Skip to content

Commit d74fc6b

Browse files
zhengruifengsrowen
authored andcommitted
[SPARK-29118][ML] Avoid redundant computation in transform of GMM & GLR
### What changes were proposed in this pull request? 1,GMM: obtaining the prediction (double) from its probabilty prediction(vector) 2,GLR: obtaining the prediction (double) from its link prediction(double) ### Why are the changes needed? it avoid predict twice ### Does this PR introduce any user-facing change? no ### How was this patch tested? existing tests Closes #25815 from zhengruifeng/gmm_transform_opt. Authored-by: zhengruifeng <ruifengz@foxmail.com> Signed-off-by: Sean Owen <sean.owen@databricks.com>
1 parent 376e17c commit d74fc6b

File tree

2 files changed

+38
-32
lines changed

2 files changed

+38
-32
lines changed

mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
3333
import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatrix,
3434
Vector => OldVector, Vectors => OldVectors}
3535
import org.apache.spark.rdd.RDD
36-
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, SparkSession}
37-
import org.apache.spark.sql.functions.udf
36+
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
37+
import org.apache.spark.sql.functions.{col, udf}
3838
import org.apache.spark.sql.types.{IntegerType, StructType}
3939
import org.apache.spark.storage.StorageLevel
4040

@@ -111,28 +111,32 @@ class GaussianMixtureModel private[ml] (
111111
override def transform(dataset: Dataset[_]): DataFrame = {
112112
transformSchema(dataset.schema, logging = true)
113113

114-
var predictionColNames = Seq.empty[String]
115-
var predictionColumns = Seq.empty[Column]
116-
117-
if ($(predictionCol).nonEmpty) {
118-
val predUDF = udf((vector: Vector) => predict(vector))
119-
predictionColNames :+= $(predictionCol)
120-
predictionColumns :+= predUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))
121-
}
114+
val vectorCol = DatasetUtils.columnToVector(dataset, $(featuresCol))
115+
var outputData = dataset
116+
var numColsOutput = 0
122117

123118
if ($(probabilityCol).nonEmpty) {
124119
val probUDF = udf((vector: Vector) => predictProbability(vector))
125-
predictionColNames :+= $(probabilityCol)
126-
predictionColumns :+= probUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))
120+
outputData = outputData.withColumn($(probabilityCol), probUDF(vectorCol))
121+
numColsOutput += 1
122+
}
123+
124+
if ($(predictionCol).nonEmpty) {
125+
if ($(probabilityCol).nonEmpty) {
126+
val predUDF = udf((vector: Vector) => vector.argmax)
127+
outputData = outputData.withColumn($(predictionCol), predUDF(col($(probabilityCol))))
128+
} else {
129+
val predUDF = udf((vector: Vector) => predict(vector))
130+
outputData = outputData.withColumn($(predictionCol), predUDF(vectorCol))
131+
}
132+
numColsOutput += 1
127133
}
128134

129-
if (predictionColNames.nonEmpty) {
130-
dataset.withColumns(predictionColNames, predictionColumns)
131-
} else {
135+
if (numColsOutput == 0) {
132136
this.logWarning(s"$uid: GaussianMixtureModel.transform() does nothing" +
133137
" because no output columns were set.")
134-
dataset.toDF()
135138
}
139+
outputData.toDF
136140
}
137141

138142
@Since("2.0.0")

mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,31 +1036,33 @@ class GeneralizedLinearRegressionModel private[ml] (
10361036
}
10371037

10381038
override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
1039-
var predictionColNames = Seq.empty[String]
1040-
var predictionColumns = Seq.empty[Column]
1041-
10421039
val offset = if (!hasOffsetCol) lit(0.0) else col($(offsetCol)).cast(DoubleType)
1040+
var outputData = dataset
1041+
var numColsOutput = 0
10431042

1044-
if ($(predictionCol).nonEmpty) {
1045-
val predictUDF = udf { (features: Vector, offset: Double) => predict(features, offset) }
1046-
predictionColNames :+= $(predictionCol)
1047-
predictionColumns :+= predictUDF(col($(featuresCol)), offset)
1043+
if (hasLinkPredictionCol) {
1044+
val predLinkUDF = udf((features: Vector, offset: Double) => predictLink(features, offset))
1045+
outputData = outputData
1046+
.withColumn($(linkPredictionCol), predLinkUDF(col($(featuresCol)), offset))
1047+
numColsOutput += 1
10481048
}
10491049

1050-
if (hasLinkPredictionCol) {
1051-
val predictLinkUDF =
1052-
udf { (features: Vector, offset: Double) => predictLink(features, offset) }
1053-
predictionColNames :+= $(linkPredictionCol)
1054-
predictionColumns :+= predictLinkUDF(col($(featuresCol)), offset)
1050+
if ($(predictionCol).nonEmpty) {
1051+
if (hasLinkPredictionCol) {
1052+
val predUDF = udf((eta: Double) => familyAndLink.fitted(eta))
1053+
outputData = outputData.withColumn($(predictionCol), predUDF(col($(linkPredictionCol))))
1054+
} else {
1055+
val predUDF = udf((features: Vector, offset: Double) => predict(features, offset))
1056+
outputData = outputData.withColumn($(predictionCol), predUDF(col($(featuresCol)), offset))
1057+
}
1058+
numColsOutput += 1
10551059
}
10561060

1057-
if (predictionColNames.nonEmpty) {
1058-
dataset.withColumns(predictionColNames, predictionColumns)
1059-
} else {
1061+
if (numColsOutput == 0) {
10601062
this.logWarning(s"$uid: GeneralizedLinearRegressionModel.transform() does nothing" +
10611063
" because no output columns were set.")
1062-
dataset.toDF()
10631064
}
1065+
outputData.toDF
10641066
}
10651067

10661068
/**

0 commit comments

Comments
 (0)