Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions docs/ml-features.md
Original file line number Diff line number Diff line change
Expand Up @@ -1104,9 +1104,11 @@ for more details on the API.
`QuantileDiscretizer` takes a column with continuous features and outputs a column with binned
categorical features. The number of bins is set by the `numBuckets` parameter. It is possible
that the number of buckets used will be less than this value, for example, if there are too few
distinct values of the input to create enough distinct quantiles. Note also that NaN values are
handled specially and placed into their own bucket. For example, if 4 buckets are used, then
non-NaN data will be put into buckets[0-3], but NaNs will be counted in a special bucket[4].
distinct values of the input to create enough distinct quantiles. Note also that QuantileDiscretizer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(same as below) Is "possible that the number of buckets used will be less than this value" true? It was true before this used Dataset.approxQuantiles, but I don't think it is true any longer.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In cases when the number of buckets requested by the user is greater than the number of distinct splits generated from Bucketizer, the returned number of buckets will be less than requested.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, you're right

will raise an error when it finds NaN value in the dataset, but user can also choose to either
keep or remove NaN values within the dataset by setting handleInvalid. If user chooses to keep
NaN values, they will be handled specially and placed into their own bucket, for example, if 4 buckets
are used, then non-NaN data will be put into buckets[0-3], but NaNs will be counted in a special bucket[4].
The bin ranges are chosen using an approximate algorithm (see the documentation for
[approxQuantile](api/scala/index.html#org.apache.spark.sql.DataFrameStatFunctions) for a
detailed description). The precision of the approximation can be controlled with the
Expand Down
63 changes: 56 additions & 7 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}

Expand Down Expand Up @@ -73,15 +74,52 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
@Since("1.4.0")
def setOutputCol(value: String): this.type = set(outputCol, value)

/**
* Param for how to handle invalid entries. Options are skip (which will filter out rows with
* invalid values), or error (which will throw an error), or keep (which will keep the invalid
* values in certain way). Default behaviour is to report an error for invalid entries.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd just write: default: "error"
Rewording as "report" instead of "throw" could confuse people.

*
* @group param
*/
@Since("2.1.0")
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle" +
"invalid entries. Options are skip (which will filter out rows with invalid values), or" +
"error (which will throw an error), or keep (which will keep the invalid values" +
" in certain way). Default behaviour is to report an error for invalid entries.",
ParamValidators.inArray(Array("skip", "error", "keep")))

/** @group getParam */
@Since("2.1.0")
def gethandleInvalid: Option[Boolean] = $(handleInvalid) match {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should just return $(handleInvalid), just like any other Param getter method.

case "keep" => Some(true)
case "skip" => Some(false)
case _ => None
}

/** @group setParam */
@Since("2.1.0")
def sethandleInvalid(value: String): this.type = set(handleInvalid, value)
setDefault(handleInvalid, "error")

@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema)
val bucketizer = udf { feature: Double =>
Bucketizer.binarySearchForBuckets($(splits), feature)
val keepInvalid = gethandleInvalid.isDefined && gethandleInvalid.get

val bucketizer: UserDefinedFunction = udf { (feature: Double) =>
Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid)
}
val newCol = bucketizer(dataset($(inputCol)))
val newField = prepOutputField(dataset.schema)
dataset.withColumn($(outputCol), newCol, newField.metadata)
val filteredDataset = {
if (!keepInvalid) {
// "skip" NaN option is set, will filter out NaN values in the dataset
dataset.na.drop.toDF()
} else {
dataset.toDF()
}
}
val newCol = bucketizer(filteredDataset($(inputCol)))
val newField = prepOutputField(filteredDataset.schema)
filteredDataset.withColumn($(outputCol), newCol, newField.metadata)
}

