Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 188 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.feature

import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param.{IntParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.{Column, DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.StructType

/**
* :: Experimental ::
* A feature transformer that adds size information to the metadata of a vector column.
* VectorAssembler needs size information for its input columns and cannot be used on streaming
* dataframes without this metadata.
*
*/
@Experimental
@Since("2.3.0")
class VectorSizeHint @Since("2.3.0") (@Since("2.3.0") override val uid: String)
extends Transformer with HasInputCol with HasHandleInvalid with DefaultParamsWritable {

@Since("2.3.0")
def this() = this(Identifiable.randomUID("vectSizeHint"))

/**
* The size of Vectors in `inputCol`.
* @group param
*/
@Since("2.3.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a docstring and mark with @group param

val size: IntParam = new IntParam(
this,
"size",
"Size of vectors in column.",
{s: Int => s >= 0})

/** group getParam */
@Since("2.3.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mark with @group getParam

def getSize: Int = getOrDefault(size)

/** @group setParam */
@Since("2.3.0")
def setSize(value: Int): this.type = set(size, value)

/** @group setParam */
@Since("2.3.0")
def setInputCol(value: String): this.type = set(inputCol, value)

/**
* Param for how to handle invalid entries. Invalid vectors include nulls and vectors with the
* wrong size. The options are `skip` (filter out rows with invalid vectors), `error` (throw an
* error) and `keep` (do not check the vector size, and keep all rows). `error` by default.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep -> optimistic

* @group param
*/
@Since("2.3.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as you're overriding this val, can you please override the docstring and specify the default value here?

override val handleInvalid: Param[String] = new Param[String](
this,
"handleInvalid",
"How to handle invalid vectors in inputCol. Invalid vectors include nulls and vectors with " +
"the wrong size. The options are skip (filter out rows with invalid vectors), error " +
"(throw an error) and keep (do not check the vector size, and keep all rows). `error` by " +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep -> optimistic

"default.",
ParamValidators.inArray(VectorSizeHint.supportedHandleInvalids))

/** @group setParam */
@Since("2.3.0")
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
setDefault(handleInvalid, VectorSizeHint.ERROR_INVALID)

@Since("2.3.0")
override def transform(dataset: Dataset[_]): DataFrame = {
val localInputCol = getInputCol
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's lightweight, let's call transformSchema here to validate the params. That way, if users have not yet specified a required Param, we can throw an exception with a better error message.

val localSize = getSize
val localHandleInvalid = getHandleInvalid

val group = AttributeGroup.fromStructField(dataset.schema(localInputCol))
val newGroup = validateSchemaAndSize(dataset.schema, group)
if (localHandleInvalid == VectorSizeHint.OPTIMISTIC_INVALID && group.size == localSize) {
dataset.toDF()
} else {
val newCol: Column = localHandleInvalid match {
case VectorSizeHint.OPTIMISTIC_INVALID => col(localInputCol)
case VectorSizeHint.ERROR_INVALID =>
val checkVectorSizeUDF = udf { vector: Vector =>
if (vector == null) {
throw new SparkException(s"Got null vector in VectorSizeHint, set `handleInvalid` " +
s"to 'skip' to filter invalid rows.")
}
if (vector.size != localSize) {
throw new SparkException(s"VectorSizeHint Expecting a vector of size $localSize but" +
s" got ${vector.size}")
}
vector
}.asNondeterministic()
checkVectorSizeUDF(col(localInputCol))
case VectorSizeHint.SKIP_INVALID =>
val checkVectorSizeUDF = udf { vector: Vector =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case can be converted to use pure SQL operations to avoid SerDe costs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like:

val isSparse = col(localInputCol)(0) === lit(0)
val sparseSize = col(localInputCol)(1)
val denseSize = size(col(localInputCol)(3))
val vecSize = when(isSparse, sparseSize).otherwise(denseSize)
val sizeMatches = vecSize === lit(localSize)
when(col(localInputCol).isNotNull && sizeMatches,
  col(localInputCol),
  lit(null))

That should be 90% correct I think : )

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or not? Maybe the analyzer won't allow you to treat a UDF column as a generic struct. Scratch this comment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it doesn't work... UserDefinedType column cannot be used as StructType. @cloud-fan Is there any way we can directly extract "Struct" from UDT column ? (in pure sql way)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

internally UDT column is stored as UserDefinedType.sqlType, so if your UDT is mapped to sql struct type, we can use it as struct type column via pure SQL/DataFrame operations.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan I tried, but got such error:
org.apache.spark.sql.AnalysisException: Can't extract value from a#3: need struct type but got vector, anywhere wrong ? test code:

import spark.implicits._
import org.apache.spark.ml.linalg._
val df1 = Seq(Tuple1(Vectors.dense(1.0, 2.0))).toDF("a")
df1.select(col("a")(0)).show

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this is a bug, let me look into it

if (vector != null && vector.size == localSize) {
vector
} else {
null
}
}
checkVectorSizeUDF(col(localInputCol))
}

val res = dataset.withColumn(localInputCol, newCol.as(localInputCol, newGroup.toMetadata()))
if (localHandleInvalid == VectorSizeHint.SKIP_INVALID) {
res.na.drop(Array(localInputCol))
} else {
res
}
}
}

/**
* Checks that schema can be updated with new size and returns a new attribute group with
* updated size.
*/
private def validateSchemaAndSize(schema: StructType, group: AttributeGroup): AttributeGroup = {
// This will throw a NoSuchElementException if params are not set.
val localSize = getSize
val localInputCol = getInputCol

val inputColType = schema(getInputCol).dataType
require(
inputColType.isInstanceOf[VectorUDT],
s"Input column, $getInputCol must be of Vector type, got $inputColType"
)
group.size match {
case `localSize` => group
case -1 => new AttributeGroup(localInputCol, localSize)
case _ =>
val msg = s"Trying to set size of vectors in `$localInputCol` to $localSize but size " +
s"already set to ${group.size}."
throw new IllegalArgumentException(msg)
}
}

@Since("2.3.0")
override def transformSchema(schema: StructType): StructType = {
val fieldIndex = schema.fieldIndex(getInputCol)
val fields = schema.fields.clone()
val inputField = fields(fieldIndex)
val group = AttributeGroup.fromStructField(inputField)
val newGroup = validateSchemaAndSize(schema, group)
fields(fieldIndex) = inputField.copy(metadata = newGroup.toMetadata())
StructType(fields)
}

@Since("2.3.0")
override def copy(extra: ParamMap): this.type = defaultCopy(extra)
}

@Experimental
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add Scala docstring here with :: Experimental :: note.

@Since("2.3.0")
object VectorSizeHint extends DefaultParamsReadable[VectorSizeHint] {

private[feature] val OPTIMISTIC_INVALID = "optimistic"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we call this "keep" instead of "optimistic" in order to match handeInvalid Params in other Transformers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm open to changing this to keep, but I wanted to lay out the argument for having it not be keep. The way that keep is used in other transformers is generally to map unknown or new values to some valid representation. In other words in an instruction to the transformer to make a "best effort" to deal with invalid values. The important thing here is that not only does keep not error on invalid values, but existing keep implementations actually ensure that invalid values produce a valid result.

The behaviour of this transformer is subtly different. It doesn't do anything to "correct" invalid vectors and as a result using the keep/optimistic option can lead to an invalid state of the DataFrame, namely the metadata is wrong about the contents of the column. Users who are accustomed to using keep on other transformers maybe confused or frustrated by this difference.

Copy link
Member

@jkbradley jkbradley Dec 13, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK that's a great argument. Let's keep it "optimistic."

Since this is a little confusing, and since users could trip themselves up by using "optimistic," would you mind putting some more warning about use of "optimistic" in the Param docstring? Telling users when to use or not use it would be helpful and might prevent some mistakes.

private[feature] val ERROR_INVALID = "error"
private[feature] val SKIP_INVALID = "skip"
private[feature] val supportedHandleInvalids: Array[String] =
Array(OPTIMISTIC_INVALID, ERROR_INVALID, SKIP_INVALID)

@Since("2.3.0")
override def load(path: String): VectorSizeHint = super.load(path)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.feature

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.streaming.StreamTest

class VectorSizeHintSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

import testImplicits._

test("Test Param Validators") {
intercept[IllegalArgumentException] (new VectorSizeHint().setHandleInvalid("invalidValue"))
intercept[IllegalArgumentException] (new VectorSizeHint().setSize(-3))
}

test("Required params must be set before transform.") {
val data = Seq((Vectors.dense(1, 2), 0)).toDF("vector", "intValue")

val noSizeTransformer = new VectorSizeHint().setInputCol("vector")
intercept[NoSuchElementException] (noSizeTransformer.transform(data))
intercept[NoSuchElementException] (noSizeTransformer.transformSchema(data.schema))

val noInputColTransformer = new VectorSizeHint().setSize(2)
intercept[NoSuchElementException] (noInputColTransformer.transform(data))
intercept[NoSuchElementException] (noInputColTransformer.transformSchema(data.schema))
}

test("Adding size to column of vectors.") {

val size = 3
val vectorColName = "vector"
val denseVector = Vectors.dense(1, 2, 3)
val sparseVector = Vectors.sparse(size, Array(), Array())

val data = Seq(denseVector, denseVector, sparseVector).map(Tuple1.apply)
val dataFrame = data.toDF(vectorColName)
assert(
AttributeGroup.fromStructField(dataFrame.schema(vectorColName)).size == -1,
s"This test requires that column '$vectorColName' not have size metadata.")

for (handleInvalid <- VectorSizeHint.supportedHandleInvalids) {
val transformer = new VectorSizeHint()
.setInputCol(vectorColName)
.setSize(size)
.setHandleInvalid(handleInvalid)
val withSize = transformer.transform(dataFrame)
assert(
AttributeGroup.fromStructField(withSize.schema(vectorColName)).size == size,
"Transformer did not add expected size data.")
val numRows = withSize.collect().length
assert(numRows === data.length, s"Expecting ${data.length} rows, got $numRows.")
}
}

test("Size hint preserves attributes.") {

val size = 3
val vectorColName = "vector"
val data = Seq((1, 2, 3), (2, 3, 3))
val dataFrame = data.toDF("x", "y", "z")

val assembler = new VectorAssembler()
.setInputCols(Array("x", "y", "z"))
.setOutputCol(vectorColName)
val dataFrameWithMetadata = assembler.transform(dataFrame)
val group = AttributeGroup.fromStructField(dataFrameWithMetadata.schema(vectorColName))

for (handleInvalid <- VectorSizeHint.supportedHandleInvalids) {
val transformer = new VectorSizeHint()
.setInputCol(vectorColName)
.setSize(size)
.setHandleInvalid(handleInvalid)
val withSize = transformer.transform(dataFrameWithMetadata)

val newGroup = AttributeGroup.fromStructField(withSize.schema(vectorColName))
assert(newGroup.size === size, "Column has incorrect size metadata.")
assert(
newGroup.attributes.get === group.attributes.get,
"VectorSizeHint did not preserve attributes.")
withSize.collect
}
}

test("Size mismatch between current and target size raises an error.") {
val size = 4
val vectorColName = "vector"
val data = Seq((1, 2, 3), (2, 3, 3))
val dataFrame = data.toDF("x", "y", "z")

val assembler = new VectorAssembler()
.setInputCols(Array("x", "y", "z"))
.setOutputCol(vectorColName)
val dataFrameWithMetadata = assembler.transform(dataFrame)

for (handleInvalid <- VectorSizeHint.supportedHandleInvalids) {
val transformer = new VectorSizeHint()
.setInputCol(vectorColName)
.setSize(size)
.setHandleInvalid(handleInvalid)
intercept[IllegalArgumentException](transformer.transform(dataFrameWithMetadata))
}
}

test("Handle invalid does the right thing.") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't find a test for optimistic option. We should test:
If input dataset vector column do not include metadata, the VectorSizeHint should add metadata with proper size, or input vector column include metadata with different size, the VectorSizeHint should replace it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I talked offline to @jkbradley and I think it's better to throw an exception unless if the column includes metadata & the there is a mismatch between the new and original size.

I've added a new test for this exception and made sure the other tests are run with all handleInvalid cases. Does it look ok now?


val vector = Vectors.dense(1, 2, 3)
val short = Vectors.dense(2)
val dataWithNull = Seq(vector, null).map(Tuple1.apply).toDF("vector")
val dataWithShort = Seq(vector, short).map(Tuple1.apply).toDF("vector")

val sizeHint = new VectorSizeHint()
.setInputCol("vector")
.setHandleInvalid("error")
.setSize(3)

intercept[SparkException](sizeHint.transform(dataWithNull).collect())
intercept[SparkException](sizeHint.transform(dataWithShort).collect())

sizeHint.setHandleInvalid("skip")
assert(sizeHint.transform(dataWithNull).count() === 1)
assert(sizeHint.transform(dataWithShort).count() === 1)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test keep/optimistic too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you a thought on how to test keep/optimistic. I could verify that the invalid data is not removed but that's a little bit weird to test. It's ensuring that this option allows the column to get into a "bad state" where the metadata doesn't match the contents. Is that what you had in mind?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that's what I had in mind. That is the expected behavior, so we can test that behavior...even if it's not what most use cases would need.


test("read/write") {
val sizeHint = new VectorSizeHint()
.setInputCol("myInputCol")
.setSize(11)
.setHandleInvalid("skip")
testDefaultReadWrite(sizeHint)
}
}

class VectorSizeHintStreamingSuite extends StreamTest {

import testImplicits._

test("Test assemble vectors with size hint in streaming.") {
val a = Vectors.dense(0, 1, 2)
val b = Vectors.sparse(4, Array(0, 3), Array(3, 6))

val stream = MemoryStream[(Vector, Vector)]
val streamingDF = stream.toDS.toDF("a", "b")
val sizeHintA = new VectorSizeHint()
.setSize(3)
.setInputCol("a")
val sizeHintB = new VectorSizeHint()
.setSize(4)
.setInputCol("b")
val vectorAssembler = new VectorAssembler()
.setInputCols(Array("a", "b"))
.setOutputCol("assembled")
val pipeline = new Pipeline().setStages(Array(sizeHintA, sizeHintB, vectorAssembler))
/**
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove unused code?

val output = Seq(sizeHintA, sizeHintB, vectorAssembler).foldLeft(streamingDF) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just put these in a PipelineModel to avoid using foldLeft.

case (data, transformer) => transformer.transform(data)
}.select("assembled")
*/
val output = pipeline.fit(streamingDF).transform(streamingDF).select("assembled")

val expected = Vectors.dense(0, 1, 2, 3, 0, 0, 6)

testStream (output) (
AddData(stream, (a, b), (a, b)),
CheckAnswer(Tuple1(expected), Tuple1(expected))
)
}
}