diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 3c7bcf7590e6..2a3837195df8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -25,6 +25,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ +import org.apache.spark.sql.UserDefinedFunction import org.apache.spark.sql.types._ /** @@ -90,7 +91,7 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] * account of the embedded param map. So the param values should be determined solely by the input * param map. */ - protected def createTransformFunc: IN => OUT + protected def transformFunc: UserDefinedFunction /** * Returns the data type of the output column. @@ -115,8 +116,7 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) - dataset.withColumn($(outputCol), - callUDF(this.createTransformFunc, outputDataType, dataset($(inputCol)))) + dataset.withColumn($(outputCol), this.transformFunc(col($(inputCol)))) } override def copy(extra: ParamMap): T = defaultCopy(extra) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala index a6f878151de7..70bade556c85 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala @@ -24,6 +24,8 @@ import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.BooleanParam import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.UserDefinedFunction import org.apache.spark.sql.types.DataType /** @@ -57,11 +59,13 @@ class DCT(override val uid: String) setDefault(inverse -> false) - override protected def createTransformFunc: Vector => Vector = { vec => - val result = vec.toArray - val jTransformer = new DoubleDCT_1D(result.length) - if ($(inverse)) jTransformer.inverse(result, true) else jTransformer.forward(result, true) - Vectors.dense(result) + override protected def transformFunc: UserDefinedFunction = { + udf { input: Vector => + val result = input.toArray + val jTransformer = new DoubleDCT_1D(result.length) + if ($(inverse)) jTransformer.inverse(result, true) else jTransformer.forward(result, true) + Vectors.dense(result) + } } override protected def validateInputType(inputType: DataType): Unit = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala index 07a12df32035..63183dbc8d42 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala @@ -23,6 +23,8 @@ import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.UserDefinedFunction import org.apache.spark.sql.types.DataType /** @@ -49,10 +51,12 @@ class ElementwiseProduct(override val uid: String) /** @group getParam */ def getScalingVec: Vector = getOrDefault(scalingVec) - override protected def createTransformFunc: Vector => Vector = { + override protected def transformFunc: UserDefinedFunction = { require(params.contains(scalingVec), s"transformation requires a weight vector") - val elemScaler = new feature.ElementwiseProduct($(scalingVec)) - elemScaler.transform + udf { input: Vector => + val elemScaler = new feature.ElementwiseProduct($(scalingVec)) + elemScaler.transform(input) + } } override protected def outputDataType: DataType = new VectorUDT() diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala index f8bc7e3f0c03..30b79241eace 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala @@ -21,6 +21,8 @@ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param._ import org.apache.spark.ml.util._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.UserDefinedFunction import org.apache.spark.sql.types.{ArrayType, DataType, StringType} /** @@ -56,8 +58,10 @@ class NGram(override val uid: String) setDefault(n -> 2) - override protected def createTransformFunc: Seq[String] => Seq[String] = { - _.iterator.sliding($(n)).withPartial(false).map(_.mkString(" ")).toSeq + override protected def transformFunc: UserDefinedFunction = { + udf { input: Seq[String] => + input.iterator.sliding($(n)).withPartial(false).map(_.mkString(" ")).toSeq + } } override protected def validateInputType(inputType: DataType): Unit = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index a603b3f83320..b74bb16f053d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -23,6 +23,8 @@ import org.apache.spark.ml.param.{DoubleParam, ParamValidators} import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.UserDefinedFunction import org.apache.spark.sql.types.DataType /** @@ -50,9 +52,11 @@ class Normalizer(override val uid: String) /** @group setParam */ def setP(value: Double): this.type = set(p, value) - override protected def createTransformFunc: Vector => Vector = { - val normalizer = new feature.Normalizer($(p)) - normalizer.transform + override protected def transformFunc: UserDefinedFunction = { + udf { input: Vector => + val normalizer = new feature.Normalizer($(p)) + normalizer.transform(input) + } } override protected def outputDataType: DataType = new VectorUDT() diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index 42b26c8ee836..49549c79e4c9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -24,7 +24,9 @@ import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.UserDefinedFunction /** * :: Experimental :: @@ -56,8 +58,10 @@ class PolynomialExpansion(override val uid: String) /** @group setParam */ def setDegree(value: Int): this.type = set(degree, value) - override protected def createTransformFunc: Vector => Vector = { v => - PolynomialExpansion.expand(v, $(degree)) + override protected def transformFunc: UserDefinedFunction = { + udf { input: Vector => + PolynomialExpansion.expand(input, $(degree)) + } } override protected def outputDataType: DataType = new VectorUDT() diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 8456a0e91580..1eb457115b5c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -21,6 +21,8 @@ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param._ import org.apache.spark.ml.util._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.UserDefinedFunction import org.apache.spark.sql.types.{ArrayType, DataType, StringType} /** @@ -35,8 +37,10 @@ class Tokenizer(override val uid: String) def this() = this(Identifiable.randomUID("tok")) - override protected def createTransformFunc: String => Seq[String] = { - _.toLowerCase.split("\\s") + override protected def transformFunc: UserDefinedFunction = { + udf { input: String => + input.toLowerCase.split("\\s") + } } override protected def validateInputType(inputType: DataType): Unit = { @@ -124,12 +128,14 @@ class RegexTokenizer(override val uid: String) setDefault(minTokenLength -> 1, gaps -> true, pattern -> "\\s+", toLowercase -> true) - override protected def createTransformFunc: String => Seq[String] = { originStr => - val re = $(pattern).r - val str = if ($(toLowercase)) originStr.toLowerCase() else originStr - val tokens = if ($(gaps)) re.split(str).toSeq else re.findAllIn(str).toSeq - val minLength = $(minTokenLength) - tokens.filter(_.length >= minLength) + override protected def transformFunc: UserDefinedFunction = { + udf { input: String => + val re = $(pattern).r + val str = if ($(toLowercase)) input.toLowerCase() else input + val tokens = if ($(gaps)) re.split(str).toSeq else re.findAllIn(str).toSeq + val minLength = $(minTokenLength) + tokens.filter(_.length >= minLength) + } } override protected def validateInputType(inputType: DataType): Unit = {