private def prepOutputField(schema: StructType): StructField = {
Expand Down Expand Up @@ -126,10 +164,21 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] {

/**
* Binary searching in several buckets to place each data point.
* @param splits array of split points
* @param feature data point
* @param keepInvalid NaN flag.
* Set "true" to make an extra bucket for NaN values;
* Set "false" to report an error for NaN values
* @return bucket for each data point
* @throws SparkException if a feature is < splits.head or > splits.last
*/
private[feature] def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = {
if (feature.isNaN) {

private[feature] def binarySearchForBuckets(
splits: Array[Double],
feature: Double,
keepInvalid: Boolean): Double = {
if (feature.isNaN && keepInvalid) {
// NaN data point found plus "keep" NaN option is set
splits.length - 1
} else if (feature == splits.last) {
splits.length - 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,13 @@ private[feature] trait QuantileDiscretizerBase extends Params
/**
* `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned
* categorical features. The number of bins can be set using the `numBuckets` parameter. It is
* possible that the number of buckets used will be less than this value, for example, if there
* are too few distinct values of the input to create enough distinct quantiles. Note also that
* NaN values are handled specially and placed into their own bucket. For example, if 4 buckets
* are used, then non-NaN data will be put into buckets(0-3), but NaNs will be counted in a special
* bucket(4).
* possible that the number of buckets used will be less than this value, for example, if there are
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this true? It was true before this used Dataset.approxQuantiles, but I don't think it is true any longer.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as the comment above

* too few distinct values of the input to create enough distinct quantiles. Note also that
* QuantileDiscretizer will raise an error when it finds NaN value in the dataset, but user can
* also choose to either keep or remove NaN values within the dataset by setting handleInvalid.
* If user chooses to keep NaN values, they will be handled specially and placed into their own
* bucket, for example, if 4 buckets are used, then non-NaN data will be put into buckets[0-3],
* but NaNs will be counted in a special bucket[4].
* The bin ranges are chosen using an approximate algorithm (see the documentation for
* [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]]
* for a detailed description). The precision of the approximation can be controlled with the
Expand Down Expand Up @@ -100,6 +102,33 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui
@Since("1.6.0")
def setOutputCol(value: String): this.type = set(outputCol, value)

/**
* Param for how to handle invalid entries. Options are skip (which will filter out rows with
* invalid values), or error (which will throw an error), or keep (which will keep the invalid
* values in certain way). Default behaviour is to report an error for invalid entries.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd just write: default: "error"
Rewording as "report" instead of "throw" could confuse people.

*
* @group param
*/
@Since("2.1.0")
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle" +
"invalid entries. Options are skip (which will filter out rows with invalid values), or" +
"error (which will throw an error), or keep (which will keep the invalid values" +
" in certain way). Default behaviour is to report an error for invalid entries.",
ParamValidators.inArray(Array("skip", "error", "keep")))

/** @group getParam */
@Since("2.1.0")
def gethandleInvalid: Option[Boolean] = $(handleInvalid) match {
case "keep" => Some(true)
case "skip" => Some(false)
case _ => None
}

/** @group setParam */
@Since("2.1.0")
def sethandleInvalid(value: String): this.type = set(handleInvalid, value)
setDefault(handleInvalid, "error")

@Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
SchemaUtils.checkNumericType(schema, $(inputCol))
Expand All @@ -124,7 +153,9 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui
log.warn(s"Some quantiles were identical. Bucketing to ${distinctSplits.length - 1}" +
s" buckets as a result.")
}
val bucketizer = new Bucketizer(uid).setSplits(distinctSplits.sorted)
val bucketizer = new Bucketizer(uid)
.setSplits(distinctSplits.sorted)
.sethandleInvalid($(handleInvalid))
copyValues(bucketizer.setParent(this))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setInputCol("feature")
.setOutputCol("result")
.setSplits(splits)
.sethandleInvalid("keep")

bucketizer.transform(dataFrame).select("result", "expected").collect().foreach {
case Row(x: Double, y: Double) =>
Expand All @@ -111,8 +112,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
withClue("Invalid NaN split was not caught as an invalid split!") {
intercept[IllegalArgumentException] {
val bucketizer: Bucketizer = new Bucketizer()
.setInputCol("feature")
.setOutputCol("result")
.setSplits(splits)
}
}
Expand All @@ -138,7 +137,8 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
val data = Array.fill(100)(Random.nextDouble())
val splits: Array[Double] = Double.NegativeInfinity +:
Array.fill(10)(Random.nextDouble()).sorted :+ Double.PositiveInfinity
val bsResult = Vectors.dense(data.map(x => Bucketizer.binarySearchForBuckets(splits, x)))
val bsResult = Vectors.dense(data.map(x =>
Bucketizer.binarySearchForBuckets(splits, x, false)))
val lsResult = Vectors.dense(data.map(x => BucketizerSuite.linearSearchForBuckets(splits, x)))
assert(bsResult ~== lsResult absTol 1e-5)
}
Expand Down Expand Up @@ -169,7 +169,7 @@ private object BucketizerSuite extends SparkFunSuite {
/** Check all values in splits, plus values between all splits. */
def checkBinarySearch(splits: Array[Double]): Unit = {
def testFeature(feature: Double, expectedBucket: Double): Unit = {
assert(Bucketizer.binarySearchForBuckets(splits, feature) === expectedBucket,
assert(Bucketizer.binarySearchForBuckets(splits, feature, false) === expectedBucket,
s"Expected feature value $feature to be in bucket $expectedBucket with splits:" +
s" ${splits.mkString(", ")}")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql._
import org.apache.spark.sql.functions.udf

class QuantileDiscretizerSuite
Expand Down Expand Up @@ -76,20 +76,26 @@ class QuantileDiscretizerSuite
import spark.implicits._

val numBuckets = 3
val df = sc.parallelize(Array(1.0, 1.0, 1.0, Double.NaN))
.map(Tuple1.apply).toDF("input")
val validData = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9, Double.NaN, Double.NaN, Double.NaN)
val expectedKeep = Array(0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0)
val expectedSkip = Array(0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 2.0)

val discretizer = new QuantileDiscretizer()
.setInputCol("input")
.setOutputCol("result")
.setNumBuckets(numBuckets)

// Reserve extra one bucket for NaN
val expectedNumBuckets = discretizer.fit(df).getSplits.length - 1
val result = discretizer.fit(df).transform(df)
val observedNumBuckets = result.select("result").distinct.count
assert(observedNumBuckets == expectedNumBuckets,
s"Observed number of buckets are not correct." +
s" Expected $expectedNumBuckets but found $observedNumBuckets")
List(("keep", expectedKeep), ("skip", expectedSkip)).foreach{
case(u, v) =>
discretizer.sethandleInvalid(u)
val dataFrame: DataFrame = validData.zip(v).toSeq.toDF("input", "expected")
val result = discretizer.fit(dataFrame).transform(dataFrame)
result.select("result", "expected").collect().foreach {
case Row(x: Double, y: Double) =>
assert(x === y,
s"The feature value is not correct after bucketing. Expected $y but found $x")
}
}
}

test("Test transform method on unseen data") {
Expand Down
8 changes: 5 additions & 3 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,9 +1157,11 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadab
categorical features. The number of bins can be set using the :py:attr:`numBuckets` parameter.
It is possible that the number of buckets used will be less than this value, for example, if
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here too: no longer the case

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as the comment above.

there are too few distinct values of the input to create enough distinct quantiles. Note also
that NaN values are handled specially and placed into their own bucket. For example, if 4
buckets are used, then non-NaN data will be put into buckets(0-3), but NaNs will be counted in
a special bucket(4).
that QuantileDiscretizer will raise an error when it finds NaN value in the dataset, but user
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually no need to update Python API until it is updated to include handleNaN

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't available in Python yet, so can you please revert this change to feature.py?

can also choose to either keep or remove NaN values within the dataset by setting
handleInvalid. If user chooses to keep NaN values, they will be handled specially and placed
into their own bucket, for example, if 4 buckets are used, then non-NaN data will be put into
buckets[0-3], but NaNs will be counted in a special bucket[4].
The bin ranges are chosen using an approximate algorithm (see the documentation for
:py:meth:`~.DataFrameStatFunctions.approxQuantile` for a detailed description).
The precision of the approximation can be controlled with the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
assert(math.abs(d1 - 2 * q1 * n) < error_double)
assert(math.abs(d2 - 2 * q2 * n) < error_double)
}
// test approxQuantile on NaN values
val dfNaN = Array(Double.NaN, 1.0, Double.NaN, Double.NaN).toSeq.toDF("input")
val resNaN = dfNaN.stat.approxQuantile("input", Array(q1, q2), epsilons(0))
assert(resNaN.count(_.isNaN) == 0)
}

test("crosstab") {
Expand Down