-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-9906] [ML] User guide for LogisticRegressionSummary #8197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
487b361
56cb35b
9831270
1ab3d9c
9825b14
1dfe7f6
4060c5b
83d229f
7bf922c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -127,3 +127,168 @@ print("Intercept: " + str(lrModel.intercept)) | |
|
|
||
| The optimization algorithm underlies the implementation is called [Orthant-Wise Limited-memory QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf) | ||
| (OWL-QN). It is an extension of L-BFGS that can effectively handle L1 regularization and elastic net. | ||
|
|
||
| ### Model Summaries | ||
|
|
||
| Once a linear model is fit on data, it is useful to extract statistics such as the | ||
| loss per iteration and metrics to understand how well the model has performed on training | ||
| and test data. The examples provided below will help in understanding how to use the summaries | ||
| obtained by the summary method of the fitted linear models. | ||
|
|
||
| Note that the predictions and metrics which are stored as dataframes obtained from the summary | ||
| are transient and are not available on the driver. This is because these are as expensive | ||
| to store as the original data itself. | ||
|
|
||
| #### Example: Summary for LogisticRegression | ||
| <div class="codetabs"> | ||
| <div data-lang="scala"> | ||
| [`LogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary) | ||
| provides an interface to access information such as `objectiveHistory` and metrics | ||
| to evaluate the performance on the training data directly with very less code to be rewritten by | ||
| the user. [`LogisticRegression`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegression) | ||
| currently supports only binary classification and hence in order to access the binary metrics | ||
| the summary must be explicitly cast to | ||
| [BinaryLogisticRegressionTrainingSummary](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary) | ||
| as done in the code below. This avoids raising errors for multiclass outputs while providing | ||
| extensiblity when multiclass classification is supported in the future. | ||
|
|
||
| This example illustrates the use of `LogisticRegressionTrainingSummary` on some toy data. | ||
|
|
||
| {% highlight scala %} | ||
| import org.apache.spark.{SparkConf, SparkContext} | ||
| import org.apache.spark.ml.classification.{LogisticRegression, BinaryLogisticRegressionSummary} | ||
| import org.apache.spark.mllib.regression.LabeledPoint | ||
| import org.apache.spark.mllib.linalg.Vectors | ||
| import org.apache.spark.sql.{Row, SQLContext} | ||
|
|
||
| val conf = new SparkConf().setAppName("LogisticRegressionSummary") | ||
| val sc = new SparkContext(conf) | ||
| val sqlContext = new SQLContext(sc) | ||
| import sqlContext.implicits._ | ||
|
|
||
| // Use some random data for demonstration. | ||
| // Note that the RDD of LabeledPoints can be converted to a dataframe directly. | ||
| val data = sc.parallelize(Array( | ||
| LabeledPoint(0.0, Vectors.dense(0.2, 4.5, 1.6)), | ||
| LabeledPoint(1.0, Vectors.dense(3.1, 6.8, 3.6)), | ||
| LabeledPoint(0.0, Vectors.dense(2.4, 0.9, 1.9)), | ||
| LabeledPoint(1.0, Vectors.dense(9.1, 3.1, 3.6)), | ||
| LabeledPoint(0.0, Vectors.dense(2.5, 1.9, 9.1))) | ||
| ) | ||
| val logRegDataFrame = data.toDF() | ||
|
|
||
| // Run Logistic Regression on your toy data. | ||
| // Since LogisticRegression is an estimator, it returns an instance of LogisticRegressionModel | ||
| // which is a transformer. | ||
| val logReg = new LogisticRegression().setMaxIter(5).setRegParam(0.01) | ||
| val logRegModel = logReg.fit(logRegDataFrame) | ||
|
|
||
| // Extract the summary directly from the returned LogisticRegressionModel instance. | ||
| val trainingSummary = logRegModel.summary | ||
|
|
||
| // Obtain the loss per iteration. | ||
| val objectiveHistory = trainingSummary.objectiveHistory | ||
| objectiveHistory.foreach(loss => println(loss)) | ||
|
|
||
| // Obtain the metrics useful to judge performance on test data. | ||
| // We cast the summary to a BinaryLogisticRegressionSummary since the problem is a | ||
| // binary classification problem. | ||
| val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary] | ||
|
|
||
| // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. | ||
| val roc = binarySummary.roc | ||
| roc.show() | ||
| roc.select("FPR").show() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd remove this line and instead add a comment to the previous one if you want to say what the column names are. |
||
| println(binarySummary.areaUnderROC) | ||
|
|
||
| // Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with | ||
| // this selected threshold. | ||
| val fMeasure = binarySummary.fMeasureByThreshold | ||
| val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0) | ||
| val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure). | ||
| select("threshold").head().getDouble(0) | ||
| logReg.setThreshold(bestThreshold) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's no need to re-fit the model since the threshold is only used during prediction. Instead, set the threshold in lrModel. |
||
| logReg.fit(logRegDataFrame) | ||
| {% endhighlight %} | ||
| </div> | ||
|
|
||
| <div data-lang="java"> | ||
| [`LogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary) | ||
| provides an interface to access information such as `objectiveHistory` and metrics | ||
| to evaluate the performance on the training data directly with very less code to be rewritten by | ||
| the user. [`LogisticRegression`](api/java/org/apache/spark/ml/classification/LogisticRegression) | ||
| currently supports only binary classification and hence in order to access the binary metrics | ||
| the summary must be explicitly cast to | ||
| [BinaryLogisticRegressionTrainingSummary](api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary) | ||
| as done in the code below. This avoids raising errors for multiclass outputs while providing | ||
| extensiblity when multiclass classification is supported in the future | ||
|
|
||
| This example illustrates the use of `LogisticRegressionTrainingSummary` on some toy data. | ||
|
|
||
| {% highlight java %} | ||
| import com.google.common.collect.Lists; | ||
|
|
||
| import org.apache.spark.SparkConf; | ||
| import org.apache.spark.api.java.JavaRDD; | ||
| import org.apache.spark.api.java.JavaSparkContext; | ||
| import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary; | ||
| import org.apache.spark.ml.classification.LogisticRegression; | ||
| import org.apache.spark.ml.classification.LogisticRegressionModel; | ||
| import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary; | ||
| import org.apache.spark.mllib.regression.LabeledPoint; | ||
| import org.apache.spark.mllib.linalg.Vectors; | ||
| import org.apache.spark.sql.DataFrame; | ||
| import org.apache.spark.sql.Row; | ||
| import org.apache.spark.sql.SQLContext; | ||
| import static org.apache.spark.sql.functions.*; | ||
|
|
||
| SparkConf conf = new SparkConf().setAppName("LogisticRegressionSummary"); | ||
| JavaSparkContext jsc = new JavaSparkContext(conf); | ||
| SQLContext jsql = new SQLContext(jsc); | ||
|
|
||
| // Use some random data for demonstration. | ||
| // Note that the RDD of LabeledPoints can be converted to a dataframe directly. | ||
| JavaRDD<LabeledPoint> data = sc.parallelize(Lists.newArrayList( | ||
| new LabeledPoint(0.0, Vectors.dense(0.2, 4.5, 1.6)), | ||
| new LabeledPoint(1.0, Vectors.dense(3.1, 6.8, 3.6)), | ||
| new LabeledPoint(0.0, Vectors.dense(2.4, 0.9, 1.9)), | ||
| new LabeledPoint(1.0, Vectors.dense(9.1, 3.1, 3.6)), | ||
| new LabeledPoint(0.0, Vectors.dense(2.5, 1.9, 9.1))) | ||
| ); | ||
| DataFrame logRegDataFrame = sql.createDataFrame(data, LabeledPoint.class); | ||
|
|
||
| // Run Logistic Regression on your toy data. | ||
| // Since LogisticRegression is an estimator, it returns an instance of LogisticRegressionModel | ||
| // which is a transformer. | ||
| LogisticRegression logReg = new LogisticRegression().setMaxIter(5).setRegParam(0.01); | ||
| LogisticRegressionModel logRegModel = logReg.fit(logRegDataFrame); | ||
|
|
||
| // Extract the summary directly from the returned LogisticRegressionModel instance. | ||
| LogisticRegressionTrainingSummary trainingSummary = logRegModel.summary(); | ||
|
|
||
| // Obtain the loss per iteration. | ||
| double[] objectiveHistory = trainingSummary.objectiveHistory(); | ||
| for (double lossPerIteration: objectiveHistory) { | ||
| System.out.println(lossPerIteration); | ||
| } | ||
|
|
||
| // Obtain the metrics useful to judge performance on test data. | ||
| BinaryLogisticRegressionSummary binarySummary = (BinaryLogisticRegressionSummary) trainingSummary; | ||
|
|
||
| // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. | ||
| DataFrame roc = binarySummary.roc(); | ||
| roc.show(); | ||
| roc.select("FPR").show(); | ||
| System.out.println(binarySummary.areaUnderROC()); | ||
|
|
||
| // Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with | ||
| // this selected threshold. | ||
| DataFrame fMeasure = binarySummary.fMeasureByThreshold(); | ||
| double maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0); | ||
| double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure)). | ||
| select("threshold").head().getDouble(0); | ||
| logReg.setThreshold(bestThreshold); | ||
| logReg.fit(logRegDataFrame); | ||
| {% endhighlight %} | ||
| </div> | ||
| </div> | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"loss" --> "objective" (which includes regularization)