Skip to content

Commit 3ffdd36

Browse files
committed
OneVsRest use 'when ... otherwise' not UDF to generate new label at binary reduction
1 parent 71a077f commit 3ffdd36

File tree

1 file changed

+2
-8
lines changed

1 file changed

+2
-8
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ final class OneVsRestModel private[ml] (
9191
// add an accumulator column to store predictions of all the models
9292
val accColName = "mbc$acc" + UUID.randomUUID().toString
9393
val initUDF = udf { () => Map[Int, Double]() }
94-
val mapType = MapType(IntegerType, DoubleType, valueContainsNull = false)
9594
val newDataset = dataset.withColumn(accColName, initUDF())
9695

9796
// persist if underlying dataset is not persistent.
@@ -195,16 +194,11 @@ final class OneVsRest(override val uid: String)
195194

196195
// create k columns, one for each binary classifier.
197196
val models = Range(0, numClasses).par.map { index =>
198-
val labelUDF = udf { (label: Double) =>
199-
if (label.toInt == index) 1.0 else 0.0
200-
}
201-
202197
// generate new label metadata for the binary problem.
203-
// TODO: use when ... otherwise after SPARK-7321 is merged
204198
val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata()
205199
val labelColName = "mc2b$" + index
206-
val trainingDataset =
207-
multiclassLabeled.withColumn(labelColName, labelUDF(col($(labelCol))), newLabelMeta)
200+
val trainingDataset = multiclassLabeled.withColumn(
201+
labelColName, when(col($(labelCol)) === index.toDouble, 1.0).otherwise(0.0), newLabelMeta)
208202
val classifier = getClassifier
209203
val paramMap = new ParamMap()
210204
paramMap.put(classifier.labelCol -> labelColName)

0 commit comments

Comments
 (0)