Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
940 changes: 15 additions & 925 deletions docs/mllib-evaluation-metrics.md

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a space line here


package org.apache.spark.examples.mllib;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a space line here


// $example on$
import scala.Tuple2;

import org.apache.spark.api.java.*;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS;
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.util.MLUtils;
// $example off$
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;

public class JavaBinaryClassificationMetricsExample {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("Java Binary Classification Metrics Example");
SparkContext sc = new SparkContext(conf);
// $example on$
String path = "data/mllib/sample_binary_classification_data.txt";
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD();

// Split initial RDD into two... [60% training data, 40% testing data].
JavaRDD<LabeledPoint>[] splits =
data.randomSplit(new double[]{0.6, 0.4}, 11L);
JavaRDD<LabeledPoint> training = splits[0].cache();
JavaRDD<LabeledPoint> test = splits[1];

// Run training algorithm to build the model.
final LogisticRegressionModel model = new LogisticRegressionWithLBFGS()
.setNumClasses(2)
.run(training.rdd());

// Clear the prediction threshold so the model will return probabilities
model.clearThreshold();

// Compute raw scores on the test set.
JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map(
new Function<LabeledPoint, Tuple2<Object, Object>>() {
public Tuple2<Object, Object> call(LabeledPoint p) {
Double prediction = model.predict(p.features());
return new Tuple2<Object, Object>(prediction, p.label());
}
}
);

// Get evaluation metrics.
BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictionAndLabels.rdd());

// Precision by threshold
JavaRDD<Tuple2<Object, Object>> precision = metrics.precisionByThreshold().toJavaRDD();
System.out.println("Precision by threshold: " + precision.toArray());

// Recall by threshold
JavaRDD<Tuple2<Object, Object>> recall = metrics.recallByThreshold().toJavaRDD();
System.out.println("Recall by threshold: " + recall.toArray());

// F Score by threshold
JavaRDD<Tuple2<Object, Object>> f1Score = metrics.fMeasureByThreshold().toJavaRDD();
System.out.println("F1 Score by threshold: " + f1Score.toArray());

JavaRDD<Tuple2<Object, Object>> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD();
System.out.println("F2 Score by threshold: " + f2Score.toArray());

// Precision-recall curve
JavaRDD<Tuple2<Object, Object>> prc = metrics.pr().toJavaRDD();
System.out.println("Precision-recall curve: " + prc.toArray());

// Thresholds
JavaRDD<Double> thresholds = precision.map(
new Function<Tuple2<Object, Object>, Double>() {
public Double call(Tuple2<Object, Object> t) {
return new Double(t._1().toString());
}
}
);

// ROC Curve
JavaRDD<Tuple2<Object, Object>> roc = metrics.roc().toJavaRDD();
System.out.println("ROC curve: " + roc.toArray());

// AUPRC
System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR());

// AUROC
System.out.println("Area under ROC = " + metrics.areaUnderROC());

// Save and load model
model.save(sc, "target/tmp/LogisticRegressionModel");
LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc,
"target/tmp/LogisticRegressionModel");
// $example off$
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a space line here


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove a space line

package org.apache.spark.examples.mllib;

// $example on$
import java.util.Arrays;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change the imports like this:

// $example on$
import java.util.Arrays;
import java.util.List;

import scala.Tuple2;

import org.apache.spark.api.java.*;
import org.apache.spark.mllib.evaluation.MultilabelMetrics;
import org.apache.spark.SparkConf;
// $example off$

import java.util.List;

import scala.Tuple2;

import org.apache.spark.api.java.*;
import org.apache.spark.mllib.evaluation.MultilabelMetrics;
import org.apache.spark.rdd.RDD;
import org.apache.spark.SparkConf;
// $example off$
import org.apache.spark.SparkContext;

