Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
PR feedback.
  • Loading branch information
MrBago committed Dec 13, 2017
commit cafa875d60c487d7df0d935f0a1808b30db3d05d
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.StructType

/**
* A feature transformer that adds vector size information to a vector column.
* :: Experimental ::
* A feature transformer that adds size information to the metadata of a vector column.
* VectorAssembler needs size information for its input columns and cannot be used on streaming
* dataframes without this metadata.
*
*/
@Experimental
@Since("2.3.0")
Expand All @@ -40,9 +44,18 @@ class VectorSizeHint @Since("2.3.0") (@Since("2.3.0") override val uid: String)
@Since("2.3.0")
def this() = this(Identifiable.randomUID("vectSizeHint"))

/**
* The size of Vectors in `inputCol`.
* @group param
*/
@Since("2.3.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a docstring and mark with @group param

val size = new IntParam(this, "size", "Size of vectors in column.", {s: Int => s >= 0})
val size: IntParam = new IntParam(
this,
"size",
"Size of vectors in column.",
{s: Int => s >= 0})

/** group getParam */
@Since("2.3.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mark with @group getParam

def getSize: Int = getOrDefault(size)

Expand All @@ -54,13 +67,20 @@ class VectorSizeHint @Since("2.3.0") (@Since("2.3.0") override val uid: String)
@Since("2.3.0")
def setInputCol(value: String): this.type = set(inputCol, value)

/**
* Param for how to handle invalid entries. Invalid vectors include nulls and vectors with the
* wrong size. The options are `skip` (filter out rows with invalid vectors), `error` (throw an
* error) and `keep` (do not check the vector size, and keep all rows). `error` by default.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep -> optimistic

* @group param
*/
@Since("2.3.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as you're overriding this val, can you please override the docstring and specify the default value here?

override val handleInvalid: Param[String] = new Param[String](
this,
"handleInvalid",
"How to handle invalid vectors in inputCol, (invalid vectors include nulls and vectors with " +
"the wrong size. The options are `skip` (filter out rows with invalid vectors), `error` " +
"(throw an error) and `optimistic` (don't check the vector size).",
"How to handle invalid vectors in inputCol. Invalid vectors include nulls and vectors with " +
"the wrong size. The options are skip (filter out rows with invalid vectors), error " +
"(throw an error) and keep (do not check the vector size, and keep all rows). `error` by " +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep -> optimistic

"default.",
ParamValidators.inArray(VectorSizeHint.supportedHandleInvalids))

/** @group setParam */
Expand All @@ -75,18 +95,10 @@ class VectorSizeHint @Since("2.3.0") (@Since("2.3.0") override val uid: String)
val localHandleInvalid = getHandleInvalid

val group = AttributeGroup.fromStructField(dataset.schema(localInputCol))
val newGroup = validateSchemaAndSize(dataset.schema, group)
if (localHandleInvalid == VectorSizeHint.OPTIMISTIC_INVALID && group.size == localSize) {
dataset.toDF
dataset.toDF()
} else {
val newGroup = group.size match {
case `localSize` => group
case -1 => new AttributeGroup(localInputCol, localSize)
case _ =>
val msg = s"Trying to set size of vectors in `$localInputCol` to $localSize but size " +
s"already set to ${group.size}."
throw new SparkException(msg)
}

val newCol: Column = localHandleInvalid match {
case VectorSizeHint.OPTIMISTIC_INVALID => col(localInputCol)
case VectorSizeHint.ERROR_INVALID =>
Expand All @@ -100,7 +112,7 @@ class VectorSizeHint @Since("2.3.0") (@Since("2.3.0") override val uid: String)
s" got ${vector.size}")
}
vector
}.asNondeterministic
}.asNondeterministic()
checkVectorSizeUDF(col(localInputCol))
case VectorSizeHint.SKIP_INVALID =>
val checkVectorSizeUDF = udf { vector: Vector =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case can be converted to use pure SQL operations to avoid SerDe costs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like:

val isSparse = col(localInputCol)(0) === lit(0)
val sparseSize = col(localInputCol)(1)
val denseSize = size(col(localInputCol)(3))
val vecSize = when(isSparse, sparseSize).otherwise(denseSize)
val sizeMatches = vecSize === lit(localSize)
when(col(localInputCol).isNotNull && sizeMatches,
  col(localInputCol),
  lit(null))

That should be 90% correct I think : )

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or not? Maybe the analyzer won't allow you to treat a UDF column as a generic struct. Scratch this comment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it doesn't work... UserDefinedType column cannot be used as StructType. @cloud-fan Is there any way we can directly extract "Struct" from UDT column ? (in pure sql way)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

internally UDT column is stored as UserDefinedType.sqlType, so if your UDT is mapped to sql struct type, we can use it as struct type column via pure SQL/DataFrame operations.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan I tried, but got such error:
org.apache.spark.sql.AnalysisException: Can't extract value from a#3: need struct type but got vector, anywhere wrong ? test code:

import spark.implicits._
import org.apache.spark.ml.linalg._
val df1 = Seq(Tuple1(Vectors.dense(1.0, 2.0))).toDF("a")
df1.select(col("a")(0)).show

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this is a bug, let me look into it

Expand All @@ -113,7 +125,7 @@ class VectorSizeHint @Since("2.3.0") (@Since("2.3.0") override val uid: String)
checkVectorSizeUDF(col(localInputCol))
}

