-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-11549][Docs] Replace example code in mllib-evaluation-metrics.md using include_example #9689
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
377d7a9
cb9c846
ed33687
3af5fa3
1106cae
ad3c01e
4d18447
3c40a35
892591b
8d2d508
54008ce
1c5cc8f
88512e7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,112 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| // scalastyle:off println | ||
| package org.apache.spark.examples.mllib; | ||
|
|
||
| // $example on$ | ||
|
|
||
| import scala.Tuple2; | ||
|
|
||
| import org.apache.spark.api.java.*; | ||
| import org.apache.spark.api.java.function.Function; | ||
| import org.apache.spark.mllib.classification.LogisticRegressionModel; | ||
| import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; | ||
| import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; | ||
| import org.apache.spark.mllib.regression.LabeledPoint; | ||
| import org.apache.spark.mllib.util.MLUtils; | ||
| import org.apache.spark.rdd.RDD; | ||
|
||
| import org.apache.spark.SparkConf; | ||
| import org.apache.spark.SparkContext; | ||
|
|
||
| public class JavaBinaryClassification { | ||
|
||
| public static void main(String[] args) { | ||
| SparkConf conf = new SparkConf().setAppName("Binary Classification Metrics"); | ||
| SparkContext sc = new SparkContext(conf); | ||
|
||
| String path = "data/mllib/sample_binary_classification_data.txt"; | ||
| JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); | ||
|
|
||
| // Split initial RDD into two... [60% training data, 40% testing data]. | ||
| JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.6, 0.4}, 11L); | ||
| JavaRDD<LabeledPoint> training = splits[0].cache(); | ||
| JavaRDD<LabeledPoint> test = splits[1]; | ||
|
|
||
| // Run training algorithm to build the model. | ||
| final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() | ||
| .setNumClasses(2) | ||
|
||
| .run(training.rdd()); | ||
|
|
||
| // Clear the prediction threshold so the model will return probabilities | ||
| model.clearThreshold(); | ||
|
|
||
| // Compute raw scores on the test set. | ||
| JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map( | ||
| new Function<LabeledPoint, Tuple2<Object, Object>>() { | ||
|
||
| public Tuple2<Object, Object> call(LabeledPoint p) { | ||
| Double prediction = model.predict(p.features()); | ||
| return new Tuple2<Object, Object>(prediction, p.label()); | ||
| } | ||
| } | ||
| ); | ||
|
|
||
| // Get evaluation metrics. | ||
| BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictionAndLabels.rdd()); | ||
|
|
||
| // Precision by threshold | ||
| JavaRDD<Tuple2<Object, Object>> precision = metrics.precisionByThreshold().toJavaRDD(); | ||
| System.out.println("Precision by threshold: " + precision.toArray()); | ||
|
|
||
| // Recall by threshold | ||
| JavaRDD<Tuple2<Object, Object>> recall = metrics.recallByThreshold().toJavaRDD(); | ||
| System.out.println("Recall by threshold: " + recall.toArray()); | ||
|
|
||
| // F Score by threshold | ||
| JavaRDD<Tuple2<Object, Object>> f1Score = metrics.fMeasureByThreshold().toJavaRDD(); | ||
| System.out.println("F1 Score by threshold: " + f1Score.toArray()); | ||
|
|
||
| JavaRDD<Tuple2<Object, Object>> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD(); | ||
| System.out.println("F2 Score by threshold: " + f2Score.toArray()); | ||
|
|
||
| // Precision-recall curve | ||
| JavaRDD<Tuple2<Object, Object>> prc = metrics.pr().toJavaRDD(); | ||
| System.out.println("Precision-recall curve: " + prc.toArray()); | ||
|
|
||
| // Thresholds | ||
| JavaRDD<Double> thresholds = precision.map( | ||
| new Function<Tuple2<Object, Object>, Double>() { | ||
| public Double call(Tuple2<Object, Object> t) { | ||
| return new Double(t._1().toString()); | ||
| } | ||
| } | ||
| ); | ||
|
|
||
| // ROC Curve | ||
| JavaRDD<Tuple2<Object, Object>> roc = metrics.roc().toJavaRDD(); | ||
| System.out.println("ROC curve: " + roc.toArray()); | ||
|
|
||
| // AUPRC | ||
| System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR()); | ||
|
|
||
| // AUROC | ||
| System.out.println("Area under ROC = " + metrics.areaUnderROC()); | ||
|
|
||
| // Save and load model | ||
| model.save(sc, "myModelPath"); | ||
|
||
| LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); | ||
|
||
| } | ||
| } | ||
| // $example off$ | ||
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,92 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| // scalastyle:off println | ||
| package org.apache.spark.examples.mllib; | ||
|
|
||
| // $example on$ | ||
|
|
||
| import scala.Tuple2; | ||
|
|
||
| import org.apache.spark.api.java.*; | ||
| import org.apache.spark.api.java.function.Function; | ||
| import org.apache.spark.mllib.linalg.Vectors; | ||
| import org.apache.spark.mllib.regression.LabeledPoint; | ||
| import org.apache.spark.mllib.regression.LinearRegressionModel; | ||
| import org.apache.spark.mllib.regression.LinearRegressionWithSGD; | ||
| import org.apache.spark.mllib.evaluation.RegressionMetrics; | ||
| import org.apache.spark.SparkConf; | ||
|
|
||
| // Read in the ratings data | ||
| public class JavaLinearRegression { | ||
| public static void main(String[] args) { | ||
| SparkConf conf = new SparkConf().setAppName("Linear Regression Example"); | ||
| JavaSparkContext sc = new JavaSparkContext(conf); | ||
|
|
||
| // Load and parse the data | ||
| String path = "data/mllib/sample_linear_regression_data.txt"; | ||
| JavaRDD<String> data = sc.textFile(path); | ||
| JavaRDD<LabeledPoint> parsedData = data.map( | ||
| new Function<String, LabeledPoint>() { | ||
| public LabeledPoint call(String line) { | ||
| String[] parts = line.split(" "); | ||
| double[] v = new double[parts.length - 1]; | ||
| for (int i = 1; i < parts.length - 1; i++) | ||
| v[i - 1] = Double.parseDouble(parts[i].split(":")[1]); | ||
| return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); | ||
| } | ||
| } | ||
| ); | ||
| parsedData.cache(); | ||
|
|
||
| // Building the model | ||
| int numIterations = 100; | ||
| final LinearRegressionModel model = | ||
| LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations); | ||
|
|
||
| // Evaluate model on training examples and compute training error | ||
| JavaRDD<Tuple2<Object, Object>> valuesAndPreds = parsedData.map( | ||
| new Function<LabeledPoint, Tuple2<Object, Object>>() { | ||
| public Tuple2<Object, Object> call(LabeledPoint point) { | ||
| double prediction = model.predict(point.features()); | ||
| return new Tuple2<Object, Object>(prediction, point.label()); | ||
| } | ||
| } | ||
| ); | ||
|
|
||
| // Instantiate metrics object | ||
| RegressionMetrics metrics = new RegressionMetrics(valuesAndPreds.rdd()); | ||
|
|
||
| // Squared error | ||
| System.out.format("MSE = %f\n", metrics.meanSquaredError()); | ||
| System.out.format("RMSE = %f\n", metrics.rootMeanSquaredError()); | ||
|
|
||
| // R-squared | ||
| System.out.format("R Squared = %f\n", metrics.r2()); | ||
|
|
||
| // Mean absolute error | ||
| System.out.format("MAE = %f\n", metrics.meanAbsoluteError()); | ||
|
|
||
| // Explained variance | ||
| System.out.format("Explained Variance = %f\n", metrics.explainedVariance()); | ||
|
|
||
| // Save and load model | ||
| model.save(sc.sc(), "myModelPath"); | ||
| LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), "myModelPath"); | ||
| } | ||
| } | ||
| // $example off$ |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| // scalastyle:off println | ||
| package org.apache.spark.examples.mllib; | ||
|
|
||
| // $example off$ | ||
|
||
| import java.util.Arrays; | ||
| import java.util.List; | ||
| // $example on$ | ||
| import scala.Tuple2; | ||
|
|
||
| import org.apache.spark.api.java.*; | ||
| import org.apache.spark.mllib.evaluation.MultilabelMetrics; | ||
| import org.apache.spark.rdd.RDD; | ||
| import org.apache.spark.SparkConf; | ||
| // $example off$ | ||
| import org.apache.spark.SparkContext; | ||
|
|
||
| // $example on$ | ||
|
||
| public class JavaMultiLabelClassification { | ||
| public static void main(String[] args) { | ||
| SparkConf conf = new SparkConf().setAppName("Multilabel Classification Metrics"); | ||
| JavaSparkContext sc = new JavaSparkContext(conf); | ||
|
|
||
|
||
| List<Tuple2<double[], double[]>> data = Arrays.asList( | ||
| new Tuple2<double[], double[]>(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}), | ||
| new Tuple2<double[], double[]>(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}), | ||
| new Tuple2<double[], double[]>(new double[]{}, new double[]{0.0}), | ||
| new Tuple2<double[], double[]>(new double[]{2.0}, new double[]{2.0}), | ||
| new Tuple2<double[], double[]>(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}), | ||
| new Tuple2<double[], double[]>(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}), | ||
| new Tuple2<double[], double[]>(new double[]{1.0}, new double[]{1.0, 2.0}) | ||
| ); | ||
| JavaRDD<Tuple2<double[], double[]>> scoreAndLabels = sc.parallelize(data); | ||
|
|
||
| // Instantiate metrics object | ||
| MultilabelMetrics metrics = new MultilabelMetrics(scoreAndLabels.rdd()); | ||
|
|
||
| // Summary stats | ||
| System.out.format("Recall = %f\n", metrics.recall()); | ||
| System.out.format("Precision = %f\n", metrics.precision()); | ||
| System.out.format("F1 measure = %f\n", metrics.f1Measure()); | ||
| System.out.format("Accuracy = %f\n", metrics.accuracy()); | ||
|
|
||
| // Stats by labels | ||
| for (int i = 0; i < metrics.labels().length - 1; i++) { | ||
| System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i])); | ||
|
||
| System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])); | ||
| System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure(metrics.labels()[i])); | ||
| } | ||
|
|
||
| // Micro stats | ||
| System.out.format("Micro recall = %f\n", metrics.microRecall()); | ||
| System.out.format("Micro precision = %f\n", metrics.microPrecision()); | ||
| System.out.format("Micro F1 measure = %f\n", metrics.microF1Measure()); | ||
|
|
||
| // Hamming loss | ||
| System.out.format("Hamming loss = %f\n", metrics.hammingLoss()); | ||
|
|
||
| // Subset accuracy | ||
| System.out.format("Subset accuracy = %f\n", metrics.subsetAccuracy()); | ||
|
|
||
|
||
| } | ||
| } | ||
| // $example off$ | ||
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,94 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| // scalastyle:off println | ||
| package org.apache.spark.examples.mllib | ||
|
|
||
| // $example on$ | ||
|
|
||
| import scala.Tuple2; | ||
|
|
||
| import org.apache.spark.api.java.*; | ||
| import org.apache.spark.api.java.function.Function; | ||
| import org.apache.spark.mllib.classification.LogisticRegressionModel; | ||
| import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; | ||
| import org.apache.spark.mllib.evaluation.MulticlassMetrics; | ||
| import org.apache.spark.mllib.regression.LabeledPoint; | ||
| import org.apache.spark.mllib.util.MLUtils; | ||
| import org.apache.spark.mllib.linalg.Matrix; | ||
| import org.apache.spark.rdd.RDD; | ||
| import org.apache.spark.SparkConf; | ||
| import org.apache.spark.SparkContext; | ||
|
|
||
| public class JavaMulticlassClassification { | ||
| public static void main(String[] args) { | ||
| SparkConf conf = new SparkConf().setAppName("Multi class Classification Metrics"); | ||
| SparkContext sc = new SparkContext(conf); | ||
| String path = "data/mllib/sample_multiclass_classification_data.txt"; | ||
| JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); | ||
|
|
||
| // Split initial RDD into two... [60% training data, 40% testing data]. | ||
| JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.6, 0.4}, 11L); | ||
| JavaRDD<LabeledPoint> training = splits[0].cache(); | ||
| JavaRDD<LabeledPoint> test = splits[1]; | ||
|
|
||
| // Run training algorithm to build the model. | ||
| final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() | ||
| .setNumClasses(3) | ||
| .run(training.rdd()); | ||
|
|
||
| // Compute raw scores on the test set. | ||
| JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map( | ||
| new Function<LabeledPoint, Tuple2<Object, Object>>() { | ||
| public Tuple2<Object, Object> call(LabeledPoint p) { | ||
| Double prediction = model.predict(p.features()); | ||
| return new Tuple2<Object, Object>(prediction, p.label()); | ||
| } | ||
| } | ||
| ); | ||
|
|
||
| // Get evaluation metrics. | ||
| MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); | ||
|
|
||
| // Confusion matrix | ||
| Matrix confusion = metrics.confusionMatrix(); | ||
| System.out.println("Confusion matrix: \n" + confusion); | ||
|
|
||
| // Overall statistics | ||
| System.out.println("Precision = " + metrics.precision()); | ||
| System.out.println("Recall = " + metrics.recall()); | ||
| System.out.println("F1 Score = " + metrics.fMeasure()); | ||
|
|
||
| // Stats by labels | ||
| for (int i = 0; i < metrics.labels().length; i++) { | ||
| System.out.format("Class %f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i])); | ||
| System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])); | ||
| System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure(metrics.labels()[i])); | ||
| } | ||
|
|
||
| //Weighted stats | ||
| System.out.format("Weighted precision = %f\n", metrics.weightedPrecision()); | ||
| System.out.format("Weighted recall = %f\n", metrics.weightedRecall()); | ||
| System.out.format("Weighted F1 score = %f\n", metrics.weightedFMeasure()); | ||
| System.out.format("Weighted false positive rate = %f\n", metrics.weightedFalsePositiveRate()); | ||
|
|
||
| // Save and load model | ||
| model.save(sc, "myModelPath"); | ||
| LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); | ||
| } | ||
| } | ||
| // $example off$ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should not have a blank line. Pls remove it.