Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
940 changes: 15 additions & 925 deletions docs/mllib-evaluation-metrics.md

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// scalastyle:off println
package org.apache.spark.examples.mllib;

// $example on$

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should not have a blank line. Pls remove it.

import scala.Tuple2;

import org.apache.spark.api.java.*;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS;
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.util.MLUtils;
import org.apache.spark.rdd.RDD;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

$example off$ here.

import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;

public class JavaBinaryClassification {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd better making these class names more concretely. I.e. JavaBinaryClassificatinMetricsExample. And the same for the following classes.

public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("Binary Classification Metrics");
SparkContext sc = new SparkContext(conf);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

$example on$ here.

String path = "data/mllib/sample_binary_classification_data.txt";
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD();

// Split initial RDD into two... [60% training data, 40% testing data].
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.6, 0.4}, 11L);
JavaRDD<LabeledPoint> training = splits[0].cache();
JavaRDD<LabeledPoint> test = splits[1];

// Run training algorithm to build the model.
final LogisticRegressionModel model = new LogisticRegressionWithLBFGS()
.setNumClasses(2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix the style error. Spark uses 2-space indent in Java and Scala, 4-indent in Python.

.run(training.rdd());

// Clear the prediction threshold so the model will return probabilities
model.clearThreshold();

// Compute raw scores on the test set.
JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map(
new Function<LabeledPoint, Tuple2<Object, Object>>() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix indention. Refer to Spark Scala Guide. Check your code style with dev/scalastyle, dev/lint-python, etc.

public Tuple2<Object, Object> call(LabeledPoint p) {
Double prediction = model.predict(p.features());
return new Tuple2<Object, Object>(prediction, p.label());
}
}
);

// Get evaluation metrics.
BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictionAndLabels.rdd());

// Precision by threshold
JavaRDD<Tuple2<Object, Object>> precision = metrics.precisionByThreshold().toJavaRDD();
System.out.println("Precision by threshold: " + precision.toArray());

// Recall by threshold
JavaRDD<Tuple2<Object, Object>> recall = metrics.recallByThreshold().toJavaRDD();
System.out.println("Recall by threshold: " + recall.toArray());

// F Score by threshold
JavaRDD<Tuple2<Object, Object>> f1Score = metrics.fMeasureByThreshold().toJavaRDD();
System.out.println("F1 Score by threshold: " + f1Score.toArray());

JavaRDD<Tuple2<Object, Object>> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD();
System.out.println("F2 Score by threshold: " + f2Score.toArray());

// Precision-recall curve
JavaRDD<Tuple2<Object, Object>> prc = metrics.pr().toJavaRDD();
System.out.println("Precision-recall curve: " + prc.toArray());

// Thresholds
JavaRDD<Double> thresholds = precision.map(
new Function<Tuple2<Object, Object>, Double>() {
public Double call(Tuple2<Object, Object> t) {
return new Double(t._1().toString());
}
}
);

// ROC Curve
JavaRDD<Tuple2<Object, Object>> roc = metrics.roc().toJavaRDD();
System.out.println("ROC curve: " + roc.toArray());

// AUPRC
System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR());

// AUROC
System.out.println("Area under ROC = " + metrics.areaUnderROC());

// Save and load model
model.save(sc, "myModelPath");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pls change all the "myModelPath" to "target/tmp/XXXModel". Here we can use "target/tmp/LogisticRegressionModel". Change the following paths accordingly.

LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

$example off$ here

}
}
// $example off$
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete this line

Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// scalastyle:off println
package org.apache.spark.examples.mllib;

// $example on$

import scala.Tuple2;

import org.apache.spark.api.java.*;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.regression.LinearRegressionModel;
import org.apache.spark.mllib.regression.LinearRegressionWithSGD;
import org.apache.spark.mllib.evaluation.RegressionMetrics;
import org.apache.spark.SparkConf;

// Read in the ratings data
public class JavaLinearRegression {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("Linear Regression Example");
JavaSparkContext sc = new JavaSparkContext(conf);

// Load and parse the data
String path = "data/mllib/sample_linear_regression_data.txt";
JavaRDD<String> data = sc.textFile(path);
JavaRDD<LabeledPoint> parsedData = data.map(
new Function<String, LabeledPoint>() {
public LabeledPoint call(String line) {
String[] parts = line.split(" ");
double[] v = new double[parts.length - 1];
for (int i = 1; i < parts.length - 1; i++)
v[i - 1] = Double.parseDouble(parts[i].split(":")[1]);
return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v));
}
}
);
parsedData.cache();

// Building the model
int numIterations = 100;
final LinearRegressionModel model =
LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations);

// Evaluate model on training examples and compute training error
JavaRDD<Tuple2<Object, Object>> valuesAndPreds = parsedData.map(
new Function<LabeledPoint, Tuple2<Object, Object>>() {
public Tuple2<Object, Object> call(LabeledPoint point) {
double prediction = model.predict(point.features());
return new Tuple2<Object, Object>(prediction, point.label());
}
}
);

// Instantiate metrics object
RegressionMetrics metrics = new RegressionMetrics(valuesAndPreds.rdd());