val res = dataset.withColumn(localInputCol, newCol.as(localInputCol, newGroup.toMetadata))
val res = dataset.withColumn(localInputCol, newCol.as(localInputCol, newGroup.toMetadata()))
if (localHandleInvalid == VectorSizeHint.SKIP_INVALID) {
res.na.drop(Array(localInputCol))
} else {
Expand All @@ -122,14 +134,39 @@ class VectorSizeHint @Since("2.3.0") (@Since("2.3.0") override val uid: String)
}
}

@Since("2.3.0")
override def transformSchema(schema: StructType): StructType = {
/**
* Checks that schema can be updated with new size and returns a new attribute group with
* updated size.
*/
private def validateSchemaAndSize(schema: StructType, group: AttributeGroup): AttributeGroup = {
// This will throw a NoSuchElementException if params are not set.
val localSize = getSize
val localInputCol = getInputCol

val inputColType = schema(getInputCol).dataType
require(
inputColType.isInstanceOf[VectorUDT],
s"Input column, $getInputCol must be of Vector type, got $inputColType"
)
schema
group.size match {
case `localSize` => group
case -1 => new AttributeGroup(localInputCol, localSize)
case _ =>
val msg = s"Trying to set size of vectors in `$localInputCol` to $localSize but size " +
s"already set to ${group.size}."
throw new IllegalArgumentException(msg)
}
}

@Since("2.3.0")
override def transformSchema(schema: StructType): StructType = {
val fieldIndex = schema.fieldIndex(getInputCol)
val fields = schema.fields.clone()
val inputField = fields(fieldIndex)
val group = AttributeGroup.fromStructField(inputField)
val newGroup = validateSchemaAndSize(schema, group)
fields(fieldIndex) = inputField.copy(metadata = newGroup.toMetadata())
StructType(fields)
}

@Since("2.3.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
package org.apache.spark.ml.feature

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.streaming.StreamTest

Expand All @@ -36,6 +36,18 @@ class VectorSizeHintSuite
intercept[IllegalArgumentException] (new VectorSizeHint().setSize(-3))
}

test("Required params must be set before transform.") {
val data = Seq((Vectors.dense(1, 2), 0)).toDF("vector", "intValue")

val noSizeTransformer = new VectorSizeHint().setInputCol("vector")
intercept[NoSuchElementException] (noSizeTransformer.transform(data))
intercept[NoSuchElementException] (noSizeTransformer.transformSchema(data.schema))

val noInputColTransformer = new VectorSizeHint().setSize(2)
intercept[NoSuchElementException] (noInputColTransformer.transform(data))
intercept[NoSuchElementException] (noInputColTransformer.transformSchema(data.schema))
}

test("Adding size to column of vectors.") {

val size = 3
Expand All @@ -47,7 +59,7 @@ class VectorSizeHintSuite
val dataFrame = data.toDF(vectorColName)
assert(
AttributeGroup.fromStructField(dataFrame.schema(vectorColName)).size == -1,
"Transformer did not add expected size data.")
s"This test requires that column '$vectorColName' not have size metadata.")

for (handleInvalid <- VectorSizeHint.supportedHandleInvalids) {
val transformer = new VectorSizeHint()
Expand All @@ -58,7 +70,8 @@ class VectorSizeHintSuite
assert(
AttributeGroup.fromStructField(withSize.schema(vectorColName)).size == size,
"Transformer did not add expected size data.")
withSize.collect
val numRows = withSize.collect().length
assert(numRows === data.length, s"Expecting ${data.length} rows, got $numRows.")
}
}

