@@ -91,7 +91,6 @@ final class OneVsRestModel private[ml] (
9191 // add an accumulator column to store predictions of all the models
9292 val accColName = " mbc$acc" + UUID .randomUUID().toString
9393 val initUDF = udf { () => Map [Int , Double ]() }
94- val mapType = MapType (IntegerType , DoubleType , valueContainsNull = false )
9594 val newDataset = dataset.withColumn(accColName, initUDF())
9695
9796 // persist if underlying dataset is not persistent.
@@ -195,16 +194,11 @@ final class OneVsRest(override val uid: String)
195194
196195 // create k columns, one for each binary classifier.
197196 val models = Range (0 , numClasses).par.map { index =>
198- val labelUDF = udf { (label : Double ) =>
199- if (label.toInt == index) 1.0 else 0.0
200- }
201-
202197 // generate new label metadata for the binary problem.
203- // TODO: use when ... otherwise after SPARK-7321 is merged
204198 val newLabelMeta = BinaryAttribute .defaultAttr.withName(" label" ).toMetadata()
205199 val labelColName = " mc2b$" + index
206- val trainingDataset =
207- multiclassLabeled.withColumn( labelColName, labelUDF (col($(labelCol))), newLabelMeta)
200+ val trainingDataset = multiclassLabeled.withColumn(
201+ labelColName, when (col($(labelCol)) === index.toDouble, 1.0 ).otherwise( 0.0 ), newLabelMeta)
208202 val classifier = getClassifier
209203 val paramMap = new ParamMap ()
210204 paramMap.put(classifier.labelCol -> labelColName)
0 commit comments