-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-17219][ML] enhanced NaN value handling in Bucketizer #15428
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
b14fbab
5274d4a
2f98d31
2644235
2b1b81d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,7 @@ import org.apache.spark.ml.param._ | |
| import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} | ||
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.sql._ | ||
| import org.apache.spark.sql.expressions.UserDefinedFunction | ||
| import org.apache.spark.sql.functions._ | ||
| import org.apache.spark.sql.types.{DoubleType, StructField, StructType} | ||
|
|
||
|
|
@@ -73,15 +74,52 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String | |
| @Since("1.4.0") | ||
| def setOutputCol(value: String): this.type = set(outputCol, value) | ||
|
|
||
| /** | ||
| * Param for how to handle invalid entries. Options are skip (which will filter out rows with | ||
| * invalid values), or error (which will throw an error), or keep (which will keep the invalid | ||
| * values in certain way). Default behaviour is to report an error for invalid entries. | ||
|
||
| * | ||
| * @group param | ||
| */ | ||
| @Since("2.1.0") | ||
| val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle" + | ||
| "invalid entries. Options are skip (which will filter out rows with invalid values), or" + | ||
| "error (which will throw an error), or keep (which will keep the invalid values" + | ||
| " in certain way). Default behaviour is to report an error for invalid entries.", | ||
| ParamValidators.inArray(Array("skip", "error", "keep"))) | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.1.0") | ||
| def gethandleInvalid: Option[Boolean] = $(handleInvalid) match { | ||
|
||
| case "keep" => Some(true) | ||
| case "skip" => Some(false) | ||
| case _ => None | ||
| } | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.1.0") | ||
| def sethandleInvalid(value: String): this.type = set(handleInvalid, value) | ||
| setDefault(handleInvalid, "error") | ||
|
|
||
| @Since("2.0.0") | ||
| override def transform(dataset: Dataset[_]): DataFrame = { | ||
| transformSchema(dataset.schema) | ||
| val bucketizer = udf { feature: Double => | ||
| Bucketizer.binarySearchForBuckets($(splits), feature) | ||
| val keepInvalid = gethandleInvalid.isDefined && gethandleInvalid.get | ||
|
|
||
| val bucketizer: UserDefinedFunction = udf { (feature: Double) => | ||
| Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid) | ||
| } | ||
| val newCol = bucketizer(dataset($(inputCol))) | ||
| val newField = prepOutputField(dataset.schema) | ||
| dataset.withColumn($(outputCol), newCol, newField.metadata) | ||
| val filteredDataset = { | ||
| if (!keepInvalid) { | ||
| // "skip" NaN option is set, will filter out NaN values in the dataset | ||
| dataset.na.drop.toDF() | ||
| } else { | ||
| dataset.toDF() | ||
| } | ||
| } | ||
| val newCol = bucketizer(filteredDataset($(inputCol))) | ||
| val newField = prepOutputField(filteredDataset.schema) | ||
| filteredDataset.withColumn($(outputCol), newCol, newField.metadata) | ||
| } | ||
|
|
||
| private def prepOutputField(schema: StructType): StructField = { | ||
|
|
@@ -126,10 +164,21 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] { | |
|
|
||
| /** | ||
| * Binary searching in several buckets to place each data point. | ||
| * @param splits array of split points | ||
| * @param feature data point | ||
| * @param keepInvalid NaN flag. | ||
| * Set "true" to make an extra bucket for NaN values; | ||
| * Set "false" to report an error for NaN values | ||
| * @return bucket for each data point | ||
| * @throws SparkException if a feature is < splits.head or > splits.last | ||
| */ | ||
| private[feature] def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = { | ||
| if (feature.isNaN) { | ||
|
|
||
| private[feature] def binarySearchForBuckets( | ||
| splits: Array[Double], | ||
| feature: Double, | ||
| keepInvalid: Boolean): Double = { | ||
| if (feature.isNaN && keepInvalid) { | ||
| // NaN data point found plus "keep" NaN option is set | ||
| splits.length - 1 | ||
| } else if (feature == splits.last) { | ||
| splits.length - 2 | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -66,11 +66,13 @@ private[feature] trait QuantileDiscretizerBase extends Params | |
| /** | ||
| * `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned | ||
| * categorical features. The number of bins can be set using the `numBuckets` parameter. It is | ||
| * possible that the number of buckets used will be less than this value, for example, if there | ||
| * are too few distinct values of the input to create enough distinct quantiles. Note also that | ||
| * NaN values are handled specially and placed into their own bucket. For example, if 4 buckets | ||
| * are used, then non-NaN data will be put into buckets(0-3), but NaNs will be counted in a special | ||
| * bucket(4). | ||
| * possible that the number of buckets used will be less than this value, for example, if there are | ||
|
||
| * too few distinct values of the input to create enough distinct quantiles. Note also that | ||
| * QuantileDiscretizer will raise an error when it finds NaN value in the dataset, but user can | ||
| * also choose to either keep or remove NaN values within the dataset by setting handleInvalid. | ||
| * If user chooses to keep NaN values, they will be handled specially and placed into their own | ||
| * bucket, for example, if 4 buckets are used, then non-NaN data will be put into buckets[0-3], | ||
| * but NaNs will be counted in a special bucket[4]. | ||
| * The bin ranges are chosen using an approximate algorithm (see the documentation for | ||
| * [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]] | ||
| * for a detailed description). The precision of the approximation can be controlled with the | ||
|
|
@@ -100,6 +102,33 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui | |
| @Since("1.6.0") | ||
| def setOutputCol(value: String): this.type = set(outputCol, value) | ||
|
|
||
| /** | ||
| * Param for how to handle invalid entries. Options are skip (which will filter out rows with | ||
| * invalid values), or error (which will throw an error), or keep (which will keep the invalid | ||
| * values in certain way). Default behaviour is to report an error for invalid entries. | ||
|
||
| * | ||
| * @group param | ||
| */ | ||
| @Since("2.1.0") | ||
| val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle" + | ||
| "invalid entries. Options are skip (which will filter out rows with invalid values), or" + | ||
| "error (which will throw an error), or keep (which will keep the invalid values" + | ||
| " in certain way). Default behaviour is to report an error for invalid entries.", | ||
| ParamValidators.inArray(Array("skip", "error", "keep"))) | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.1.0") | ||
| def gethandleInvalid: Option[Boolean] = $(handleInvalid) match { | ||
| case "keep" => Some(true) | ||
| case "skip" => Some(false) | ||
| case _ => None | ||
| } | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.1.0") | ||
| def sethandleInvalid(value: String): this.type = set(handleInvalid, value) | ||
| setDefault(handleInvalid, "error") | ||
|
|
||
| @Since("1.6.0") | ||
| override def transformSchema(schema: StructType): StructType = { | ||
| SchemaUtils.checkNumericType(schema, $(inputCol)) | ||
|
|
@@ -124,7 +153,9 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui | |
| log.warn(s"Some quantiles were identical. Bucketing to ${distinctSplits.length - 1}" + | ||
| s" buckets as a result.") | ||
| } | ||
| val bucketizer = new Bucketizer(uid).setSplits(distinctSplits.sorted) | ||
| val bucketizer = new Bucketizer(uid) | ||
| .setSplits(distinctSplits.sorted) | ||
| .sethandleInvalid($(handleInvalid)) | ||
| copyValues(bucketizer.setParent(this)) | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1157,9 +1157,11 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadab | |
| categorical features. The number of bins can be set using the :py:attr:`numBuckets` parameter. | ||
| It is possible that the number of buckets used will be less than this value, for example, if | ||
|
||
| there are too few distinct values of the input to create enough distinct quantiles. Note also | ||
| that NaN values are handled specially and placed into their own bucket. For example, if 4 | ||
| buckets are used, then non-NaN data will be put into buckets(0-3), but NaNs will be counted in | ||
| a special bucket(4). | ||
| that QuantileDiscretizer will raise an error when it finds NaN value in the dataset, but user | ||
|
||
| can also choose to either keep or remove NaN values within the dataset by setting | ||
| handleInvalid. If user chooses to keep NaN values, they will be handled specially and placed | ||
| into their own bucket, for example, if 4 buckets are used, then non-NaN data will be put into | ||
| buckets[0-3], but NaNs will be counted in a special bucket[4]. | ||
| The bin ranges are chosen using an approximate algorithm (see the documentation for | ||
| :py:meth:`~.DataFrameStatFunctions.approxQuantile` for a detailed description). | ||
| The precision of the approximation can be controlled with the | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(same as below) Is "possible that the number of buckets used will be less than this value" true? It was true before this used Dataset.approxQuantiles, but I don't think it is true any longer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In cases when the number of buckets requested by the user is greater than the number of distinct splits generated from Bucketizer, the returned number of buckets will be less than requested.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, you're right