Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Added VectorSizeHint Transformer in ml.feature.
  • Loading branch information
MrBago committed Dec 13, 2017
commit 24cc41792770c7f08481a2bbcb120a119631e5ee
151 changes: 151 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.feature

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.{Column, DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.StructType

/**
* A feature transformer that adds vector size information to a vector column.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add :: Experimental :: note here so it shows up properly in docs. Look at other uses of Experimental for examples. (Same for the companion object)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, it'd be good to add more docs about why/when people should use this.

*/
@Experimental
@Since("2.3.0")
class VectorSizeHint @Since("2.3.0") (@Since("2.3.0") override val uid: String)
extends Transformer with HasInputCol with HasHandleInvalid with DefaultParamsWritable {

@Since("2.3.0")
def this() = this(Identifiable.randomUID("vectSizeHint"))

@Since("2.3.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a docstring and mark with @group param

val size = new Param[Int](this, "size", "Size of vectors in column.", {s: Int => s >= 0})

@Since("2.3.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mark with @group getParam

def getSize: Int = getOrDefault(size)

/** @group setParam */
@Since("2.3.0")
def setSize(value: Int): this.type = set(size, value)

/** @group setParam */
@Since("2.3.0")
def setInputCol(value: String): this.type = set(inputCol, value)

@Since("2.3.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as you're overriding this val, can you please override the docstring and specify the default value here?

override val handleInvalid: Param[String] = new Param[String](
this,
"handleInvalid",
"How to handle invalid vectors in inputCol, (invalid vectors include nulls and vectors with " +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The writing here is formatted strangely. How about:
"How to handle invalid vectors in inputCol. Invalid vectors include nulls and vectors with the wrong size. The options are skip (filter out rows with invalid vectors), error (throw an error) and keep (do not check the vector size, and keep all rows)."

"the wrong size. The options are `skip` (filter out rows with invalid vectors), `error` " +
"(throw an error) and `optimistic` (don't check the vector size).",
ParamValidators.inArray(VectorSizeHint.supportedHandleInvalids))

/** @group setParam */
@Since("2.3.0")
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
setDefault(handleInvalid, VectorSizeHint.ERROR_INVALID)

@Since("2.3.0")
override def transform(dataset: Dataset[_]): DataFrame = {
val localInputCol = getInputCol
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's lightweight, let's call transformSchema here to validate the params. That way, if users have not yet specified a required Param, we can throw an exception with a better error message.

val localSize = getSize
val localHandleInvalid = getHandleInvalid

val group = AttributeGroup.fromStructField(dataset.schema(localInputCol))
if (localHandleInvalid == VectorSizeHint.OPTIMISTIC_INVALID && group.size == localSize) {
dataset.toDF
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note IntelliJ style warning: Call toDF() with parentheses.

} else {
val newGroup = if (group.size == localSize) {
// Pass along any existing metadata about vector.
group
} else {
new AttributeGroup(localInputCol, localSize)
}

val newCol: Column = localHandleInvalid match {
case VectorSizeHint.OPTIMISTIC_INVALID => col(localInputCol)
case VectorSizeHint.ERROR_INVALID =>
val checkVectorSize = { vector: Vector =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here can simply use:

val checkVectorSizeUDF = udf { vector: Vector => ...}
checkVectorSizeUDF(col(localInputCol))

So code will be clearer.

if (vector == null) {
throw new VectorSizeHint.InvalidEntryException(s"Got null vector in VectorSizeHint," +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The UDF which is possible to throw exception should be marked as nondeterministic, check this PR #19662 for more explanation.

s" set `handleInvalid` to 'skip' to filter invalid rows.")
}
if (vector.size != localSize) {
throw new VectorSizeHint.InvalidEntryException(s"VectorSizeHint Expecting a vector " +
s"of size $localSize but got ${vector.size}")
}
vector
}
udf(checkVectorSize, new VectorUDT)(col(localInputCol))
case VectorSizeHint.SKIP_INVALID =>
val checkVectorSize = { vector: Vector =>
if (vector != null && vector.size == localSize) {
vector
} else {
null
}
}
udf(checkVectorSize, new VectorUDT)(col(localInputCol))
}

val res = dataset.withColumn(localInputCol, newCol.as(localInputCol, newGroup.toMetadata))
if (localHandleInvalid == VectorSizeHint.SKIP_INVALID) {
res.filter(col(localInputCol).isNotNull)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here use res.na.drop(Array(localInputCol)) will be better.

} else {
res
}
}
}

@Since("2.3.0")
override def transformSchema(schema: StructType): StructType = {
val inputColType = schema(getInputCol).dataType
require(
inputColType.isInstanceOf[VectorUDT],
s"Input column, $getInputCol must be of Vector type, got $inputColType"
)
schema
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since transformSchema does final Param validation checks, let's require that 'size' is specified here. Also, add size to the metadata.

}

@Since("2.3.0")
override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra)
}

@Experimental
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add Scala docstring here with :: Experimental :: note.

@Since("2.3.0")
object VectorSizeHint extends DefaultParamsReadable[VectorSizeHint] {

private[feature] val OPTIMISTIC_INVALID = "optimistic"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we call this "keep" instead of "optimistic" in order to match handeInvalid Params in other Transformers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm open to changing this to keep, but I wanted to lay out the argument for having it not be keep. The way that keep is used in other transformers is generally to map unknown or new values to some valid representation. In other words in an instruction to the transformer to make a "best effort" to deal with invalid values. The important thing here is that not only does keep not error on invalid values, but existing keep implementations actually ensure that invalid values produce a valid result.

The behaviour of this transformer is subtly different. It doesn't do anything to "correct" invalid vectors and as a result using the keep/optimistic option can lead to an invalid state of the DataFrame, namely the metadata is wrong about the contents of the column. Users who are accustomed to using keep on other transformers maybe confused or frustrated by this difference.

Copy link
Member

@jkbradley jkbradley Dec 13, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK that's a great argument. Let's keep it "optimistic."

Since this is a little confusing, and since users could trip themselves up by using "optimistic," would you mind putting some more warning about use of "optimistic" in the Param docstring? Telling users when to use or not use it would be helpful and might prevent some mistakes.

private[feature] val ERROR_INVALID = "error"
private[feature] val SKIP_INVALID = "skip"
private[feature] val supportedHandleInvalids: Array[String] =
Array(OPTIMISTIC_INVALID, ERROR_INVALID, SKIP_INVALID)

@Since("2.3.0")
class InvalidEntryException(msg: String) extends Exception(msg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need define a new exception class ? Or directly use SparkException ?


@Since("2.3.0")
override def load(path: String): VectorSizeHint = super.load(path)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.feature

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.streaming.StreamTest

class VectorSizeHintSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

import testImplicits._

test("Test Param Validators") {
assertThrows[IllegalArgumentException] (new VectorSizeHint().setHandleInvalid("invalidValue"))
assertThrows[IllegalArgumentException] (new VectorSizeHint().setSize(-3))
}

test("Adding size to column of vectors.") {

val size = 3
val denseVector = Vectors.dense(1, 2, 3)
val sparseVector = Vectors.sparse(size, Array(), Array())

val data = Seq(denseVector, denseVector, sparseVector).map(Tuple1.apply)
val dataFrame = data.toDF("vector")

val transformer = new VectorSizeHint()
.setInputCol("vector")
.setSize(3)
.setHandleInvalid("error")
val withSize = transformer.transform(dataFrame)
assert(
AttributeGroup.fromStructField(withSize.schema("vector")).size == size,
"Transformer did not add expected size data.")
}

test("Size hint preserves attributes.") {

case class Foo(x: Double, y: Double, z: Double)
val size = 3
val data = Seq((1, 2, 3), (2, 3, 3))
val boo = data.toDF("x", "y", "z")

val assembler = new VectorAssembler()
.setInputCols(Array("x", "y", "z"))
.setOutputCol("vector")
val dataFrameWithMeatadata = assembler.transform(boo)
val group = AttributeGroup.fromStructField(dataFrameWithMeatadata.schema("vector"))

val transformer = new VectorSizeHint()
.setInputCol("vector")
.setSize(3)
.setHandleInvalid("error")
val withSize = transformer.transform(dataFrameWithMeatadata)

val newGroup = AttributeGroup.fromStructField(withSize.schema("vector"))
assert(newGroup.size == size, "Transformer did not add expected size data.")
assert(
newGroup.attributes.get.deep === group.attributes.get.deep,
"SizeHintTransformer did not preserve attributes.")
}

test("Handle invalid does the right thing.") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't find a test for optimistic option. We should test:
If input dataset vector column do not include metadata, the VectorSizeHint should add metadata with proper size, or input vector column include metadata with different size, the VectorSizeHint should replace it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I talked offline to @jkbradley and I think it's better to throw an exception unless if the column includes metadata & the there is a mismatch between the new and original size.

I've added a new test for this exception and made sure the other tests are run with all handleInvalid cases. Does it look ok now?


val vector = Vectors.dense(1, 2, 3)
val short = Vectors.dense(2)
val dataWithNull = Seq(vector, null).map(Tuple1.apply).toDF("vector")
val dataWithShort = Seq(vector, short).map(Tuple1.apply).toDF("vector")

val sizeHint = new VectorSizeHint()
.setInputCol("vector")
.setHandleInvalid("error")
.setSize(3)

assertThrows[SparkException](sizeHint.transform(dataWithNull).collect)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use intercept[SparkException] {...} is better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've made the change. Just out of curiosity, why is intercept better than assertThrows?

assertThrows[SparkException](sizeHint.transform(dataWithShort).collect)

sizeHint.setHandleInvalid("skip")
assert(sizeHint.transform(dataWithNull).count() === 1)
assert(sizeHint.transform(dataWithShort).count() === 1)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test keep/optimistic too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you a thought on how to test keep/optimistic. I could verify that the invalid data is not removed but that's a little bit weird to test. It's ensuring that this option allows the column to get into a "bad state" where the metadata doesn't match the contents. Is that what you had in mind?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that's what I had in mind. That is the expected behavior, so we can test that behavior...even if it's not what most use cases would need.

}

class VectorSizeHintStreamingSuite extends StreamTest {

import testImplicits._

test("Test assemble vectors with size hint in steaming.") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

steaming streaming

val a = Vectors.dense(0, 1, 2)
val b = Vectors.sparse(4, Array(0, 3), Array(3, 6))

val stream = MemoryStream[(Vector, Vector)]
val streamingDF = stream.toDS.toDF("a", "b")
val sizeHintA = new VectorSizeHint()
.setSize(3)
.setInputCol("a")
val sizeHintB = new VectorSizeHint()
.setSize(4)
.setInputCol("b")
val vectorAssembler = new VectorAssembler()
.setInputCols(Array("a", "b"))
.setOutputCol("assembled")
val output = Seq(sizeHintA, sizeHintB, vectorAssembler).foldLeft(streamingDF) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just put these in a PipelineModel to avoid using foldLeft.

case (data, transform) => transform.transform(data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

case (data, transform) ==> case (data, transformer)

}.select("assembled")

val expected = Vectors.dense(0, 1, 2, 3, 0, 0, 6)

testStream (output) (
AddData(stream, (a, b), (a, b)),
CheckAnswerRows(Seq(Row(expected), Row(expected)), false, false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use CheckAnswer(expected, expected) will be simpler.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I didn't use CheckAnswer is because there isn't an implicit encoder in testImplicits that handles Vector. I tried CheckAnswer[Vector](expected, expected) but that doesn't work either :(. Is there an encoder that works for Vectors?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, sorry, it should be CheckAnswer(Tuple1(expected), Tuple1(expected)). It should work I think.

)
}
}