Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
make rawPrediction optionall
  • Loading branch information
lu-wang-dl committed Apr 13, 2018
commit 2a47e2be30d52e3fbea7e1eeeaa5048a6ac97116
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ final class OneVsRestModel private[ml] (
val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) =>
predictions + ((index, prediction(1)))
}

model.setFeaturesCol($(featuresCol))
val transformedDataset = model.transform(df).select(columns: _*)
val updatedDataset = transformedDataset
Expand All @@ -206,18 +207,31 @@ final class OneVsRestModel private[ml] (
}

// output the RawPrediction as vector
val rawPredictionUDF = udf { (predictions: Map[Int, Double]) =>
Vectors.sparse(numClasses, predictions.toList )
}
if (getRawPredictionCol != "") {
val rawPredictionUDF = udf { (predictions: Map[Int, Double]) =>
val myArray = Array.fill[Double](numClasses)(0.0)
predictions.foreach { case (idx, value) => myArray(idx) = value }
Vectors.dense(myArray)
}

// output the index of the classifier with highest confidence as prediction
val labelUDF = udf { (predictions: Vector) => predictions.argmax.toDouble }
// output the index of the classifier with highest confidence as prediction
val labelUDF = udf { (predictions: Vector) => predictions.argmax.toDouble }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

==> udf { (rawPredictions: Vector) => ... }


// output confidence as rwa prediction, label and label metadata as prediction
aggregatedDataset
.withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName)))
.withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata)
.drop(accColName)
aggregatedDataset
.withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName)))
.withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata)
.drop(accColName)
}
else {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scala style: This should go on the previous line: } else {

// output the index of the classifier with highest confidence as prediction
val labelUDF = udf { (predictions: Map[Int, Double]) =>
predictions.maxBy(_._2)._1.toDouble
}
// output confidence as rwa prediction, label and label metadata as prediction
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rwa -> raw

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment seems to be in the wrong part of the code. Also there's a typo

aggregatedDataset
.withColumn(getPredictionCol, labelUDF(col(accColName)), labelMetadata)
.drop(accColName)
}
}

@Since("1.4.1")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,10 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest {
val dataset2 = dataset.select(col("label").as("y"), col("features").as("fea"))
ovaModel.setFeaturesCol("fea")
ovaModel.setPredictionCol("pred")
ovaModel.setRawPredictionCol("rawpred")
ovaModel.setRawPredictionCol("")
val transformedDataset = ovaModel.transform(dataset2)
val outputFields = transformedDataset.schema.fieldNames.toSet
assert(outputFields === Set("y", "fea", "pred", "rawpred"))
assert(outputFields === Set("y", "fea", "pred"))
}

test("SPARK-8049: OneVsRest shouldn't output temp columns") {
Expand Down