diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 56e2c543d100a..fb498d4e5225c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -171,25 +171,38 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui @Since("2.3.0") def setOutputCols(value: Array[String]): this.type = set(outputCols, value) - private[feature] def getInOutCols: (Array[String], Array[String]) = { - require((isSet(inputCol) && isSet(outputCol) && !isSet(inputCols) && !isSet(outputCols)) || - (!isSet(inputCol) && !isSet(outputCol) && isSet(inputCols) && isSet(outputCols)), - "QuantileDiscretizer only supports setting either inputCol/outputCol or" + - "inputCols/outputCols." - ) + @Since("1.6.0") + override def transformSchema(schema: StructType): StructType = { + ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol), + Seq(outputCols)) if (isSet(inputCol)) { - (Array($(inputCol)), Array($(outputCol))) - } else { - require($(inputCols).length == $(outputCols).length, - "inputCols number do not match outputCols") - ($(inputCols), $(outputCols)) + require(!isSet(numBucketsArray), + s"numBucketsArray can't be set for single-column QuantileDiscretizer.") } - } - @Since("1.6.0") - override def transformSchema(schema: StructType): StructType = { - val (inputColNames, outputColNames) = getInOutCols + if (isSet(inputCols)) { + require(getInputCols.length == getOutputCols.length, + s"QuantileDiscretizer $this has mismatched Params " + + s"for multi-column transform. Params (inputCols, outputCols) should have " + + s"equal lengths, but they have different lengths: " + + s"(${getInputCols.length}, ${getOutputCols.length}).") + if (isSet(numBucketsArray)) { + require(getInputCols.length == getNumBucketsArray.length, + s"QuantileDiscretizer $this has mismatched Params " + + s"for multi-column transform. Params (inputCols, outputCols, numBucketsArray) " + + s"should have equal lengths, but they have different lengths: " + + s"(${getInputCols.length}, ${getOutputCols.length}, ${getNumBucketsArray.length}).") + require(!isSet(numBuckets), + s"exactly one of numBuckets, numBucketsArray Params to be set, but both are set." ) + } + } + + val (inputColNames, outputColNames) = if (isSet(inputCols)) { + ($(inputCols).toSeq, $(outputCols).toSeq) + } else { + (Seq($(inputCol)), Seq($(outputCol))) + } val existingFields = schema.fields var outputFields = existingFields inputColNames.zip(outputColNames).foreach { case (inputColName, outputColName) => diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index b009038bbd833..66b6b4287cf0b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql._ @@ -414,33 +415,92 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { assert(readDiscretizer.hasDefault(readDiscretizer.outputCol)) } - test("Multiple Columns: Both inputCol and inputCols are set") { + test("Multiple Columns: Mismatched sizes of inputCols/outputCols") { val spark = this.spark import spark.implicits._ val discretizer = new QuantileDiscretizer() - .setInputCol("input") - .setOutputCol("result") + .setInputCols(Array("input")) + .setOutputCols(Array("result1", "result2")) .setNumBuckets(3) - .setInputCols(Array("input1", "input2")) val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) .map(Tuple1.apply).toDF("input") - // When both inputCol and inputCols are set, we throw Exception. intercept[IllegalArgumentException] { discretizer.fit(df) } } - test("Multiple Columns: Mismatched sizes of inputCols / outputCols") { + test("Multiple Columns: Mismatched sizes of inputCols/numBucketsArray") { val spark = this.spark import spark.implicits._ val discretizer = new QuantileDiscretizer() - .setInputCols(Array("input")) + .setInputCols(Array("input1", "input2")) .setOutputCols(Array("result1", "result2")) - .setNumBuckets(3) + .setNumBucketsArray(Array(2, 5, 10)) + val data1 = Array(1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 2.0, 2.0, 2.0) + val data2 = Array(1.0, 2.0, 3.0, 1.0, 1.0, 1.0, 1.0, 3.0, 2.0, 3.0) + val df = data1.zip(data2).toSeq.toDF("input1", "input2") + intercept[IllegalArgumentException] { + discretizer.fit(df) + } + } + + test("Multiple Columns: Set both of numBuckets/numBucketsArray") { + val spark = this.spark + import spark.implicits._ + val discretizer = new QuantileDiscretizer() + .setInputCols(Array("input1", "input2")) + .setOutputCols(Array("result1", "result2")) + .setNumBucketsArray(Array(2, 5)) + .setNumBuckets(2) + val data1 = Array(1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 2.0, 2.0, 2.0) + val data2 = Array(1.0, 2.0, 3.0, 1.0, 1.0, 1.0, 1.0, 3.0, 2.0, 3.0) + val df = data1.zip(data2).toSeq.toDF("input1", "input2") + intercept[IllegalArgumentException] { + discretizer.fit(df) + } + } + + test("Setting numBucketsArray for Single-Column QuantileDiscretizer") { + val spark = this.spark + import spark.implicits._ + val discretizer = new QuantileDiscretizer() + .setInputCol("input") + .setOutputCol("result") + .setNumBucketsArray(Array(2, 5)) val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) .map(Tuple1.apply).toDF("input") intercept[IllegalArgumentException] { discretizer.fit(df) } } + + test("Assert exception is thrown if both multi-column and single-column params are set") { + val spark = this.spark + import spark.implicits._ + val df = Seq((0.5, 0.3), (0.5, -0.4)).toDF("feature1", "feature2") + ParamsSuite.testExclusiveParams(new QuantileDiscretizer, df, ("inputCol", "feature1"), + ("inputCols", Array("feature1", "feature2"))) + ParamsSuite.testExclusiveParams(new QuantileDiscretizer, df, ("inputCol", "feature1"), + ("outputCol", "result1"), ("outputCols", Array("result1", "result2"))) + // this should fail because at least one of inputCol and inputCols must be set + ParamsSuite.testExclusiveParams(new QuantileDiscretizer, df, ("outputCol", "feature1")) + } + + test("Setting inputCol without setting outputCol") { + val spark = this.spark + import spark.implicits._ + + val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) + .map(Tuple1.apply).toDF("input") + val numBuckets = 2 + val discretizer = new QuantileDiscretizer() + .setInputCol("input") + .setNumBuckets(numBuckets) + val model = discretizer.fit(df) + val result = model.transform(df) + + val observedNumBuckets = result.select(discretizer.getOutputCol).distinct.count + assert(observedNumBuckets === numBuckets, + "Observed number of buckets does not equal expected number of buckets.") + } }