public class JavaMultiLabelClassificationMetricsExample {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("Multilabel Classification Metrics Example");
JavaSparkContext sc = new JavaSparkContext(conf);
// $example on$
List<Tuple2<double[], double[]>> data = Arrays.asList(
new Tuple2<double[], double[]>(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}),
new Tuple2<double[], double[]>(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}),
new Tuple2<double[], double[]>(new double[]{}, new double[]{0.0}),
new Tuple2<double[], double[]>(new double[]{2.0}, new double[]{2.0}),
new Tuple2<double[], double[]>(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}),
new Tuple2<double[], double[]>(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}),
new Tuple2<double[], double[]>(new double[]{1.0}, new double[]{1.0, 2.0})
);
JavaRDD<Tuple2<double[], double[]>> scoreAndLabels = sc.parallelize(data);

// Instantiate metrics object
MultilabelMetrics metrics = new MultilabelMetrics(scoreAndLabels.rdd());

// Summary stats
System.out.format("Recall = %f\n", metrics.recall());
System.out.format("Precision = %f\n", metrics.precision());
System.out.format("F1 measure = %f\n", metrics.f1Measure());
System.out.format("Accuracy = %f\n", metrics.accuracy());

// Stats by labels
for (int i = 0; i < metrics.labels().length - 1; i++) {
System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision
(metrics.labels()[i]));
System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall(metrics
.labels()[i]));
System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure
(metrics.labels()[i]));
}

// Micro stats
System.out.format("Micro recall = %f\n", metrics.microRecall());
System.out.format("Micro precision = %f\n", metrics.microPrecision());
System.out.format("Micro F1 measure = %f\n", metrics.microF1Measure());

// Hamming loss
System.out.format("Hamming loss = %f\n", metrics.hammingLoss());

// Subset accuracy
System.out.format("Subset accuracy = %f\n", metrics.subsetAccuracy());
// $example off$
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a space line


package org.apache.spark.examples.mllib;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a space line


// $example on$
import scala.Tuple2;

import org.apache.spark.api.java.*;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS;
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.util.MLUtils;
import org.apache.spark.mllib.linalg.Matrix;
// $example off$
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;

public class JavaMulticlassClassificationMetricsExample {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("Multi class Classification Metrics Example");
SparkContext sc = new SparkContext(conf);
// $example on$
String path = "data/mllib/sample_multiclass_classification_data.txt";
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD();

// Split initial RDD into two... [60% training data, 40% testing data].
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.6, 0.4}, 11L);
JavaRDD<LabeledPoint> training = splits[0].cache();
JavaRDD<LabeledPoint> test = splits[1];

// Run training algorithm to build the model.
final LogisticRegressionModel model = new LogisticRegressionWithLBFGS()
.setNumClasses(3)
.run(training.rdd());

// Compute raw scores on the test set.
JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map(
new Function<LabeledPoint, Tuple2<Object, Object>>() {
public Tuple2<Object, Object> call(LabeledPoint p) {
Double prediction = model.predict(p.features());
return new Tuple2<Object, Object>(prediction, p.label());
}
}
);

// Get evaluation metrics.
MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd());

// Confusion matrix
Matrix confusion = metrics.confusionMatrix();
System.out.println("Confusion matrix: \n" + confusion);

// Overall statistics
System.out.println("Precision = " + metrics.precision());
System.out.println("Recall = " + metrics.recall());
System.out.println("F1 Score = " + metrics.fMeasure());

// Stats by labels
for (int i = 0; i < metrics.labels().length; i++) {
System.out.format("Class %f precision = %f\n", metrics.labels()[i],metrics.precision
(metrics.labels()[i]));
System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall(metrics
.labels()[i]));
System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure
(metrics.labels()[i]));
}

//Weighted stats
System.out.format("Weighted precision = %f\n", metrics.weightedPrecision());
System.out.format("Weighted recall = %f\n", metrics.weightedRecall());
System.out.format("Weighted F1 score = %f\n", metrics.weightedFMeasure());
System.out.format("Weighted false positive rate = %f\n", metrics.weightedFalsePositiveRate());

// Save and load model
model.save(sc, "target/tmp/LogisticRegressionModel");
LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc,
"target/tmp/LogisticRegressionModel");
// $example off$
}
}
Loading