Skip to content

Commit b4c4329

Browse files
committed
Use udf to replace callUDF for ML
1 parent ad5b7cf commit b4c4329

File tree

7 files changed

+52
-26
lines changed

7 files changed

+52
-26
lines changed

mllib/src/main/scala/org/apache/spark/ml/Transformer.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.apache.spark.ml.param._
2525
import org.apache.spark.ml.param.shared._
2626
import org.apache.spark.sql.DataFrame
2727
import org.apache.spark.sql.functions._
28+
import org.apache.spark.sql.UserDefinedFunction
2829
import org.apache.spark.sql.types._
2930

3031
/**
@@ -90,7 +91,7 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
9091
* account of the embedded param map. So the param values should be determined solely by the input
9192
* param map.
9293
*/
93-
protected def createTransformFunc: IN => OUT
94+
protected def transformFunc: UserDefinedFunction
9495

9596
/**
9697
* Returns the data type of the output column.
@@ -115,8 +116,7 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
115116

116117
override def transform(dataset: DataFrame): DataFrame = {
117118
transformSchema(dataset.schema, logging = true)
118-
dataset.withColumn($(outputCol),
119-
callUDF(this.createTransformFunc, outputDataType, dataset($(inputCol))))
119+
dataset.withColumn($(outputCol), this.transformFunc(col($(inputCol))))
120120
}
121121

122122
override def copy(extra: ParamMap): T = defaultCopy(extra)

mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ import org.apache.spark.ml.UnaryTransformer
2424
import org.apache.spark.ml.param.BooleanParam
2525
import org.apache.spark.ml.util._
2626
import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
27+
import org.apache.spark.sql.functions._
28+
import org.apache.spark.sql.UserDefinedFunction
2729
import org.apache.spark.sql.types.DataType
2830

2931
/**
@@ -57,11 +59,13 @@ class DCT(override val uid: String)
5759

5860
setDefault(inverse -> false)
5961

60-
override protected def createTransformFunc: Vector => Vector = { vec =>
61-
val result = vec.toArray
62-
val jTransformer = new DoubleDCT_1D(result.length)
63-
if ($(inverse)) jTransformer.inverse(result, true) else jTransformer.forward(result, true)
64-
Vectors.dense(result)
62+
override protected def transformFunc: UserDefinedFunction = {
63+
udf { input: Vector =>
64+
val result = input.toArray
65+
val jTransformer = new DoubleDCT_1D(result.length)
66+
if ($(inverse)) jTransformer.inverse(result, true) else jTransformer.forward(result, true)
67+
Vectors.dense(result)
68+
}
6569
}
6670

6771
override protected def validateInputType(inputType: DataType): Unit = {

mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ import org.apache.spark.ml.param.{Param, ParamMap}
2323
import org.apache.spark.ml.util.Identifiable
2424
import org.apache.spark.mllib.feature
2525
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
26+
import org.apache.spark.sql.functions._
27+
import org.apache.spark.sql.UserDefinedFunction
2628
import org.apache.spark.sql.types.DataType
2729

2830
/**
@@ -49,10 +51,12 @@ class ElementwiseProduct(override val uid: String)
4951
/** @group getParam */
5052
def getScalingVec: Vector = getOrDefault(scalingVec)
5153

