From 106f8cfc082cc370653da1b297064e4246ad5d1c Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 30 Jan 2018 11:55:34 -0800 Subject: [PATCH 1/6] [SPARK-23265][SQL]Update multi-column error handling logic in QuantileDiscretizer --- .../ml/feature/QuantileDiscretizer.scala | 33 +++++++++---------- .../ml/feature/QuantileDiscretizerSuite.scala | 29 ++++++++-------- 2 files changed, 30 insertions(+), 32 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 56e2c543d100a..8b3de14e25553 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -171,25 +171,24 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui @Since("2.3.0") def setOutputCols(value: Array[String]): this.type = set(outputCols, value) - private[feature] def getInOutCols: (Array[String], Array[String]) = { - require((isSet(inputCol) && isSet(outputCol) && !isSet(inputCols) && !isSet(outputCols)) || - (!isSet(inputCol) && !isSet(outputCol) && isSet(inputCols) && isSet(outputCols)), - "QuantileDiscretizer only supports setting either inputCol/outputCol or" + - "inputCols/outputCols." - ) - - if (isSet(inputCol)) { - (Array($(inputCol)), Array($(outputCol))) - } else { - require($(inputCols).length == $(outputCols).length, - "inputCols number do not match outputCols") - ($(inputCols), $(outputCols)) - } - } - @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { - val (inputColNames, outputColNames) = getInOutCols + ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol), + Seq(outputCols)) + + if (isSet(inputCols)) { + require(getInputCols.length == getOutputCols.length, + s"QuantileDiscretizer $this has mismatched Params " + + s"for multi-column transform. Params (inputCols, outputCols) should have " + + s"equal lengths, but they have different lengths: " + + s"(${getInputCols.length}, ${getOutputCols.length}).") + } + + val (inputColNames, outputColNames) = if (isSet(inputCols)) { + ($(inputCols).toSeq, $(outputCols).toSeq) + } else { + (Seq($(inputCol)), Seq($(outputCol))) + } val existingFields = schema.fields var outputFields = existingFields inputColNames.zip(outputColNames).foreach { case (inputColName, outputColName) => diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index b009038bbd833..55093917c5afb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -19,6 +19,9 @@ package org.apache.spark.ml.feature import org.apache.spark.ml.Pipeline import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql._ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { @@ -414,33 +417,29 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { assert(readDiscretizer.hasDefault(readDiscretizer.outputCol)) } - test("Multiple Columns: Both inputCol and inputCols are set") { + test("Multiple Columns: Mismatched sizes of inputCols / outputCols") { val spark = this.spark import spark.implicits._ val discretizer = new QuantileDiscretizer() - .setInputCol("input") - .setOutputCol("result") + .setInputCols(Array("input")) + .setOutputCols(Array("result1", "result2")) .setNumBuckets(3) - .setInputCols(Array("input1", "input2")) val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) .map(Tuple1.apply).toDF("input") - // When both inputCol and inputCols are set, we throw Exception. intercept[IllegalArgumentException] { discretizer.fit(df) } } - test("Multiple Columns: Mismatched sizes of inputCols / outputCols") { + test("Assert exception is thrown if both multi-column and single-column params are set") { val spark = this.spark import spark.implicits._ - val discretizer = new QuantileDiscretizer() - .setInputCols(Array("input")) - .setOutputCols(Array("result1", "result2")) - .setNumBuckets(3) - val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) - .map(Tuple1.apply).toDF("input") - intercept[IllegalArgumentException] { - discretizer.fit(df) - } + val df = Seq((0.5, 0.3), (0.5, -0.4)).toDF("feature1", "feature2") + ParamsSuite.testExclusiveParams(new QuantileDiscretizer, df, ("inputCol", "feature1"), + ("inputCols", Array("feature1", "feature2"))) + ParamsSuite.testExclusiveParams(new QuantileDiscretizer, df, ("inputCol", "feature1"), + ("outputCol", "result1"), ("outputCols", Array("result1", "result2"))) + // this should fail because at least one of inputCol and inputCols must be set + ParamsSuite.testExclusiveParams(new QuantileDiscretizer, df, ("outputCol", "feature1")) } } From d22093b64b306ce9c685c2f138d754caedc60382 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 30 Jan 2018 14:54:23 -0800 Subject: [PATCH 2/6] add check for numBucketsArray length --- .../spark/ml/feature/QuantileDiscretizer.scala | 7 +++++++ .../ml/feature/QuantileDiscretizerSuite.scala | 17 ++++++++++++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 8b3de14e25553..9a4e1a79c1402 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -182,6 +182,13 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui s"for multi-column transform. Params (inputCols, outputCols) should have " + s"equal lengths, but they have different lengths: " + s"(${getInputCols.length}, ${getOutputCols.length}).") + if (isSet(numBucketsArray)) { + require(getInputCols.length == getNumBucketsArray.length, + s"QuantileDiscretizer $this has mismatched Params " + + s"for multi-column transform. Params (inputCols, outputCols, numBucketsArray) " + + s"should have equal lengths, but they have different lengths: " + + s"(${getInputCols.length}, ${getOutputCols.length}, ${getNumBucketsArray.length}).") + } } val (inputColNames, outputColNames) = if (isSet(inputCols)) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 55093917c5afb..41e7c2025e11b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -417,7 +417,7 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { assert(readDiscretizer.hasDefault(readDiscretizer.outputCol)) } - test("Multiple Columns: Mismatched sizes of inputCols / outputCols") { + test("Multiple Columns: Mismatched sizes of inputCols/outputCols") { val spark = this.spark import spark.implicits._ val discretizer = new QuantileDiscretizer() @@ -431,6 +431,21 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { } } + test("Multiple Columns: Mismatched sizes of inputCols/numBucketsArray") { + val spark = this.spark + import spark.implicits._ + val discretizer = new QuantileDiscretizer() + .setInputCols(Array("input1", "input2")) + .setOutputCols(Array("result1", "result2")) + .setNumBucketsArray(Array(2, 5, 10)) + val data1 = Array(1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 2.0, 2.0, 2.0) + val data2 = Array(1.0, 2.0, 3.0, 1.0, 1.0, 1.0, 1.0, 3.0, 2.0, 3.0) + val df = data1.zip(data2).toSeq.toDF("input1", "input2") + intercept[IllegalArgumentException] { + discretizer.fit(df) + } + } + test("Assert exception is thrown if both multi-column and single-column params are set") { val spark = this.spark import spark.implicits._ From 9674c3cb8d0150018224d2fc0fb8c70994e54c65 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 30 Jan 2018 21:48:41 -0800 Subject: [PATCH 3/6] address comments --- .../spark/ml/feature/QuantileDiscretizer.scala | 5 +++++ .../ml/feature/QuantileDiscretizerSuite.scala | 14 ++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 9a4e1a79c1402..9eeaf00880ff5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -176,6 +176,11 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol), Seq(outputCols)) + if (isSet(inputCol)) { + require(!isSet(numBucketsArray), + s"numBucketsArray can't be set for single-column QuantileDiscretizer.") + } + if (isSet(inputCols)) { require(getInputCols.length == getOutputCols.length, s"QuantileDiscretizer $this has mismatched Params " + diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 41e7c2025e11b..7a12aacd6d950 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -446,6 +446,20 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { } } + test("Setting numBucketsArray for Single-Column QuantileDiscretizer") { + val spark = this.spark + import spark.implicits._ + val discretizer = new QuantileDiscretizer() + .setInputCol("input") + .setOutputCol("result") + .setNumBucketsArray(Array(2, 5)) + val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) + .map(Tuple1.apply).toDF("input") + intercept[IllegalArgumentException] { + discretizer.fit(df) + } + } + test("Assert exception is thrown if both multi-column and single-column params are set") { val spark = this.spark import spark.implicits._ From f68940bea9fd6e16c18d91bcb28631f88534288e Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 31 Jan 2018 10:42:28 -0800 Subject: [PATCH 4/6] address comments (2) --- .../spark/ml/feature/QuantileDiscretizer.scala | 2 ++ .../ml/feature/QuantileDiscretizerSuite.scala | 16 ++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 9eeaf00880ff5..fb498d4e5225c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -193,6 +193,8 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui s"for multi-column transform. Params (inputCols, outputCols, numBucketsArray) " + s"should have equal lengths, but they have different lengths: " + s"(${getInputCols.length}, ${getOutputCols.length}, ${getNumBucketsArray.length}).") + require(!isSet(numBuckets), + s"exactly one of numBuckets, numBucketsArray Params to be set, but both are set." ) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 7a12aacd6d950..2a7101fceb3eb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -446,6 +446,22 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { } } + test("Multiple Columns: Set both of numBuckets/numBucketsArray") { + val spark = this.spark + import spark.implicits._ + val discretizer = new QuantileDiscretizer() + .setInputCols(Array("input1", "input2")) + .setOutputCols(Array("result1", "result2")) + .setNumBucketsArray(Array(2, 5)) + .setNumBuckets(2) + val data1 = Array(1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 2.0, 2.0, 2.0) + val data2 = Array(1.0, 2.0, 3.0, 1.0, 1.0, 1.0, 1.0, 3.0, 2.0, 3.0) + val df = data1.zip(data2).toSeq.toDF("input1", "input2") + intercept[IllegalArgumentException] { + discretizer.fit(df) + } + } + test("Setting numBucketsArray for Single-Column QuantileDiscretizer") { val spark = this.spark import spark.implicits._ From db35f610f890445b767392a4e23c5db73b7e642d Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 24 Apr 2018 22:20:43 -0700 Subject: [PATCH 5/6] resolve conflict --- .../apache/spark/ml/feature/QuantileDiscretizerSuite.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 2a7101fceb3eb..d93a8a2fbe78a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -18,10 +18,8 @@ package org.apache.spark.ml.feature import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql._ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { From 73aeb1c50c25ea7c92f4b63d606f2efeec8871e1 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 13 Jun 2018 16:03:34 -0700 Subject: [PATCH 6/6] adding a test to check setting inputCol only (without setting outputCol) works OK --- .../ml/feature/QuantileDiscretizerSuite.scala | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index d93a8a2fbe78a..66b6b4287cf0b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -485,4 +485,22 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { // this should fail because at least one of inputCol and inputCols must be set ParamsSuite.testExclusiveParams(new QuantileDiscretizer, df, ("outputCol", "feature1")) } + + test("Setting inputCol without setting outputCol") { + val spark = this.spark + import spark.implicits._ + + val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) + .map(Tuple1.apply).toDF("input") + val numBuckets = 2 + val discretizer = new QuantileDiscretizer() + .setInputCol("input") + .setNumBuckets(numBuckets) + val model = discretizer.fit(df) + val result = model.transform(df) + + val observedNumBuckets = result.select(discretizer.getOutputCol).distinct.count + assert(observedNumBuckets === numBuckets, + "Observed number of buckets does not equal expected number of buckets.") + } }