Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Error when size-hint size does not match existing metadata.
  • Loading branch information
MrBago committed Dec 13, 2017
commit 136d8f8f05e6430a65f55fd35d6011643371e9c4
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,13 @@ class VectorSizeHint @Since("2.3.0") (@Since("2.3.0") override val uid: String)
if (localHandleInvalid == VectorSizeHint.OPTIMISTIC_INVALID && group.size == localSize) {
dataset.toDF
} else {
val newGroup = if (group.size == localSize) {
// Pass along any existing metadata about vector.
group
} else {
new AttributeGroup(localInputCol, localSize)
val newGroup = group.size match {
case `localSize` => group
case -1 => new AttributeGroup(localInputCol, localSize)
case _ =>
val msg = s"Trying to set size of vectors in `$localInputCol` to $localSize but size " +
s"already set to ${group.size}."
throw new SparkException(msg)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually, SparkException is used for exceptions within tasks. I'd use IllegalArgumentException.

}

val newCol: Column = localHandleInvalid match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,47 +39,75 @@ class VectorSizeHintSuite
test("Adding size to column of vectors.") {

val size = 3
val vectorColName = "vector"
val denseVector = Vectors.dense(1, 2, 3)
val sparseVector = Vectors.sparse(size, Array(), Array())

val data = Seq(denseVector, denseVector, sparseVector).map(Tuple1.apply)
val dataFrame = data.toDF("vector")

val transformer = new VectorSizeHint()
.setInputCol("vector")
.setSize(3)
.setHandleInvalid("error")
val withSize = transformer.transform(dataFrame)
val dataFrame = data.toDF(vectorColName)
assert(
AttributeGroup.fromStructField(withSize.schema("vector")).size == size,
AttributeGroup.fromStructField(dataFrame.schema(vectorColName)).size == -1,
"Transformer did not add expected size data.")

for (handleInvalid <- VectorSizeHint.supportedHandleInvalids) {
val transformer = new VectorSizeHint()
.setInputCol(vectorColName)
.setSize(size)
.setHandleInvalid(handleInvalid)
val withSize = transformer.transform(dataFrame)
assert(
AttributeGroup.fromStructField(withSize.schema(vectorColName)).size == size,
"Transformer did not add expected size data.")
withSize.collect
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might as well assert withSize.collect().length === 3 to make sure no Rows were incorrectly filtered

}
}

test("Size hint preserves attributes.") {

case class Foo(x: Double, y: Double, z: Double)
val size = 3
val vectorColName = "vector"
val data = Seq((1, 2, 3), (2, 3, 3))
val boo = data.toDF("x", "y", "z")
val dataFrame = data.toDF("x", "y", "z")

val assembler = new VectorAssembler()
.setInputCols(Array("x", "y", "z"))
.setOutputCol("vector")
val dataFrameWithMeatadata = assembler.transform(boo)
val group = AttributeGroup.fromStructField(dataFrameWithMeatadata.schema("vector"))
.setOutputCol(vectorColName)
val dataFrameWithMetadata = assembler.transform(dataFrame)
val group = AttributeGroup.fromStructField(dataFrameWithMetadata.schema(vectorColName))

for (handleInvalid <- Seq("error", "skip", "optimistic")) {
for (handleInvalid <- VectorSizeHint.supportedHandleInvalids) {
val transformer = new VectorSizeHint()
.setInputCol("vector")
.setSize(3)
.setInputCol(vectorColName)
.setSize(size)
.setHandleInvalid(handleInvalid)
val withSize = transformer.transform(dataFrameWithMeatadata)
val withSize = transformer.transform(dataFrameWithMetadata)

val newGroup = AttributeGroup.fromStructField(withSize.schema("vector"))
val newGroup = AttributeGroup.fromStructField(withSize.schema(vectorColName))
assert(newGroup.size === size, "Transformer did not add expected size data.")
assert(
newGroup.attributes.get.deep === group.attributes.get.deep,
"SizeHintTransformer did not preserve attributes.")
withSize.collect
}
}

test("Size miss-match between current and target size raises an error.") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: mismatch

val size = 4
val vectorColName = "vector"
val data = Seq((1, 2, 3), (2, 3, 3))
val dataFrame = data.toDF("x", "y", "z")

val assembler = new VectorAssembler()
.setInputCols(Array("x", "y", "z"))
.setOutputCol(vectorColName)
val dataFrameWithMetadata = assembler.transform(dataFrame)

for (handleInvalid <- VectorSizeHint.supportedHandleInvalids) {
val transformer = new VectorSizeHint()
.setInputCol(vectorColName)
.setSize(size)
.setHandleInvalid(handleInvalid)
intercept[SparkException](transformer.transform(dataFrameWithMetadata))
}
}

Expand Down