From e3a5ff4313936283e4057bfad597f75a979a1583 Mon Sep 17 00:00:00 2001 From: zhengruifeng Date: Wed, 18 Sep 2019 18:26:48 +0800 Subject: [PATCH 1/3] create pr --- .../apache/spark/ml/feature/Binarizer.scala | 42 ++++++++++++------- .../spark/ml/feature/BinarizerSuite.scala | 14 +++++++ 2 files changed, 40 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index 2b0862c60fdf..46eb716bcc45 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -75,28 +75,38 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) val schema = dataset.schema val inputType = schema($(inputCol)).dataType val td = $(threshold) - - val binarizerDouble = udf { in: Double => if (in > td) 1.0 else 0.0 } - val binarizerVector = udf { (data: Vector) => - val indices = ArrayBuilder.make[Int] - val values = ArrayBuilder.make[Double] - - data.foreachActive { (index, value) => - if (value > td) { - indices += index - values += 1.0 - } - } - - Vectors.sparse(data.size, indices.result(), values.result()).compressed - } - val metadata = outputSchema($(outputCol)).metadata inputType match { case DoubleType => + val binarizerDouble = udf { in: Double => if (in > td) 1.0 else 0.0 } dataset.select(col("*"), binarizerDouble(col($(inputCol))).as($(outputCol), metadata)) + case _: VectorUDT => + val func = (vector: Vector) => { + val indices = ArrayBuilder.make[Int] + val values = ArrayBuilder.make[Double] + + vector.foreachActive { (index, value) => + if (value > td) { + indices += index + values += 1.0 + } + } + + Vectors.sparse(vector.size, indices.result(), values.result()).compressed + } + + val binarizerVector = if (td < 0) { + udf { vector: Vector => + require(vector.isInstanceOf[DenseVector], + s"Threshold must be non-negative for operations on sparse vector $vector.") + func(vector) + } + } else { + udf(func) + } + dataset.select(col("*"), binarizerVector(col($(inputCol))).as($(outputCol), metadata)) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 05d4a6ee2dab..e2ff4778deec 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import org.apache.spark.SparkException import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} @@ -101,6 +102,19 @@ class BinarizerSuite extends MLTest with DefaultReadWriteTest { } } + test("Binarizer should not support sparse vector with negative threshold") { + val threshold = -0.5 + val data = Seq((0, Vectors.sparse(3, Array(1), Array(0.5))), + (1, Vectors.dense(Array(0.0, 0.5, 0.0)))) + val df = data.toDF("id", "feature") + val binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(-0.5) + intercept[SparkException] { + binarizer.transform(df).count() + } + } test("read/write") { val t = new Binarizer() From cb52a09bf1887b54a3d271f861e0191c75e82e70 Mon Sep 17 00:00:00 2001 From: zhengruifeng Date: Wed, 18 Sep 2019 18:27:30 +0800 Subject: [PATCH 2/3] nit --- .../test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index e2ff4778deec..81091fb3bb77 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -103,7 +103,6 @@ class BinarizerSuite extends MLTest with DefaultReadWriteTest { } test("Binarizer should not support sparse vector with negative threshold") { - val threshold = -0.5 val data = Seq((0, Vectors.sparse(3, Array(1), Array(0.5))), (1, Vectors.dense(Array(0.0, 0.5, 0.0)))) val df = data.toDF("id", "feature") From 190c3b891387e2a137964c0cc3b5670dede23a25 Mon Sep 17 00:00:00 2001 From: zhengruifeng Date: Thu, 19 Sep 2019 11:18:16 +0800 Subject: [PATCH 3/3] swith to ML fashion --- .../apache/spark/ml/feature/Binarizer.scala | 32 +++++++++---------- .../spark/ml/feature/BinarizerSuite.scala | 15 +++++---- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index 46eb716bcc45..c4daf64dfc5f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -77,38 +77,38 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) val td = $(threshold) val metadata = outputSchema($(outputCol)).metadata - inputType match { + val binarizerUDF = inputType match { case DoubleType => - val binarizerDouble = udf { in: Double => if (in > td) 1.0 else 0.0 } - dataset.select(col("*"), binarizerDouble(col($(inputCol))).as($(outputCol), metadata)) + udf { in: Double => if (in > td) 1.0 else 0.0 } - case _: VectorUDT => - val func = (vector: Vector) => { + case _: VectorUDT if td >= 0 => + udf { vector: Vector => val indices = ArrayBuilder.make[Int] val values = ArrayBuilder.make[Double] - vector.foreachActive { (index, value) => if (value > td) { indices += index values += 1.0 } } - Vectors.sparse(vector.size, indices.result(), values.result()).compressed } - val binarizerVector = if (td < 0) { - udf { vector: Vector => - require(vector.isInstanceOf[DenseVector], - s"Threshold must be non-negative for operations on sparse vector $vector.") - func(vector) + case _: VectorUDT if td < 0 => + this.logWarning(s"Binarization operations on sparse dataset with negative threshold " + + s"$td will build a dense output, so take care when applying to sparse input.") + udf { vector: Vector => + val values = Array.fill(vector.size)(1.0) + vector.foreachActive { (index, value) => + if (value <= td) { + values(index) = 0.0 + } } - } else { - udf(func) + Vectors.dense(values).compressed } - - dataset.select(col("*"), binarizerVector(col($(inputCol))).as($(outputCol), metadata)) } + + dataset.withColumn($(outputCol), binarizerUDF(col($(inputCol))), metadata) } @Since("1.4.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 81091fb3bb77..91bec50fb904 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkException import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} @@ -102,16 +101,18 @@ class BinarizerSuite extends MLTest with DefaultReadWriteTest { } } - test("Binarizer should not support sparse vector with negative threshold") { - val data = Seq((0, Vectors.sparse(3, Array(1), Array(0.5))), - (1, Vectors.dense(Array(0.0, 0.5, 0.0)))) - val df = data.toDF("id", "feature") + test("Binarizer should support sparse vector with negative threshold") { + val data = Seq( + (Vectors.sparse(3, Array(1), Array(0.5)), Vectors.dense(Array(1.0, 1.0, 1.0))), + (Vectors.dense(Array(0.0, 0.5, 0.0)), Vectors.dense(Array(1.0, 1.0, 1.0)))) + val df = data.toDF("feature", "expected") val binarizer = new Binarizer() .setInputCol("feature") .setOutputCol("binarized_feature") .setThreshold(-0.5) - intercept[SparkException] { - binarizer.transform(df).count() + binarizer.transform(df).select("binarized_feature", "expected").collect().foreach { + case Row(x: Vector, y: Vector) => + assert(x == y, "The feature value is not correct after binarization.") } }