Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Cleanups: docs cleanups, slightly improved unit test coverage, fixed …
…naming of set/get for handleInvalid
  • Loading branch information
jkbradley committed Oct 26, 2016
commit 2644235f111bbbf43fd1f30d24d318735553e034
13 changes: 8 additions & 5 deletions docs/ml-features.md
Original file line number Diff line number Diff line change
Expand Up @@ -1103,13 +1103,16 @@ for more details on the API.

`QuantileDiscretizer` takes a column with continuous features and outputs a column with binned
categorical features. The number of bins is set by the `numBuckets` parameter. It is possible
that the number of buckets used will be less than this value, for example, if there are too few
distinct values of the input to create enough distinct quantiles. Note also that QuantileDiscretizer
will raise an error when it finds NaN value in the dataset, but user can also choose to either
keep or remove NaN values within the dataset by setting handleInvalid. If user chooses to keep
that the number of buckets used will be smaller than this value, for example, if there are too few
distinct values of the input to create enough distinct quantiles.

NaN values: Note also that QuantileDiscretizer
will raise an error when it finds NaN values in the dataset, but the user can also choose to either
keep or remove NaN values within the dataset by setting `handleInvalid`. If the user chooses to keep
NaN values, they will be handled specially and placed into their own bucket, for example, if 4 buckets
are used, then non-NaN data will be put into buckets[0-3], but NaNs will be counted in a special bucket[4].
The bin ranges are chosen using an approximate algorithm (see the documentation for

Algorithm: The bin ranges are chosen using an approximate algorithm (see the documentation for
[approxQuantile](api/scala/index.html#org.apache.spark.sql.DataFrameStatFunctions) for a
detailed description). The precision of the approximation can be controlled with the
`relativeError` parameter. When set to zero, exact quantiles are calculated
Expand Down
44 changes: 28 additions & 16 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
* also includes y. Splits should be of length >= 3 and strictly increasing.
* Values at -inf, inf must be explicitly provided to cover all Double values;
* otherwise, values outside the splits specified will be treated as errors.
*
* See also [[handleInvalid]], which can optionally create an additional bucket for NaN values.
*
* @group param
*/
@Since("1.4.0")
Expand Down Expand Up @@ -75,37 +78,36 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
def setOutputCol(value: String): this.type = set(outputCol, value)

/**
* Param for how to handle invalid entries. Options are skip (which will filter out rows with
* invalid values), or error (which will throw an error), or keep (which will keep the invalid
* values in certain way).
* Param for how to handle invalid entries. Options are skip (filter out rows with
* invalid values), error (throw an error), or keep (keep invalid values in a special additional
* bucket).
* Default: "error"
* @group param
*/
@Since("2.1.0")
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle" +
"invalid entries. Options are skip (which will filter out rows with invalid values), or" +
"error (which will throw an error), or keep (which will keep the invalid values" +
" in certain way). Default behaviour is to report an error for invalid entries.",
ParamValidators.inArray(Array("skip", "error", "keep")))
"invalid entries. Options are skip (filter out rows with invalid values), " +
"error (throw an error), or keep (keep invalid values in a special additional bucket).",
ParamValidators.inArray(Bucketizer.supportedHandleInvalid))

/** @group getParam */
@Since("2.1.0")
def gethandleInvalid: String = $(handleInvalid)
def getHandleInvalid: String = $(handleInvalid)

/** @group setParam */
@Since("2.1.0")
def sethandleInvalid(value: String): this.type = set(handleInvalid, value)
setDefault(handleInvalid, "error")
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
setDefault(handleInvalid, Bucketizer.ERROR_INVALID)

