Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -171,25 +171,38 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui
@Since("2.3.0")
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)

private[feature] def getInOutCols: (Array[String], Array[String]) = {
require((isSet(inputCol) && isSet(outputCol) && !isSet(inputCols) && !isSet(outputCols)) ||
(!isSet(inputCol) && !isSet(outputCol) && isSet(inputCols) && isSet(outputCols)),
"QuantileDiscretizer only supports setting either inputCol/outputCol or" +
"inputCols/outputCols."
)
@Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting numBucketsArray when single-column can be an error. Since checkSingleVsMultiColumnParams doesn't support this usage, I think we may need to check it here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comment. I will add the check.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checkSingleVsMultiColumnParams can used like ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol, splits), Seq(outputCols, splitsArray)).

If we want numBuckets and numBucketsArray to be exclusively set, you can use checkSingleVsMultiColumnParams like that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @viirya for your quick reply!
The reason I didn't use

    ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol, numBuckets),
      Seq(outputCols, numBucketsArray))

is that we can actually setNumBuckets for multi columns. I looked the previous conversion, we have decided to allow setNumBuckets for multi columns. In the multi columns case

If however the numBucketsArray param is unset but the numBuckets param is set, 
the user is saying they want the same numBuckets across all columns, then we can 
use the multi-column version of approxQuantiles in this case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, I see. thanks!

Seq(outputCols))

if (isSet(inputCol)) {
(Array($(inputCol)), Array($(outputCol)))
} else {
require($(inputCols).length == $(outputCols).length,
"inputCols number do not match outputCols")
($(inputCols), $(outputCols))
require(!isSet(numBucketsArray),
s"numBucketsArray can't be set for single-column QuantileDiscretizer.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check if numBucketsArray and numBuckets are set at the same time?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about if I should add this check when I changed the code yesterday:
If both numBucketsArray and numBuckets are set, the current code will only take numBucketsArray. Also, numBuckets always has a default value even if it's not set. So yesterday I decided not to add the check.
But I guess it's better to tight the code to make user not set numBuckets explicitly when numBucketsArray is set. I will make the change to add the check.

}
}

@Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
val (inputColNames, outputColNames) = getInOutCols
if (isSet(inputCols)) {
require(getInputCols.length == getOutputCols.length,
s"QuantileDiscretizer $this has mismatched Params " +
s"for multi-column transform. Params (inputCols, outputCols) should have " +
s"equal lengths, but they have different lengths: " +
s"(${getInputCols.length}, ${getOutputCols.length}).")
if (isSet(numBucketsArray)) {
require(getInputCols.length == getNumBucketsArray.length,
s"QuantileDiscretizer $this has mismatched Params " +
s"for multi-column transform. Params (inputCols, outputCols, numBucketsArray) " +
s"should have equal lengths, but they have different lengths: " +
s"(${getInputCols.length}, ${getOutputCols.length}, ${getNumBucketsArray.length}).")
require(!isSet(numBuckets),
s"exactly one of numBuckets, numBucketsArray Params to be set, but both are set." )
}
}

val (inputColNames, outputColNames) = if (isSet(inputCols)) {
($(inputCols).toSeq, $(outputCols).toSeq)
} else {
(Seq($(inputCol)), Seq($(outputCol)))
}
val existingFields = schema.fields
var outputFields = existingFields
inputColNames.zip(outputColNames).foreach { case (inputColName, outputColName) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.ml.feature

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.sql._

Expand Down Expand Up @@ -414,33 +415,92 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest {
assert(readDiscretizer.hasDefault(readDiscretizer.outputCol))
}

test("Multiple Columns: Both inputCol and inputCols are set") {
test("Multiple Columns: Mismatched sizes of inputCols/outputCols") {
val spark = this.spark
import spark.implicits._
val discretizer = new QuantileDiscretizer()
.setInputCol("input")
.setOutputCol("result")
.setInputCols(Array("input"))
.setOutputCols(Array("result1", "result2"))
.setNumBuckets(3)
.setInputCols(Array("input1", "input2"))
val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
.map(Tuple1.apply).toDF("input")
// When both inputCol and inputCols are set, we throw Exception.
intercept[IllegalArgumentException] {
discretizer.fit(df)
}
}

test("Multiple Columns: Mismatched sizes of inputCols / outputCols") {
test("Multiple Columns: Mismatched sizes of inputCols/numBucketsArray") {
val spark = this.spark
import spark.implicits._
val discretizer = new QuantileDiscretizer()
.setInputCols(Array("input"))
.setInputCols(Array("input1", "input2"))
.setOutputCols(Array("result1", "result2"))
.setNumBuckets(3)
.setNumBucketsArray(Array(2, 5, 10))
val data1 = Array(1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 2.0, 2.0, 2.0)
val data2 = Array(1.0, 2.0, 3.0, 1.0, 1.0, 1.0, 1.0, 3.0, 2.0, 3.0)
val df = data1.zip(data2).toSeq.toDF("input1", "input2")
intercept[IllegalArgumentException] {
discretizer.fit(df)
}
}

test("Multiple Columns: Set both of numBuckets/numBucketsArray") {
val spark = this.spark
import spark.implicits._
val discretizer = new QuantileDiscretizer()
.setInputCols(Array("input1", "input2"))
.setOutputCols(Array("result1", "result2"))
.setNumBucketsArray(Array(2, 5))
.setNumBuckets(2)
val data1 = Array(1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 2.0, 2.0, 2.0)
val data2 = Array(1.0, 2.0, 3.0, 1.0, 1.0, 1.0, 1.0, 3.0, 2.0, 3.0)
val df = data1.zip(data2).toSeq.toDF("input1", "input2")
intercept[IllegalArgumentException] {
discretizer.fit(df)
}
}

test("Setting numBucketsArray for Single-Column QuantileDiscretizer") {
val spark = this.spark
import spark.implicits._
val discretizer = new QuantileDiscretizer()
.setInputCol("input")
.setOutputCol("result")
.setNumBucketsArray(Array(2, 5))
val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
.map(Tuple1.apply).toDF("input")
intercept[IllegalArgumentException] {
discretizer.fit(df)
}
}

test("Assert exception is thrown if both multi-column and single-column params are set") {
val spark = this.spark
import spark.implicits._
val df = Seq((0.5, 0.3), (0.5, -0.4)).toDF("feature1", "feature2")
ParamsSuite.testExclusiveParams(new QuantileDiscretizer, df, ("inputCol", "feature1"),
("inputCols", Array("feature1", "feature2")))
ParamsSuite.testExclusiveParams(new QuantileDiscretizer, df, ("inputCol", "feature1"),
("outputCol", "result1"), ("outputCols", Array("result1", "result2")))
// this should fail because at least one of inputCol and inputCols must be set
ParamsSuite.testExclusiveParams(new QuantileDiscretizer, df, ("outputCol", "feature1"))
}

test("Setting inputCol without setting outputCol") {
val spark = this.spark
import spark.implicits._

val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
.map(Tuple1.apply).toDF("input")
val numBuckets = 2
val discretizer = new QuantileDiscretizer()
.setInputCol("input")
.setNumBuckets(numBuckets)
val model = discretizer.fit(df)
val result = model.transform(df)

val observedNumBuckets = result.select(discretizer.getOutputCol).distinct.count
assert(observedNumBuckets === numBuckets,
"Observed number of buckets does not equal expected number of buckets.")
}
}