From afecd4e7eae8b129eea6ece7ff443e208d2cd2cf Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 7 Jul 2015 11:47:43 -0700 Subject: [PATCH 01/12] Add a param to skip invalid entries. --- .../ml/param/shared/SharedParamsCodeGen.scala | 1 + .../spark/ml/param/shared/sharedParams.scala | 15 +++++++++++++++ 2 files changed, 16 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index b0a6af171c01f..8571d4c1979aa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -53,6 +53,7 @@ private[shared] object SharedParamsCodeGen { ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)", isValid = "ParamValidators.gtEq(1)"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), + ParamDesc[Boolean]("skipInvalid", "whether to skip invalid entries"), ParamDesc[Boolean]("standardization", "whether to standardize the training features" + " prior to fitting the model sequence. Note that the coefficients of models are" + " always returned on the original scale.", Some("true")), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index bbe08939b6d75..929b4f1b6122d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -233,6 +233,21 @@ private[ml] trait HasFitIntercept extends Params { final def getFitIntercept: Boolean = $(fitIntercept) } +/** + * (private[ml]) Trait for shared param skipInvalid. + */ +private[ml] trait HasSkipInvalid extends Params { + + /** + * Param for whether to skip invalid entries. + * @group param + */ + final val skipInvalid: BooleanParam = new BooleanParam(this, "skipInvalid", "whether to skip invalid entries") + + /** @group getParam */ + final def getSkipInvalid: Boolean = $(skipInvalid) +} + /** * (private[ml]) Trait for shared param standardization (default: true). */ From b5734befde20e5a6b33c30318b423faf49fbaab9 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 7 Jul 2015 13:29:09 -0700 Subject: [PATCH 02/12] Add support for unseen labels --- .../spark/ml/feature/StringIndexer.scala | 31 ++++++++++++++++--- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index bf7be363b8224..4389b61a8b4e9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -32,7 +32,8 @@ import org.apache.spark.util.collection.OpenHashMap /** * Base trait for [[StringIndexer]] and [[StringIndexerModel]]. */ -private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol { +private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol + with HasSkipInvalid { /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { @@ -64,13 +65,16 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod def this() = this(Identifiable.randomUID("strIdx")) + /** @group setParam */ + def setSkipInvalid(value: Boolean): this.type = set(skipInvalid, value) + setDefault(skipInvalid, false) + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - // TODO: handle unseen labels override def fit(dataset: DataFrame): StringIndexerModel = { val counts = dataset.select(col($(inputCol)).cast(StringType)) @@ -110,6 +114,10 @@ class StringIndexerModel private[ml] ( map } + /** @group setParam */ + def setSkipInvalid(value: Boolean): this.type = set(skipInvalid, value) + setDefault(skipInvalid, false) + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -127,14 +135,27 @@ class StringIndexerModel private[ml] ( if (labelToIndex.contains(label)) { labelToIndex(label) } else { - // TODO: handle unseen labels - throw new SparkException(s"Unseen label: $label.") + if (!getSkipInvalid) { + throw new SparkException(s"Unseen label: $label.") + } else { + throw new SparkException(s"Unseen label even when pre-filtering: $label.") + } } } + val outputColName = $(outputCol) val metadata = NominalAttribute.defaultAttr .withName(outputColName).withValues(labels).toMetadata() - dataset.select(col("*"), + // If we are skipping invalid records, filter them out. + val filteredDataset = if (getSkipInvalid) { + val filterer = udf { label: String => + labelToIndex.contains(label) + } + dataset.where(filterer(dataset($(inputCol)))) + } else { + dataset + } + filteredDataset.select(col("*"), indexer(dataset($(inputCol)).cast(StringType)).as(outputColName, metadata)) } From d69ef5e183ec6a709c903316ca7e4d9c11b80726 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 7 Jul 2015 13:48:12 -0700 Subject: [PATCH 03/12] Add a test --- .../spark/ml/feature/StringIndexerSuite.scala | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 99f82bea42688..e7e4110ee2024 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import org.apache.spark.SparkException import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite @@ -49,6 +50,38 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(output === expected) } + test("StringIndexerUnessen") { + val data = sc.parallelize(Seq((0, "a"), (1, "b")), 2) + val data2 = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c")), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "label") + val df2 = sqlContext.createDataFrame(data2).toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .fit(df) + // Verify we throw by default with unseen values + intercept[SparkException] { + indexer.transform(df2).collect() + } + val indexerSkipInvalid = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .setSkipInvalid(true) + .fit(df) + // Verify that we skip the c record + val transformed = indexerSkipInvalid.transform(df2) + val attr = Attribute.fromStructField(transformed.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("b", "a")) + val output = transformed.select("id", "labelIndex").map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 1, b -> 0 + val expected = Set((0, 1.0), (1, 0.0)) + assert(output === expected) + } + + test("StringIndexer with a numeric input column") { val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2) val df = sqlContext.createDataFrame(data).toDF("id", "label") From 75ffa6986dc18bae6ee3a749e7b5e33b413cde2f Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 8 Jul 2015 15:12:57 -0700 Subject: [PATCH 04/12] Remove extra newline --- .../scala/org/apache/spark/ml/feature/StringIndexerSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index e7e4110ee2024..eb059691a5fa0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -81,7 +81,6 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(output === expected) } - test("StringIndexer with a numeric input column") { val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2) val df = sqlContext.createDataFrame(data).toDF("id", "label") From aa5b093d3d36d539cff454b149b148d857f816e8 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 9 Jul 2015 13:16:01 -0700 Subject: [PATCH 05/12] Since we filter we should never go down this code path if getSkipInvalid is true --- .../scala/org/apache/spark/ml/feature/StringIndexer.scala | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 4389b61a8b4e9..ca33d61c625c4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -135,11 +135,7 @@ class StringIndexerModel private[ml] ( if (labelToIndex.contains(label)) { labelToIndex(label) } else { - if (!getSkipInvalid) { - throw new SparkException(s"Unseen label: $label.") - } else { - throw new SparkException(s"Unseen label even when pre-filtering: $label.") - } + throw new SparkException(s"Unseen label: $label.") } } From 7a22215c0534dd1a2978a8ede5d94c34d41c365c Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 5 Aug 2015 17:16:45 -0700 Subject: [PATCH 06/12] fix typo --- .../scala/org/apache/spark/ml/feature/StringIndexerSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index eb059691a5fa0..a443c4b3851eb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -50,7 +50,7 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(output === expected) } - test("StringIndexerUnessen") { + test("StringIndexerUnseen") { val data = sc.parallelize(Seq((0, "a"), (1, "b")), 2) val data2 = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c")), 2) val df = sqlContext.createDataFrame(data).toDF("id", "label") From 1e53f9b3f6eeaef4fef5d4702dcc2aa39db0dff5 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 5 Aug 2015 17:28:44 -0700 Subject: [PATCH 07/12] update the param (codegen side) --- .../spark/ml/param/shared/SharedParamsCodeGen.scala | 3 ++- .../apache/spark/ml/param/shared/sharedParams.scala | 10 +++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index e606d61add5e6..0beec4551b98e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -53,7 +53,8 @@ private[shared] object SharedParamsCodeGen { ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)", isValid = "ParamValidators.gtEq(1)"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), - ParamDesc[Boolean]("skipInvalid", "whether to skip invalid entries"), + ParamDesc[String]("handleInvalid", "how to handle invalid entries", + isValid = "ParamValidators.inArray(List(\"skip\", \"error\"))"), ParamDesc[Boolean]("standardization", "whether to standardize the training features" + " before fitting the model.", Some("true")), ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 5d57be1f391c2..2dee845870a32 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -233,18 +233,18 @@ private[ml] trait HasFitIntercept extends Params { } /** - * Trait for shared param skipInvalid. + * Trait for shared param handleInvalid. */ -private[ml] trait HasSkipInvalid extends Params { +private[ml] trait HasHandleInvalid extends Params { /** - * Param for whether to skip invalid entries. + * Param for how to handle invalid entries. * @group param */ - final val skipInvalid: BooleanParam = new BooleanParam(this, "skipInvalid", "whether to skip invalid entries") + final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries", ParamValidators.inArray(List("skip", "error"))) /** @group getParam */ - final def getSkipInvalid: Boolean = $(skipInvalid) + final def getHandleInvalid: String = $(handleInvalid) } /** From 414e249de422d718d33bb9e1b04525704f3e8a30 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 5 Aug 2015 17:35:31 -0700 Subject: [PATCH 08/12] And switch to using handleInvalid instead of skipInvalid --- .../spark/ml/feature/StringIndexer.scala | 23 ++++++++++--------- .../ml/param/shared/SharedParamsCodeGen.scala | 2 +- .../spark/ml/param/shared/sharedParams.scala | 2 +- .../spark/ml/feature/StringIndexerSuite.scala | 2 +- 4 files changed, 15 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index ca33d61c625c4..cf543791295f6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -33,7 +33,7 @@ import org.apache.spark.util.collection.OpenHashMap * Base trait for [[StringIndexer]] and [[StringIndexerModel]]. */ private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol - with HasSkipInvalid { + with HasHandleInvalid { /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { @@ -66,8 +66,8 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod def this() = this(Identifiable.randomUID("strIdx")) /** @group setParam */ - def setSkipInvalid(value: Boolean): this.type = set(skipInvalid, value) - setDefault(skipInvalid, false) + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + setDefault(handleInvalid, "error") /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -115,8 +115,8 @@ class StringIndexerModel private[ml] ( } /** @group setParam */ - def setSkipInvalid(value: Boolean): this.type = set(skipInvalid, value) - setDefault(skipInvalid, false) + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + setDefault(handleInvalid, "error") /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -143,13 +143,14 @@ class StringIndexerModel private[ml] ( val metadata = NominalAttribute.defaultAttr .withName(outputColName).withValues(labels).toMetadata() // If we are skipping invalid records, filter them out. - val filteredDataset = if (getSkipInvalid) { - val filterer = udf { label: String => - labelToIndex.contains(label) + val filteredDataset = (getHandleInvalid) match { + case "skip" => { + val filterer = udf { label: String => + labelToIndex.contains(label) + } + dataset.where(filterer(dataset($(inputCol)))) } - dataset.where(filterer(dataset($(inputCol)))) - } else { - dataset + case _ => dataset } filteredDataset.select(col("*"), indexer(dataset($(inputCol)).cast(StringType)).as(outputColName, metadata)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 0beec4551b98e..88068456fa0e1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -54,7 +54,7 @@ private[shared] object SharedParamsCodeGen { isValid = "ParamValidators.gtEq(1)"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), ParamDesc[String]("handleInvalid", "how to handle invalid entries", - isValid = "ParamValidators.inArray(List(\"skip\", \"error\"))"), + isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))"), ParamDesc[Boolean]("standardization", "whether to standardize the training features" + " before fitting the model.", Some("true")), ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 2dee845870a32..c83e27395d4aa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -241,7 +241,7 @@ private[ml] trait HasHandleInvalid extends Params { * Param for how to handle invalid entries. * @group param */ - final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries", ParamValidators.inArray(List("skip", "error"))) + final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries", ParamValidators.inArray(Array("skip", "error"))) /** @group getParam */ final def getHandleInvalid: String = $(handleInvalid) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index a443c4b3851eb..210fb0da8dc66 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -66,7 +66,7 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { val indexerSkipInvalid = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") - .setSkipInvalid(true) + .setHandleInvalid("skip") .fit(df) // Verify that we skip the c record val transformed = indexerSkipInvalid.transform(df2) From 7f37f6e0cd5cc377f2d9efae0ba1e3906fec58ea Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 5 Aug 2015 17:46:44 -0700 Subject: [PATCH 09/12] remove extra space (scala style) --- .../main/scala/org/apache/spark/ml/feature/StringIndexer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index cf543791295f6..eb8def3ccef69 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -150,7 +150,7 @@ class StringIndexerModel private[ml] ( } dataset.where(filterer(dataset($(inputCol)))) } - case _ => dataset + case _ => dataset } filteredDataset.select(col("*"), indexer(dataset($(inputCol)).cast(StringType)).as(outputColName, metadata)) From 81dd3126b94dca8d819deb4b21f57724e125b2c4 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 6 Aug 2015 10:42:29 -0700 Subject: [PATCH 10/12] Update the docs for handleInvalid param to be more descriptive --- .../apache/spark/ml/param/shared/SharedParamsCodeGen.scala | 3 ++- .../scala/org/apache/spark/ml/param/shared/sharedParams.scala | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 88068456fa0e1..65cf2598d41c6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -53,7 +53,8 @@ private[shared] object SharedParamsCodeGen { ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)", isValid = "ParamValidators.gtEq(1)"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), - ParamDesc[String]("handleInvalid", "how to handle invalid entries", + ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " + + "will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))"), ParamDesc[Boolean]("standardization", "whether to standardize the training features" + " before fitting the model.", Some("true")), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index c83e27395d4aa..26556547f675d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -238,10 +238,10 @@ private[ml] trait HasFitIntercept extends Params { private[ml] trait HasHandleInvalid extends Params { /** - * Param for how to handle invalid entries. + * Param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.. * @group param */ - final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries", ParamValidators.inArray(Array("skip", "error"))) + final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", ParamValidators.inArray(Array("skip", "error"))) /** @group getParam */ final def getHandleInvalid: String = $(handleInvalid) From 045bf22ab1852bd2f3d2403a2202a558f9eda911 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 6 Aug 2015 10:44:29 -0700 Subject: [PATCH 11/12] Add a second b entry so b gets 0 for sure --- .../scala/org/apache/spark/ml/feature/StringIndexerSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 210fb0da8dc66..a46ba3a3dbd70 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -51,7 +51,7 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { } test("StringIndexerUnseen") { - val data = sc.parallelize(Seq((0, "a"), (1, "b")), 2) + val data = sc.parallelize(Seq((0, "a"), (1, "b"), (4, "b")), 2) val data2 = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c")), 2) val df = sqlContext.createDataFrame(data).toDF("id", "label") val df2 = sqlContext.createDataFrame(data2).toDF("id", "label") From 38a4de9ea4a04e3765b81b54479e80157630a13d Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 6 Aug 2015 10:52:52 -0700 Subject: [PATCH 12/12] fix long line --- .../org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 65cf2598d41c6..41c38e9436586 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -54,7 +54,8 @@ private[shared] object SharedParamsCodeGen { isValid = "ParamValidators.gtEq(1)"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " + - "will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", + "will filter out rows with bad values), or error (which will throw an errror). More " + + "options may be added later.", isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))"), ParamDesc[Boolean]("standardization", "whether to standardize the training features" + " before fitting the model.", Some("true")),