@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema)
val (filteredDataset, keepInvalid) = {
if ("skip" == gethandleInvalid) {
if (getHandleInvalid == Bucketizer.SKIP_INVALID) {
// "skip" NaN option is set, will filter out NaN values in the dataset
(dataset.na.drop.toDF(), false)
(dataset.na.drop().toDF(), false)
} else {
(dataset.toDF(), "keep" == gethandleInvalid)
(dataset.toDF(), getHandleInvalid == Bucketizer.KEEP_INVALID)
}
}

Expand Down Expand Up @@ -140,6 +142,12 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
@Since("1.6.0")
object Bucketizer extends DefaultParamsReadable[Bucketizer] {

private[feature] val SKIP_INVALID: String = "skip"
private[feature] val ERROR_INVALID: String = "error"
private[feature] val KEEP_INVALID: String = "keep"
private[feature] val supportedHandleInvalid: Array[String] =
Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)

/**
* We require splits to be of length >= 3 and to be in strictly increasing order.
* No NaN split should be accepted.
Expand Down Expand Up @@ -173,9 +181,13 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] {
splits: Array[Double],
feature: Double,
keepInvalid: Boolean): Double = {
if (feature.isNaN && keepInvalid) {
// NaN data point found plus "keep" NaN option is set
splits.length - 1
if (feature.isNaN) {
if (keepInvalid) {
splits.length - 1
} else {
throw new SparkException("Bucketizer encountered NaN value. To handle or skip NaNs," +
" try setting Bucketizer.handleInvalid.")
}
} else if (feature == splits.last) {
splits.length - 2
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ private[feature] trait QuantileDiscretizerBase extends Params
/**
* Number of buckets (quantiles, or categories) into which data points are grouped. Must
* be >= 2.
*
* See also [[handleInvalid]], which can optionally create an additional bucket for NaN values.
*
* default: 2
* @group param
*/
Expand All @@ -61,19 +64,41 @@ private[feature] trait QuantileDiscretizerBase extends Params

/** @group getParam */
def getRelativeError: Double = getOrDefault(relativeError)

/**
* Param for how to handle invalid entries. Options are skip (filter out rows with
* invalid values), error (throw an error), or keep (keep invalid values in a special additional
* bucket).
* Default: "error"
* @group param
*/
@Since("2.1.0")
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle" +
"invalid entries. Options are skip (filter out rows with invalid values), " +
"error (throw an error), or keep (keep invalid values in a special additional bucket).",
ParamValidators.inArray(Bucketizer.supportedHandleInvalid))
setDefault(handleInvalid, Bucketizer.ERROR_INVALID)

/** @group getParam */
@Since("2.1.0")
def getHandleInvalid: String = $(handleInvalid)

}

/**
* `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned
* categorical features. The number of bins can be set using the `numBuckets` parameter. It is
* possible that the number of buckets used will be less than this value, for example, if there are
* too few distinct values of the input to create enough distinct quantiles. Note also that
* QuantileDiscretizer will raise an error when it finds NaN value in the dataset, but user can
* also choose to either keep or remove NaN values within the dataset by setting handleInvalid.
* If user chooses to keep NaN values, they will be handled specially and placed into their own
* possible that the number of buckets used will be smaller than this value, for example, if there
* are too few distinct values of the input to create enough distinct quantiles.
*
* NaN handling: Note also that
* QuantileDiscretizer will raise an error when it finds NaN values in the dataset, but the user can
* also choose to either keep or remove NaN values within the dataset by setting `handleInvalid`.
* If the user chooses to keep NaN values, they will be handled specially and placed into their own
* bucket, for example, if 4 buckets are used, then non-NaN data will be put into buckets[0-3],
* but NaNs will be counted in a special bucket[4].
* The bin ranges are chosen using an approximate algorithm (see the documentation for
*
* Algorithm: The bin ranges are chosen using an approximate algorithm (see the documentation for
* [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]]
* for a detailed description). The precision of the approximation can be controlled with the
* `relativeError` parameter. The lower and upper bin bounds will be `-Infinity` and `+Infinity`,
Expand Down Expand Up @@ -102,28 +127,9 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui
@Since("1.6.0")
def setOutputCol(value: String): this.type = set(outputCol, value)

/**
* Param for how to handle invalid entries. Options are skip (which will filter out rows with
* invalid values), or error (which will throw an error), or keep (which will keep the invalid
* values in certain way). Default behaviour is to report an error for invalid entries.
* Default: "error"
* @group param
*/
@Since("2.1.0")
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle" +
"invalid entries. Options are skip (which will filter out rows with invalid values), or" +
"error (which will throw an error), or keep (which will keep the invalid values" +
" in certain way). Default behaviour is to report an error for invalid entries.",
ParamValidators.inArray(Array("skip", "error", "keep")))