// Squared error
System.out.format("MSE = %f\n", metrics.meanSquaredError());
System.out.format("RMSE = %f\n", metrics.rootMeanSquaredError());

// R-squared
System.out.format("R Squared = %f\n", metrics.r2());

// Mean absolute error
System.out.format("MAE = %f\n", metrics.meanAbsoluteError());

// Explained variance
System.out.format("Explained Variance = %f\n", metrics.explainedVariance());

// Save and load model
model.save(sc.sc(), "myModelPath");
LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), "myModelPath");
}
}
// $example off$
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// scalastyle:off println
package org.apache.spark.examples.mllib;

// $example off$
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why $example off$ here?

import java.util.Arrays;
import java.util.List;
// $example on$
import scala.Tuple2;

import org.apache.spark.api.java.*;
import org.apache.spark.mllib.evaluation.MultilabelMetrics;
import org.apache.spark.rdd.RDD;
import org.apache.spark.SparkConf;
// $example off$
import org.apache.spark.SparkContext;

// $example on$
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not place $example on$ in the beginning of a class.

public class JavaMultiLabelClassification {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("Multilabel Classification Metrics");
JavaSparkContext sc = new JavaSparkContext(conf);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

$example on$ here.

List<Tuple2<double[], double[]>> data = Arrays.asList(
new Tuple2<double[], double[]>(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}),
new Tuple2<double[], double[]>(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}),
new Tuple2<double[], double[]>(new double[]{}, new double[]{0.0}),
new Tuple2<double[], double[]>(new double[]{2.0}, new double[]{2.0}),
new Tuple2<double[], double[]>(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}),
new Tuple2<double[], double[]>(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}),
new Tuple2<double[], double[]>(new double[]{1.0}, new double[]{1.0, 2.0})
);
JavaRDD<Tuple2<double[], double[]>> scoreAndLabels = sc.parallelize(data);

// Instantiate metrics object
MultilabelMetrics metrics = new MultilabelMetrics(scoreAndLabels.rdd());

// Summary stats
System.out.format("Recall = %f\n", metrics.recall());
System.out.format("Precision = %f\n", metrics.precision());
System.out.format("F1 measure = %f\n", metrics.f1Measure());
System.out.format("Accuracy = %f\n", metrics.accuracy());

// Stats by labels
for (int i = 0; i < metrics.labels().length - 1; i++) {
System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i]));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use 100 chars per line. Wrap the line if it exceeds 100 chars.

System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i]));
System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure(metrics.labels()[i]));
}

// Micro stats
System.out.format("Micro recall = %f\n", metrics.microRecall());
System.out.format("Micro precision = %f\n", metrics.microPrecision());
System.out.format("Micro F1 measure = %f\n", metrics.microF1Measure());

// Hamming loss
System.out.format("Hamming loss = %f\n", metrics.hammingLoss());

// Subset accuracy
System.out.format("Subset accuracy = %f\n", metrics.subsetAccuracy());

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

substitute the blank line with a $example off$.

}
}
// $example off$
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove it.

Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// scalastyle:off println
package org.apache.spark.examples.mllib

// $example on$

import scala.Tuple2;

import org.apache.spark.api.java.*;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS;
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.util.MLUtils;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.rdd.RDD;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;

public class JavaMulticlassClassification {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("Multi class Classification Metrics");
SparkContext sc = new SparkContext(conf);
String path = "data/mllib/sample_multiclass_classification_data.txt";
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD();

// Split initial RDD into two... [60% training data, 40% testing data].
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.6, 0.4}, 11L);
JavaRDD<LabeledPoint> training = splits[0].cache();
JavaRDD<LabeledPoint> test = splits[1];

// Run training algorithm to build the model.
final LogisticRegressionModel model = new LogisticRegressionWithLBFGS()
.setNumClasses(3)
.run(training.rdd());

// Compute raw scores on the test set.
JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map(
new Function<LabeledPoint, Tuple2<Object, Object>>() {
public Tuple2<Object, Object> call(LabeledPoint p) {
Double prediction = model.predict(p.features());
return new Tuple2<Object, Object>(prediction, p.label());
}
}
);

// Get evaluation metrics.
MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd());

// Confusion matrix
Matrix confusion = metrics.confusionMatrix();
System.out.println("Confusion matrix: \n" + confusion);

// Overall statistics
System.out.println("Precision = " + metrics.precision());
System.out.println("Recall = " + metrics.recall());
System.out.println("F1 Score = " + metrics.fMeasure());

// Stats by labels
for (int i = 0; i < metrics.labels().length; i++) {
System.out.format("Class %f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i]));
System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i]));
System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure(metrics.labels()[i]));
}

//Weighted stats
System.out.format("Weighted precision = %f\n", metrics.weightedPrecision());
System.out.format("Weighted recall = %f\n", metrics.weightedRecall());
System.out.format("Weighted F1 score = %f\n", metrics.weightedFMeasure());
System.out.format("Weighted false positive rate = %f\n", metrics.weightedFalsePositiveRate());

// Save and load model
model.save(sc, "myModelPath");
LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath");
}
}
// $example off$
Loading