From 7d93d5ab3d84d04c0636d9128215881f9d00a479 Mon Sep 17 00:00:00 2001 From: Wayne Zhang Date: Tue, 2 May 2017 21:02:33 -0700 Subject: [PATCH 1/4] allow bucketizer to work for non-double column --- .../apache/spark/ml/feature/Bucketizer.scala | 4 ++-- .../spark/ml/feature/BucketizerSuite.scala | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index d1f3b2af1e48..bb8f2a3aa5f7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -116,7 +116,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid) } - val newCol = bucketizer(filteredDataset($(inputCol))) + val newCol = bucketizer(filteredDataset($(inputCol)).cast(DoubleType)) val newField = prepOutputField(filteredDataset.schema) filteredDataset.withColumn($(outputCol), newCol, newField.metadata) } @@ -130,7 +130,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) + SchemaUtils.checkNumericType(schema, $(inputCol)) SchemaUtils.appendColumn(schema, prepOutputField(schema)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index aac29137d791..5d725e7af42b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -162,6 +162,24 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setSplits(Array(0.1, 0.8, 0.9)) testDefaultReadWrite(t) } + + test("Bucket non-double features") { + val splits = Array(-3.0, 0.0, 3.0) + val validData: Array[Int] = Array(-2, -1, 0, 1, 2) + val expectedBuckets = Array(0.0, 0.0, 0.0, 1.0, 1.0) + val dataFrame: DataFrame = validData.zip(expectedBuckets).toSeq.toDF("feature", "expected") + + val bucketizer: Bucketizer = new Bucketizer() + .setInputCol("feature") + .setOutputCol("result") + .setSplits(splits) + + bucketizer.transform(dataFrame).select("result", "expected").collect().foreach { + case Row(x: Double, y: Double) => + assert(x === y, + s"The feature value is not correct after bucketing. Expected $y but found $x") + } + } } private object BucketizerSuite extends SparkFunSuite { From a86fbde996afb1aed49c10a5785d934b4c12b2a2 Mon Sep 17 00:00:00 2001 From: Wayne Zhang Date: Tue, 2 May 2017 23:01:21 -0700 Subject: [PATCH 2/4] update test for non-Double types --- .../spark/ml/feature/BucketizerSuite.scala | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 5d725e7af42b..1e0a868774aa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -26,6 +26,9 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -165,19 +168,23 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa test("Bucket non-double features") { val splits = Array(-3.0, 0.0, 3.0) - val validData: Array[Int] = Array(-2, -1, 0, 1, 2) - val expectedBuckets = Array(0.0, 0.0, 0.0, 1.0, 1.0) - val dataFrame: DataFrame = validData.zip(expectedBuckets).toSeq.toDF("feature", "expected") + val data = Array(-2.0, -1.0, 0.0, 1.0, 2.0) + val expectedBuckets = Array(0.0, 0.0, 1.0, 1.0, 1.0) + val dataFrame: DataFrame = data.zip(expectedBuckets).toSeq.toDF("feature", "expected") val bucketizer: Bucketizer = new Bucketizer() .setInputCol("feature") .setOutputCol("result") .setSplits(splits) - bucketizer.transform(dataFrame).select("result", "expected").collect().foreach { - case Row(x: Double, y: Double) => - assert(x === y, - s"The feature value is not correct after bucketing. Expected $y but found $x") + val types = Seq(ShortType, IntegerType, LongType, FloatType) + for (mType <- types) { + val df = dataFrame.withColumn("feature", col("feature").cast(mType)) + bucketizer.transform(df).select("result", "expected").collect().foreach { + case Row(x: Double, y: Double) => + assert(x === y, "The feature value is not correct after bucketing in type " + + mType.toString + ". " + s"Expected $y but found $x.") + } } } } From b4a5b6f5e40a5ad711c9a759227584a17e35fc24 Mon Sep 17 00:00:00 2001 From: Wayne Zhang Date: Wed, 3 May 2017 09:21:46 -0700 Subject: [PATCH 3/4] update test --- .../scala/org/apache/spark/ml/feature/BucketizerSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 1e0a868774aa..ab0358da4cfa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -166,7 +166,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa testDefaultReadWrite(t) } - test("Bucket non-double features") { + test("Bucket non-double numeric features") { val splits = Array(-3.0, 0.0, 3.0) val data = Array(-2.0, -1.0, 0.0, 1.0, 2.0) val expectedBuckets = Array(0.0, 0.0, 1.0, 1.0, 1.0) From d149350d73fcc4f576880730661fd688fed8fc07 Mon Sep 17 00:00:00 2001 From: Wayne Zhang Date: Thu, 4 May 2017 09:41:13 -0700 Subject: [PATCH 4/4] include all numeric types --- .../org/apache/spark/ml/feature/BucketizerSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index ab0358da4cfa..420fb17ddce8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ - class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { import testImplicits._ @@ -166,7 +165,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa testDefaultReadWrite(t) } - test("Bucket non-double numeric features") { + test("Bucket numeric features") { val splits = Array(-3.0, 0.0, 3.0) val data = Array(-2.0, -1.0, 0.0, 1.0, 2.0) val expectedBuckets = Array(0.0, 0.0, 1.0, 1.0, 1.0) @@ -177,12 +176,13 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCol("result") .setSplits(splits) - val types = Seq(ShortType, IntegerType, LongType, FloatType) + val types = Seq(ShortType, IntegerType, LongType, FloatType, DoubleType, + ByteType, DecimalType(10, 0)) for (mType <- types) { val df = dataFrame.withColumn("feature", col("feature").cast(mType)) bucketizer.transform(df).select("result", "expected").collect().foreach { case Row(x: Double, y: Double) => - assert(x === y, "The feature value is not correct after bucketing in type " + + assert(x === y, "The result is not correct after bucketing in type " + mType.toString + ". " + s"Expected $y but found $x.") } }