Expand All @@ -83,15 +96,15 @@ class VectorSizeHintSuite
val withSize = transformer.transform(dataFrameWithMetadata)

val newGroup = AttributeGroup.fromStructField(withSize.schema(vectorColName))
assert(newGroup.size === size, "Transformer did not add expected size data.")
assert(newGroup.size === size, "Column has incorrect size metadata.")
assert(
newGroup.attributes.get.deep === group.attributes.get.deep,
"SizeHintTransformer did not preserve attributes.")
newGroup.attributes.get === group.attributes.get,
"VectorSizeHint did not preserve attributes.")
withSize.collect
}
}

test("Size miss-match between current and target size raises an error.") {
test("Size mismatch between current and target size raises an error.") {
val size = 4
val vectorColName = "vector"
val data = Seq((1, 2, 3), (2, 3, 3))
Expand All @@ -107,7 +120,7 @@ class VectorSizeHintSuite
.setInputCol(vectorColName)
.setSize(size)
.setHandleInvalid(handleInvalid)
intercept[SparkException](transformer.transform(dataFrameWithMetadata))
intercept[IllegalArgumentException](transformer.transform(dataFrameWithMetadata))
}
}

Expand All @@ -123,8 +136,8 @@ class VectorSizeHintSuite
.setHandleInvalid("error")
.setSize(3)

intercept[SparkException](sizeHint.transform(dataWithNull).collect)
intercept[SparkException](sizeHint.transform(dataWithShort).collect)
intercept[SparkException](sizeHint.transform(dataWithNull).collect())
intercept[SparkException](sizeHint.transform(dataWithShort).collect())

sizeHint.setHandleInvalid("skip")
assert(sizeHint.transform(dataWithNull).count() === 1)
Expand All @@ -144,7 +157,7 @@ class VectorSizeHintStreamingSuite extends StreamTest {

import testImplicits._

test("Test assemble vectors with size hint in steaming.") {
test("Test assemble vectors with size hint in streaming.") {
val a = Vectors.dense(0, 1, 2)
val b = Vectors.sparse(4, Array(0, 3), Array(3, 6))

Expand All @@ -159,9 +172,13 @@ class VectorSizeHintStreamingSuite extends StreamTest {
val vectorAssembler = new VectorAssembler()
.setInputCols(Array("a", "b"))
.setOutputCol("assembled")
val pipeline = new Pipeline().setStages(Array(sizeHintA, sizeHintB, vectorAssembler))
/**
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove unused code?

val output = Seq(sizeHintA, sizeHintB, vectorAssembler).foldLeft(streamingDF) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just put these in a PipelineModel to avoid using foldLeft.

case (data, transformer) => transformer.transform(data)
}.select("assembled")
*/
val output = pipeline.fit(streamingDF).transform(streamingDF).select("assembled")

val expected = Vectors.dense(0, 1, 2, 3, 0, 0, 6)

Expand Down