/** @group getParam */
@Since("2.1.0")
def gethandleInvalid: String = $(handleInvalid)

/** @group setParam */
@Since("2.1.0")
def sethandleInvalid(value: String): this.type = set(handleInvalid, value)
setDefault(handleInvalid, "error")
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)

@Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
Expand Down Expand Up @@ -151,7 +157,7 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui
}
val bucketizer = new Bucketizer(uid)
.setSplits(distinctSplits.sorted)
.sethandleInvalid($(handleInvalid))
.setHandleInvalid($(handleInvalid))
copyValues(bucketizer.setParent(this))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,21 +98,33 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setInputCol("feature")
.setOutputCol("result")
.setSplits(splits)
.sethandleInvalid("keep")

bucketizer.setHandleInvalid("keep")
bucketizer.transform(dataFrame).select("result", "expected").collect().foreach {
case Row(x: Double, y: Double) =>
assert(x === y,
s"The feature value is not correct after bucketing. Expected $y but found $x")
}

bucketizer.setHandleInvalid("skip")
val skipResults: Array[Double] = bucketizer.transform(dataFrame)
.select("result").as[Double].collect()
assert(skipResults.length === 7)
assert(skipResults.forall(_ !== 4.0))

bucketizer.setHandleInvalid("error")
withClue("Bucketizer should throw error when setHandleInvalid=error and given NaN values") {
intercept[SparkException] {
bucketizer.transform(dataFrame).collect()
}
}
}

test("Bucket continuous features, with NaN splits") {
val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity, Double.NaN)
withClue("Invalid NaN split was not caught as an invalid split!") {
withClue("Invalid NaN split was not caught during Bucketizer initialization") {
intercept[IllegalArgumentException] {
val bucketizer: Bucketizer = new Bucketizer()
.setSplits(splits)
new Bucketizer().setSplits(splits)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.ml.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql._
Expand Down Expand Up @@ -85,9 +85,16 @@ class QuantileDiscretizerSuite
.setOutputCol("result")
.setNumBuckets(numBuckets)

withClue("QuantileDiscretizer with handleInvalid=error should throw exception for NaN values") {
val dataFrame: DataFrame = validData.toSeq.toDF("input")
intercept[SparkException] {
discretizer.fit(dataFrame).transform(dataFrame).collect()
}
}

List(("keep", expectedKeep), ("skip", expectedSkip)).foreach{
case(u, v) =>
discretizer.sethandleInvalid(u)
discretizer.setHandleInvalid(u)
val dataFrame: DataFrame = validData.zip(v).toSeq.toDF("input", "expected")
val result = discretizer.fit(dataFrame).transform(dataFrame)
result.select("result", "expected").collect().foreach {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
assert(math.abs(d2 - 2 * q2 * n) < error_double)
}
// test approxQuantile on NaN values
val dfNaN = Array(Double.NaN, 1.0, Double.NaN, Double.NaN).toSeq.toDF("input")
val resNaN = dfNaN.stat.approxQuantile("input", Array(q1, q2), epsilons(0))
assert(resNaN.count(_.isNaN) == 0)
val dfNaN = Seq(Double.NaN, 1.0, Double.NaN, Double.NaN).toDF("input")
val resNaN = dfNaN.stat.approxQuantile("input", Array(q1, q2), epsilons.head)
assert(resNaN.count(_.isNaN) === 0)
}

test("crosstab") {
Expand Down