52-
override protected def createTransformFunc: Vector => Vector = {
54+
override protected def transformFunc: UserDefinedFunction = {
5355
require(params.contains(scalingVec), s"transformation requires a weight vector")
54-
val elemScaler = new feature.ElementwiseProduct($(scalingVec))
55-
elemScaler.transform
56+
udf { input: Vector =>
57+
val elemScaler = new feature.ElementwiseProduct($(scalingVec))
58+
elemScaler.transform(input)
59+
}
5660
}
5761

5862
override protected def outputDataType: DataType = new VectorUDT()

mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import org.apache.spark.annotation.{Experimental, Since}
2121
import org.apache.spark.ml.UnaryTransformer
2222
import org.apache.spark.ml.param._
2323
import org.apache.spark.ml.util._
24+
import org.apache.spark.sql.functions._
25+
import org.apache.spark.sql.UserDefinedFunction
2426
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
2527

2628
/**
@@ -56,8 +58,10 @@ class NGram(override val uid: String)
5658

5759
setDefault(n -> 2)
5860

59-
override protected def createTransformFunc: Seq[String] => Seq[String] = {
60-
_.iterator.sliding($(n)).withPartial(false).map(_.mkString(" ")).toSeq
61+
override protected def transformFunc: UserDefinedFunction = {
62+
udf { input: Seq[String] =>
63+
input.iterator.sliding($(n)).withPartial(false).map(_.mkString(" ")).toSeq
64+
}
6165
}
6266

6367
override protected def validateInputType(inputType: DataType): Unit = {

mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ import org.apache.spark.ml.param.{DoubleParam, ParamValidators}
2323
import org.apache.spark.ml.util._
2424
import org.apache.spark.mllib.feature
2525
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
26+
import org.apache.spark.sql.functions._
27+
import org.apache.spark.sql.UserDefinedFunction
2628
import org.apache.spark.sql.types.DataType
2729

2830
/**
@@ -50,9 +52,11 @@ class Normalizer(override val uid: String)
5052
/** @group setParam */
5153
def setP(value: Double): this.type = set(p, value)
5254

53-
override protected def createTransformFunc: Vector => Vector = {
54-
val normalizer = new feature.Normalizer($(p))
55-
normalizer.transform
55+
override protected def transformFunc: UserDefinedFunction = {
56+
udf { input: Vector =>
57+
val normalizer = new feature.Normalizer($(p))
58+
normalizer.transform(input)
59+
}
5660
}
5761

5862
override protected def outputDataType: DataType = new VectorUDT()

mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ import org.apache.spark.ml.UnaryTransformer
2424
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
2525
import org.apache.spark.ml.util._
2626
import org.apache.spark.mllib.linalg._
27+
import org.apache.spark.sql.functions._
2728
import org.apache.spark.sql.types.DataType
29+
import org.apache.spark.sql.UserDefinedFunction
2830

2931
/**
3032
* :: Experimental ::
@@ -56,8 +58,10 @@ class PolynomialExpansion(override val uid: String)
5658
/** @group setParam */
5759
def setDegree(value: Int): this.type = set(degree, value)
5860

59-
override protected def createTransformFunc: Vector => Vector = { v =>
60-
PolynomialExpansion.expand(v, $(degree))
61+
override protected def transformFunc: UserDefinedFunction = {
62+
udf { input: Vector =>
63+
PolynomialExpansion.expand(input, $(degree))
64+
}
6165
}
6266

6367
override protected def outputDataType: DataType = new VectorUDT()

mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import org.apache.spark.annotation.{Experimental, Since}
2121
import org.apache.spark.ml.UnaryTransformer
2222
import org.apache.spark.ml.param._
2323
import org.apache.spark.ml.util._
24+
import org.apache.spark.sql.functions._
25+
import org.apache.spark.sql.UserDefinedFunction
2426
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
2527

2628
/**
@@ -35,8 +37,10 @@ class Tokenizer(override val uid: String)
3537

3638
def this() = this(Identifiable.randomUID("tok"))
3739

38-
override protected def createTransformFunc: String => Seq[String] = {
39-
_.toLowerCase.split("\\s")
40+
override protected def transformFunc: UserDefinedFunction = {
41+
udf { input: String =>
42+
input.toLowerCase.split("\\s")
43+
}
4044
}
4145

4246
override protected def validateInputType(inputType: DataType): Unit = {
@@ -124,12 +128,14 @@ class RegexTokenizer(override val uid: String)
124128

125129
setDefault(minTokenLength -> 1, gaps -> true, pattern -> "\\s+", toLowercase -> true)
126130

127-
override protected def createTransformFunc: String => Seq[String] = { originStr =>
128-
val re = $(pattern).r
129-
val str = if ($(toLowercase)) originStr.toLowerCase() else originStr
130-
val tokens = if ($(gaps)) re.split(str).toSeq else re.findAllIn(str).toSeq
131-
val minLength = $(minTokenLength)
132-
tokens.filter(_.length >= minLength)
131+
override protected def transformFunc: UserDefinedFunction = {
132+
udf { input: String =>
133+
val re = $(pattern).r
134+
val str = if ($(toLowercase)) input.toLowerCase() else input
135+
val tokens = if ($(gaps)) re.split(str).toSeq else re.findAllIn(str).toSeq
136+
val minLength = $(minTokenLength)
137+
tokens.filter(_.length >= minLength)
138+
}
133139
}
134140

135141
override protected def validateInputType(inputType: DataType): Unit = {

0 commit comments

Comments
 (0)