-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-22397][ML]add multiple columns support to QuantileDiscretizer #19715
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
07bd868
87ee0f3
5038e21
97ad483
445bd84
0e5971b
a030da1
99726a1
486b68d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,7 @@ import org.apache.spark.internal.Logging | |
| import org.apache.spark.ml._ | ||
| import org.apache.spark.ml.attribute.NominalAttribute | ||
| import org.apache.spark.ml.param._ | ||
| import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasOutputCol} | ||
| import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasInputCols, HasOutputCol, HasOutputCols} | ||
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.sql.Dataset | ||
| import org.apache.spark.sql.types.StructType | ||
|
|
@@ -31,7 +31,7 @@ import org.apache.spark.sql.types.StructType | |
| * Params for [[QuantileDiscretizer]]. | ||
| */ | ||
| private[feature] trait QuantileDiscretizerBase extends Params | ||
| with HasHandleInvalid with HasInputCol with HasOutputCol { | ||
| with HasHandleInvalid with HasInputCol with HasInputCols with HasOutputCol with HasOutputCols { | ||
|
|
||
| /** | ||
| * Number of buckets (quantiles, or categories) into which data points are grouped. Must | ||
|
|
@@ -50,10 +50,26 @@ private[feature] trait QuantileDiscretizerBase extends Params | |
| /** @group getParam */ | ||
| def getNumBuckets: Int = getOrDefault(numBuckets) | ||
|
|
||
| /** | ||
| * Array of number of buckets (quantiles, or categories) into which data points are grouped. | ||
| * | ||
| * See also [[handleInvalid]], which can optionally create an additional bucket for NaN values. | ||
| * | ||
| * @group param | ||
| */ | ||
| val numBucketsArray = new IntArrayParam(this, "numBucketsArray", "Array of number of buckets " + | ||
| "(quantiles, or categories) into which data points are grouped. This is for multiple " + | ||
| "columns input. If numBucketsArray is not set but numBuckets is set, it means user wants " + | ||
|
||
| "to use the same numBuckets across all columns.") | ||
|
||
|
|
||
| /** @group getParam */ | ||
| def getNumBucketsArray: Array[Int] = $(numBucketsArray) | ||
|
|
||
| /** | ||
| * Relative error (see documentation for | ||
| * `org.apache.spark.sql.DataFrameStatFunctions.approxQuantile` for description) | ||
| * Must be in the range [0, 1]. | ||
| * Note that in multiple columns case, relative error is applied to all columns. | ||
| * default: 0.001 | ||
| * @group param | ||
| */ | ||
|
|
@@ -68,7 +84,9 @@ private[feature] trait QuantileDiscretizerBase extends Params | |
| /** | ||
| * Param for how to handle invalid entries. Options are 'skip' (filter out rows with | ||
| * invalid values), 'error' (throw an error), or 'keep' (keep invalid values in a special | ||
| * additional bucket). | ||
| * additional bucket). Note that in the multiple columns case, the invalid handling is applied | ||
| * to all columns. That said for 'error' it will throw an error if any invalids are found in | ||
| * any column, for 'skip' it will skip rows with any invalids in any columns, etc. | ||
| * Default: "error" | ||
| * @group param | ||
| */ | ||
|
|
@@ -86,6 +104,10 @@ private[feature] trait QuantileDiscretizerBase extends Params | |
| * categorical features. The number of bins can be set using the `numBuckets` parameter. It is | ||
| * possible that the number of buckets used will be smaller than this value, for example, if there | ||
| * are too few distinct values of the input to create enough distinct quantiles. | ||
| * Since 2.3.0, | ||
|
||
| * `QuantileDiscretizer` can also map multiple columns at once. Whether it goes to map a column or | ||
| * multiple columns, it depends on which parameter of `inputCol` and `inputCols` is set. When both | ||
| * are set, a log warning will be printed and by default it chooses `inputCol`. | ||
| * | ||
| * NaN handling: | ||
| * null and NaN values will be ignored from the column during `QuantileDiscretizer` fitting. This | ||
|
|
@@ -129,34 +151,95 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui | |
| @Since("2.1.0") | ||
| def setHandleInvalid(value: String): this.type = set(handleInvalid, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.3.0") | ||
| def setNumBucketsArray(value: Array[Int]): this.type = set(numBucketsArray, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.3.0") | ||
| def setInputCols(value: Array[String]): this.type = set(inputCols, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.3.0") | ||
| def setOutputCols(value: Array[String]): this.type = set(outputCols, value) | ||
|
|
||
| private[feature] def isQuantileDiscretizeMultipleColumns(): Boolean = { | ||
| if (isSet(inputCols) && isSet(inputCol)) { | ||
| logWarning("Both `inputCol` and `inputCols` are set, we ignore `inputCols` and this " + | ||
|
||
| "`QuantileDiscretize` only map one column specified by `inputCol`") | ||
|
||
| false | ||
| } else if (isSet(inputCols)) { | ||
| true | ||
| } else { | ||
| false | ||
| } | ||
| } | ||
|
|
||
| private[feature] def getInOutCols: (Array[String], Array[String]) = { | ||
| if (!isQuantileDiscretizeMultipleColumns) { | ||
| (Array($(inputCol)), Array($(outputCol))) | ||
| } else { | ||
| require($(inputCols).length == $(outputCols).length, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should add a small test case for mismatched sizes of |
||
| "inputCols number do not match outputCols") | ||
| ($(inputCols), $(outputCols)) | ||
| } | ||
| } | ||
|
|
||
| @Since("1.6.0") | ||
| override def transformSchema(schema: StructType): StructType = { | ||
| SchemaUtils.checkNumericType(schema, $(inputCol)) | ||
| val inputFields = schema.fields | ||
| require(inputFields.forall(_.name != $(outputCol)), | ||
| s"Output column ${$(outputCol)} already exists.") | ||
| val attr = NominalAttribute.defaultAttr.withName($(outputCol)) | ||
| val outputFields = inputFields :+ attr.toStructField() | ||
| val (inputColNames, outputColNames) = getInOutCols | ||
| val existingFields = schema.fields | ||
| var outputFields = existingFields | ||
| inputColNames.zip(outputColNames).map { case (inputColName, outputColName) => | ||
|
||
| SchemaUtils.checkNumericType(schema, inputColName) | ||
| require(existingFields.forall(_.name != outputColName), | ||
| s"Output column ${outputColName} already exists.") | ||
| val attr = NominalAttribute.defaultAttr.withName(outputColName) | ||
| outputFields :+= attr.toStructField() | ||
| } | ||
| StructType(outputFields) | ||
| } | ||
|
|
||
| @Since("2.0.0") | ||
| override def fit(dataset: Dataset[_]): Bucketizer = { | ||
| transformSchema(dataset.schema, logging = true) | ||
| val splits = dataset.stat.approxQuantile($(inputCol), | ||
| (0.0 to 1.0 by 1.0/$(numBuckets)).toArray, $(relativeError)) | ||
| val bucketizer = new Bucketizer(uid).setHandleInvalid($(handleInvalid)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looking at this now, the But the multi-buckets case can perhaps still be cleaned up. How about something like this: override def fit(dataset: Dataset[_]): Bucketizer = {
transformSchema(dataset.schema, logging = true)
val bucketizer = new Bucketizer(uid).setHandleInvalid($(handleInvalid))
if (isQuantileDiscretizeMultipleColumns) {
val splitsArray = if (isSet(numBucketsArray)) {
val probArrayPerCol = $(numBucketsArray).map { numOfBuckets =>
(0.0 to 1.0 by 1.0 / numOfBuckets).toArray
}
val probabilityArray = probArrayPerCol.flatten.sorted.distinct
val splitsArrayRaw = dataset.stat.approxQuantile($(inputCols),
probabilityArray, $(relativeError))
splitsArrayRaw.zip(probArrayPerCol).map { case (splits, probs) =>
val probSet = probs.toSet
val idxSet = probabilityArray.zipWithIndex.collect {
case (p, idx) if probSet(p) =>
idx
}.toSet
splits.zipWithIndex.collect {
case (s, idx) if idxSet(idx) =>
s
}
}
} else {
dataset.stat.approxQuantile($(inputCols),
(0.0 to 1.0 by 1.0 / $(numBuckets)).toArray, $(relativeError))
}
bucketizer.setSplitsArray(splitsArray.map(getDistinctSplits))
} else {
val splits = dataset.stat.approxQuantile($(inputCol),
(0.0 to 1.0 by 1.0 / $(numBuckets)).toArray, $(relativeError))
bucketizer.setSplits(getDistinctSplits(splits))
}
copyValues(bucketizer.setParent(this))
}Then we don't need |
||
| if (isQuantileDiscretizeMultipleColumns) { | ||
|
||
| var bucketArray = Array.empty[Int] | ||
|
||
| if (isSet(numBucketsArray)) { | ||
| bucketArray = $(numBucketsArray) | ||
| } | ||
| else { | ||
| bucketArray = Array($(numBuckets)) | ||
| } | ||
| val probabilityArray = bucketArray.toSeq.flatMap { numOfBucket => | ||
| (0.0 to 1.0 by 1.0 / numOfBucket) | ||
| } | ||
| val splitsArray = dataset.stat.approxQuantile($(inputCols), | ||
| probabilityArray.sorted.toArray.distinct, $(relativeError)) | ||
| val distinctSplitsArray = splitsArray.toSeq.map { splits => | ||
| getDistinctSplits(splits) | ||
| } | ||
| bucketizer.setSplitsArray(distinctSplitsArray.toArray) | ||
| copyValues(bucketizer.setParent(this)) | ||
| } | ||
| else { | ||
|
||
| val splits = dataset.stat.approxQuantile($(inputCol), | ||
| (0.0 to 1.0 by 1.0 / $(numBuckets)).toArray, $(relativeError)) | ||
| bucketizer.setSplits(getDistinctSplits(splits)) | ||
| copyValues(bucketizer.setParent(this)) | ||
| } | ||
| } | ||
|
|
||
| private def getDistinctSplits(splits: Array[Double]): Array[Double] = { | ||
| splits(0) = Double.NegativeInfinity | ||
| splits(splits.length - 1) = Double.PositiveInfinity | ||
|
|
||
| val distinctSplits = splits.distinct | ||
| if (splits.length != distinctSplits.length) { | ||
| log.warn(s"Some quantiles were identical. Bucketing to ${distinctSplits.length - 1}" + | ||
| s" buckets as a result.") | ||
| } | ||
| val bucketizer = new Bucketizer(uid) | ||
| .setSplits(distinctSplits.sorted) | ||
| .setHandleInvalid($(handleInvalid)) | ||
| copyValues(bucketizer.setParent(this)) | ||
| distinctSplits.sorted | ||
| } | ||
|
|
||
| @Since("1.6.0") | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -146,4 +146,172 @@ class QuantileDiscretizerSuite | |
| val model = discretizer.fit(df) | ||
| assert(model.hasParent) | ||
| } | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should add 2 tests:
|
||
| test("Multiple Columns: Test observed number of buckets and their sizes match expected values") { | ||
| val spark = this.spark | ||
| import spark.implicits._ | ||
|
|
||
| val datasetSize = 100000 | ||
| val numBuckets = 5 | ||
| val data1 = Array.range(1, 100001, 1).map(_.toDouble) | ||
| val data2 = Array.range(1, 200000, 2).map(_.toDouble) | ||
| val data = (0 until 100000).map { idx => | ||
| (data1(idx), data2(idx)) | ||
| } | ||
|
||
| val df: DataFrame = data.toSeq.toDF("input1", "input2") | ||
|
||
|
|
||
| val discretizer = new QuantileDiscretizer() | ||
| .setInputCols(Array("input1", "input2")) | ||
| .setOutputCols(Array("result1", "result2")) | ||
| .setNumBuckets(numBuckets) | ||
| assert(discretizer.isQuantileDiscretizeMultipleColumns()) | ||
| val result = discretizer.fit(df).transform(df) | ||
|
|
||
| val relativeError = discretizer.getRelativeError | ||
| val isGoodBucket = udf { | ||
| (size: Int) => math.abs( size - (datasetSize / numBuckets)) <= (relativeError * datasetSize) | ||
| } | ||
|
|
||
| for (i <- 1 to 2) { | ||
| val observedNumBuckets = result.select("result" + i).distinct.count | ||
| assert(observedNumBuckets === numBuckets, | ||
| "Observed number of buckets does not equal expected number of buckets.") | ||
|
|
||
| val numGoodBuckets = result.groupBy("result" + i).count.filter(isGoodBucket($"count")).count | ||
| assert(numGoodBuckets === numBuckets, | ||
| "Bucket sizes are not within expected relative error tolerance.") | ||
| } | ||
| } | ||
|
|
||
| test("Multiple Columns: Test on data with high proportion of duplicated values") { | ||
| val spark = this.spark | ||
| import spark.implicits._ | ||
|
|
||
| val numBuckets = 5 | ||
| val expectedNumBucket = 3 | ||
| val data1 = Array(1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 2.0, 2.0, 2.0, 1.0, 3.0) | ||
| val data2 = Array(1.0, 2.0, 3.0, 1.0, 1.0, 1.0, 1.0, 3.0, 2.0, 3.0, 1.0, 2.0) | ||
| val data = (0 until data1.length).map { idx => | ||
| (data1(idx), data2(idx)) | ||
| } | ||
|
||
| val df: DataFrame = data.toSeq.toDF("input1", "input2") | ||
|
||
| val discretizer = new QuantileDiscretizer() | ||
| .setInputCols(Array("input1", "input2")) | ||
| .setOutputCols(Array("result1", "result2")) | ||
| .setNumBuckets(numBuckets) | ||
| assert(discretizer.isQuantileDiscretizeMultipleColumns()) | ||
| val result = discretizer.fit(df).transform(df) | ||
| for (i <- 1 to 2) { | ||
| val observedNumBuckets = result.select("result" + i).distinct.count | ||
| assert(observedNumBuckets == expectedNumBucket, | ||
| s"Observed number of buckets are not correct." + | ||
| s" Expected $expectedNumBucket but found ($observedNumBuckets") | ||
| } | ||
| } | ||
|
|
||
| test("Multiple Columns: Test transform on data with NaN value") { | ||
| val spark = this.spark | ||
| import spark.implicits._ | ||
|
|
||
| val numBuckets = 3 | ||
| val validData1 = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9, Double.NaN, Double.NaN, Double.NaN) | ||
| val expectedKeep1 = Array(0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0) | ||
| val validData2 = Array(0.2, -0.1, 0.3, 0.0, 0.1, 0.3, 0.5, 0.8, Double.NaN, Double.NaN) | ||
| val expectedKeep2 = Array(1.0, 0.0, 2.0, 0.0, 1.0, 2.0, 2.0, 2.0, 3.0, 3.0) | ||
|
|
||
| val data = (0 until validData1.length).map { idx => | ||
| (validData1(idx), validData2(idx), expectedKeep1(idx), expectedKeep2(idx)) | ||
| } | ||
| val dataFrame: DataFrame = data.toSeq.toDF("input1", "input2", "expected1", "expected2") | ||
|
|
||
| val discretizer = new QuantileDiscretizer() | ||
| .setInputCols(Array("input1", "input2")) | ||
| .setOutputCols(Array("result1", "result2")) | ||
| .setNumBuckets(numBuckets) | ||
| assert(discretizer.isQuantileDiscretizeMultipleColumns()) | ||
|
|
||
| withClue("QuantileDiscretizer with handleInvalid=error should throw exception for NaN values") { | ||
| intercept[SparkException] { | ||
| discretizer.fit(dataFrame).transform(dataFrame).collect() | ||
| } | ||
| } | ||
|
|
||
| discretizer.setHandleInvalid("keep") | ||
| discretizer.fit(dataFrame).transform(dataFrame). | ||
| select("result1", "expected1", "result2", "expected2") | ||
| .collect().foreach { | ||
| case Row(r1: Double, e1: Double, r2: Double, e2: Double) => | ||
| assert(r1 === e1, | ||
| s"The result value is not correct after bucketing. Expected $e1 but found $r1") | ||
| assert(r2 === e2, | ||
| s"The result value is not correct after bucketing. Expected $e2 but found $r2") | ||
| } | ||
|
|
||
| discretizer.setHandleInvalid("skip") | ||
| val result = discretizer.fit(dataFrame).transform(dataFrame) | ||
| for (i <- 1 to 2) { | ||
| val skipResults1: Array[Double] = result.select("result" + i).as[Double].collect() | ||
| assert(skipResults1.length === 7) | ||
| assert(skipResults1.forall(_ !== 4.0)) | ||
| } | ||
| } | ||
|
|
||
| test("Multiple Columns: Test numBucketsArray") { | ||
| val spark = this.spark | ||
| import spark.implicits._ | ||
|
|
||
| val datasetSize = 20 | ||
| val numBucketsArray: Array[Int] = Array(2, 5, 10) | ||
| val data1 = Array.range(1, 21, 1).map(_.toDouble) | ||
| val expected1 = Array (0.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 4.0, 4.0, 5.0, | ||
| 5.0, 5.0, 6.0, 6.0, 7.0, 8.0, 8.0, 9.0, 9.0, 9.0) | ||
|
||
| val data2 = Array.range(1, 40, 2).map(_.toDouble) | ||
| val expected2 = Array (0.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 4.0, 4.0, 5.0, | ||
| 5.0, 5.0, 6.0, 6.0, 7.0, 8.0, 8.0, 9.0, 9.0, 9.0) | ||
| val data3 = Array.range(1, 60, 3).map(_.toDouble) | ||
| val expected3 = Array (0.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 4.0, 4.0, 5.0, | ||
| 5.0, 5.0, 6.0, 6.0, 7.0, 8.0, 8.0, 9.0, 9.0, 9.0) | ||
| val data = (0 until 20).map { idx => | ||
| (data1(idx), data2(idx), data3(idx), expected1(idx), expected2(idx), expected3(idx)) | ||
| } | ||
| val df: DataFrame = | ||
| data.toSeq.toDF("input1", "input2", "input3", "expected1", "expected2", "expected3") | ||
|
|
||
| val discretizer = new QuantileDiscretizer() | ||
| .setInputCols(Array("input1", "input2", "input3")) | ||
| .setOutputCols(Array("result1", "result2", "result3")) | ||
| .setNumBucketsArray(numBucketsArray) | ||
| assert(discretizer.isQuantileDiscretizeMultipleColumns()) | ||
| discretizer.fit(df).transform(df). | ||
| select("result1", "expected1", "result2", "expected2", "result3", "expected3") | ||
| .collect().foreach { | ||
| case Row(r1: Double, e1: Double, r2: Double, e2: Double, r3: Double, e3: Double) => | ||
| assert(r1 === e1, | ||
| s"The result value is not correct after bucketing. Expected $e1 but found $r1") | ||
| assert(r2 === e2, | ||
| s"The result value is not correct after bucketing. Expected $e2 but found $r2") | ||
| assert(r3 === e3, | ||
| s"The result value is not correct after bucketing. Expected $e3 but found $r3") | ||
| } | ||
| } | ||
|
|
||
| test("multiple columns: read/write") { | ||
| val discretizer = new QuantileDiscretizer() | ||
| .setInputCols(Array("input1", "input2")) | ||
| .setOutputCols(Array("result1", "result2")) | ||
| .setNumBucketsArray(Array(5, 10)) | ||
| assert(discretizer.isQuantileDiscretizeMultipleColumns()) | ||
| testDefaultReadWrite(discretizer) | ||
| } | ||
|
|
||
| test("Both inputCol and inputCols are set") { | ||
| val discretizer = new QuantileDiscretizer() | ||
| .setInputCol("input") | ||
| .setOutputCol("result") | ||
| .setNumBuckets(3) | ||
| .setInputCols(Array("input1", "input2")) | ||
|
|
||
| // When both are set, we ignore `inputCols` and just map the column specified by `inputCol`. | ||
| assert(discretizer.isQuantileDiscretizeMultipleColumns() == false) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can add a comment about "each value must be greater than or equal to 2"