-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-11549][Docs] Replace example code in mllib-evaluation-metrics.md using include_example #9689
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
377d7a9
cb9c846
ed33687
3af5fa3
1106cae
ad3c01e
4d18447
3c40a35
892591b
8d2d508
54008ce
1c5cc8f
88512e7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ | |
| package org.apache.spark.examples.mllib; | ||
|
|
||
| // $example on$ | ||
|
|
||
| import scala.Tuple2; | ||
|
|
||
| import org.apache.spark.api.java.*; | ||
|
|
@@ -32,82 +33,80 @@ | |
| import org.apache.spark.SparkConf; | ||
| import org.apache.spark.SparkContext; | ||
|
|
||
|
|
||
| public class JavaBinaryClassification { | ||
|
||
| public static void main(String[] args) { | ||
|
|
||
| SparkConf conf = new SparkConf().setAppName("Binary Classification Metrics"); | ||
| SparkContext sc = new SparkContext(conf); | ||
| String path = "data/mllib/sample_binary_classification_data.txt"; | ||
| JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); | ||
|
|
||
| // Split initial RDD into two... [60% training data, 40% testing data]. | ||
| JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.6, 0.4}, 11L); | ||
| JavaRDD<LabeledPoint> training = splits[0].cache(); | ||
| JavaRDD<LabeledPoint> test = splits[1]; | ||
|
|
||
| // Run training algorithm to build the model. | ||
| final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() | ||
| .setNumClasses(2) | ||
| .run(training.rdd()); | ||
|
|
||
| // Clear the prediction threshold so the model will return probabilities | ||
| model.clearThreshold(); | ||
|
|
||
| // Compute raw scores on the test set. | ||
| JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map( | ||
| new Function<LabeledPoint, Tuple2<Object, Object>>() { | ||
| public Tuple2<Object, Object> call(LabeledPoint p) { | ||
| Double prediction = model.predict(p.features()); | ||
| return new Tuple2<Object, Object>(prediction, p.label()); | ||
| } | ||
| } | ||
| ); | ||
|
|
||
| // Get evaluation metrics. | ||
| BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictionAndLabels.rdd()); | ||
|
|
||
| // Precision by threshold | ||
| JavaRDD<Tuple2<Object, Object>> precision = metrics.precisionByThreshold().toJavaRDD(); | ||
| System.out.println("Precision by threshold: " + precision.toArray()); | ||
|
|
||
| // Recall by threshold | ||
| JavaRDD<Tuple2<Object, Object>> recall = metrics.recallByThreshold().toJavaRDD(); | ||
| System.out.println("Recall by threshold: " + recall.toArray()); | ||
|
|
||
| // F Score by threshold | ||
| JavaRDD<Tuple2<Object, Object>> f1Score = metrics.fMeasureByThreshold().toJavaRDD(); | ||
| System.out.println("F1 Score by threshold: " + f1Score.toArray()); | ||
|
|
||
| JavaRDD<Tuple2<Object, Object>> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD(); | ||
| System.out.println("F2 Score by threshold: " + f2Score.toArray()); | ||
|
|
||
| // Precision-recall curve | ||
| JavaRDD<Tuple2<Object, Object>> prc = metrics.pr().toJavaRDD(); | ||
| System.out.println("Precision-recall curve: " + prc.toArray()); | ||
|
|
||
| // Thresholds | ||
| JavaRDD<Double> thresholds = precision.map( | ||
| new Function<Tuple2<Object, Object>, Double>() { | ||
| public Double call(Tuple2<Object, Object> t) { | ||
| return new Double(t._1().toString()); | ||
| } | ||
| } | ||
| ); | ||
|
|
||
| // ROC Curve | ||
| JavaRDD<Tuple2<Object, Object>> roc = metrics.roc().toJavaRDD(); | ||
| System.out.println("ROC curve: " + roc.toArray()); | ||
|
|
||
| // AUPRC | ||
| System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR()); | ||
|
|
||
| // AUROC | ||
| System.out.println("Area under ROC = " + metrics.areaUnderROC()); | ||
|
|
||
| // Save and load model | ||
| model.save(sc, "myModelPath"); | ||
| LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); | ||
| } | ||
| public static void main(String[] args) { | ||
| SparkConf conf = new SparkConf().setAppName("Binary Classification Metrics"); | ||
| SparkContext sc = new SparkContext(conf); | ||
|
||
| String path = "data/mllib/sample_binary_classification_data.txt"; | ||
| JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); | ||
|
|
||
| // Split initial RDD into two... [60% training data, 40% testing data]. | ||
| JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.6, 0.4}, 11L); | ||
| JavaRDD<LabeledPoint> training = splits[0].cache(); | ||
| JavaRDD<LabeledPoint> test = splits[1]; | ||
|
|
||
| // Run training algorithm to build the model. | ||
| final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() | ||
| .setNumClasses(2) | ||
|
||
| .run(training.rdd()); | ||
|
|
||
| // Clear the prediction threshold so the model will return probabilities | ||
| model.clearThreshold(); | ||
|
|
||
| // Compute raw scores on the test set. | ||
| JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map( | ||
| new Function<LabeledPoint, Tuple2<Object, Object>>() { | ||
|
||
| public Tuple2<Object, Object> call(LabeledPoint p) { | ||
| Double prediction = model.predict(p.features()); | ||
| return new Tuple2<Object, Object>(prediction, p.label()); | ||
| } | ||
| } | ||
| ); | ||
|
|
||
| // Get evaluation metrics. | ||
| BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictionAndLabels.rdd()); | ||
|
|
||
| // Precision by threshold | ||
| JavaRDD<Tuple2<Object, Object>> precision = metrics.precisionByThreshold().toJavaRDD(); | ||
| System.out.println("Precision by threshold: " + precision.toArray()); | ||
|
|
||
| // Recall by threshold | ||
| JavaRDD<Tuple2<Object, Object>> recall = metrics.recallByThreshold().toJavaRDD(); | ||
| System.out.println("Recall by threshold: " + recall.toArray()); | ||
|
|
||
| // F Score by threshold | ||
| JavaRDD<Tuple2<Object, Object>> f1Score = metrics.fMeasureByThreshold().toJavaRDD(); | ||
| System.out.println("F1 Score by threshold: " + f1Score.toArray()); | ||
|
|
||
| JavaRDD<Tuple2<Object, Object>> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD(); | ||
| System.out.println("F2 Score by threshold: " + f2Score.toArray()); | ||
|
|
||
| // Precision-recall curve | ||
| JavaRDD<Tuple2<Object, Object>> prc = metrics.pr().toJavaRDD(); | ||
| System.out.println("Precision-recall curve: " + prc.toArray()); | ||
|
|
||
| // Thresholds | ||
| JavaRDD<Double> thresholds = precision.map( | ||
| new Function<Tuple2<Object, Object>, Double>() { | ||
| public Double call(Tuple2<Object, Object> t) { | ||
| return new Double(t._1().toString()); | ||
| } | ||
| } | ||
| ); | ||
|
|
||
| // ROC Curve | ||
| JavaRDD<Tuple2<Object, Object>> roc = metrics.roc().toJavaRDD(); | ||
| System.out.println("ROC curve: " + roc.toArray()); | ||
|
|
||
| // AUPRC | ||
| System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR()); | ||
|
|
||
| // AUROC | ||
| System.out.println("Area under ROC = " + metrics.areaUnderROC()); | ||
|
|
||
| // Save and load model | ||
| model.save(sc, "myModelPath"); | ||
|
||
| LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); | ||
|
||
| } | ||
| } | ||
| // $example off$ | ||
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,60 +19,63 @@ | |
| package org.apache.spark.examples.mllib; | ||
|
|
||
| // $example on$ | ||
|
|
||
| import scala.Tuple2; | ||
|
|
||
| import org.apache.spark.api.java.*; | ||
| import org.apache.spark.rdd.RDD; | ||
| import org.apache.spark.mllib.evaluation.MultilabelMetrics; | ||
| import org.apache.spark.SparkConf; | ||
|
|
||
| import java.util.Arrays; | ||
| import java.util.List; | ||
| // $example off$ | ||
| import org.apache.spark.SparkContext; | ||
|
|
||
| // $example on$ | ||
|
||
| public class JavaMultiLabelClassification { | ||
| public static void main(String[] args) { | ||
| SparkConf conf = new SparkConf().setAppName("Multilabel Classification Metrics"); | ||
| JavaSparkContext sc = new JavaSparkContext(conf); | ||
| public static void main(String[] args) { | ||
| SparkConf conf = new SparkConf().setAppName("Multilabel Classification Metrics"); | ||
| JavaSparkContext sc = new JavaSparkContext(conf); | ||
|
|
||
|
||
| List<Tuple2<double[], double[]>> data = Arrays.asList( | ||
| new Tuple2<double[], double[]>(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}), | ||
| new Tuple2<double[], double[]>(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}), | ||
| new Tuple2<double[], double[]>(new double[]{}, new double[]{0.0}), | ||
| new Tuple2<double[], double[]>(new double[]{2.0}, new double[]{2.0}), | ||
| new Tuple2<double[], double[]>(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}), | ||
| new Tuple2<double[], double[]>(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}), | ||
| new Tuple2<double[], double[]>(new double[]{1.0}, new double[]{1.0, 2.0}) | ||
| ); | ||
| JavaRDD<Tuple2<double[], double[]>> scoreAndLabels = sc.parallelize(data); | ||
| List<Tuple2<double[], double[]>> data = Arrays.asList( | ||
| new Tuple2<double[], double[]>(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}), | ||
| new Tuple2<double[], double[]>(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}), | ||
| new Tuple2<double[], double[]>(new double[]{}, new double[]{0.0}), | ||
| new Tuple2<double[], double[]>(new double[]{2.0}, new double[]{2.0}), | ||
| new Tuple2<double[], double[]>(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}), | ||
| new Tuple2<double[], double[]>(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}), | ||
| new Tuple2<double[], double[]>(new double[]{1.0}, new double[]{1.0, 2.0}) | ||
| ); | ||
| JavaRDD<Tuple2<double[], double[]>> scoreAndLabels = sc.parallelize(data); | ||
|
|
||
| // Instantiate metrics object | ||
| MultilabelMetrics metrics = new MultilabelMetrics(scoreAndLabels.rdd()); | ||
| // Instantiate metrics object | ||
| MultilabelMetrics metrics = new MultilabelMetrics(scoreAndLabels.rdd()); | ||
|
|
||
| // Summary stats | ||
| System.out.format("Recall = %f\n", metrics.recall()); | ||
| System.out.format("Precision = %f\n", metrics.precision()); | ||
| System.out.format("F1 measure = %f\n", metrics.f1Measure()); | ||
| System.out.format("Accuracy = %f\n", metrics.accuracy()); | ||
| // Summary stats | ||
| System.out.format("Recall = %f\n", metrics.recall()); | ||
| System.out.format("Precision = %f\n", metrics.precision()); | ||
| System.out.format("F1 measure = %f\n", metrics.f1Measure()); | ||
| System.out.format("Accuracy = %f\n", metrics.accuracy()); | ||
|
|
||
| // Stats by labels | ||
| for (int i = 0; i < metrics.labels().length - 1; i++) { | ||
| System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i])); | ||
| System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])); | ||
| System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure(metrics.labels()[i])); | ||
| } | ||
| // Stats by labels | ||
| for (int i = 0; i < metrics.labels().length - 1; i++) { | ||
| System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i])); | ||
|
||
| System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])); | ||
| System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure(metrics.labels()[i])); | ||
| } | ||
|
|
||
| // Micro stats | ||
| System.out.format("Micro recall = %f\n", metrics.microRecall()); | ||
| System.out.format("Micro precision = %f\n", metrics.microPrecision()); | ||
| System.out.format("Micro F1 measure = %f\n", metrics.microF1Measure()); | ||
| // Micro stats | ||
| System.out.format("Micro recall = %f\n", metrics.microRecall()); | ||
| System.out.format("Micro precision = %f\n", metrics.microPrecision()); | ||
| System.out.format("Micro F1 measure = %f\n", metrics.microF1Measure()); | ||
|
|
||
| // Hamming loss | ||
| System.out.format("Hamming loss = %f\n", metrics.hammingLoss()); | ||
| // Hamming loss | ||
| System.out.format("Hamming loss = %f\n", metrics.hammingLoss()); | ||
|
|
||
| // Subset accuracy | ||
| System.out.format("Subset accuracy = %f\n", metrics.subsetAccuracy()); | ||
| // Subset accuracy | ||
| System.out.format("Subset accuracy = %f\n", metrics.subsetAccuracy()); | ||
|
|
||
|
||
| } | ||
| } | ||
| } | ||
| // $example off$ | ||
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should not have a blank line. Pls remove it.