-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-9906] [ML] User guide for LogisticRegressionSummary #8197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
487b361
56cb35b
9831270
1ab3d9c
9825b14
1dfe7f6
4060c5b
83d229f
7bf922c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,20 +23,41 @@ displayTitle: <a href="ml-guide.html">ML</a> - Linear Methods | |
| \]` | ||
|
|
||
|
|
||
| In MLlib, we implement popular linear methods such as logistic regression and linear least squares with L1 or L2 regularization. Refer to [the linear methods in mllib](mllib-linear-methods.html) for details. In `spark.ml`, we also include Pipelines API for [Elastic net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid of L1 and L2 regularization proposed in [this paper](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). Mathematically it is defined as a linear combination of the L1-norm and the L2-norm: | ||
| In MLlib, we implement popular linear methods such as logistic | ||
| regression and linear least squares with $L_1$ or $L_2$ regularization. | ||
| Refer to [the linear methods in mllib](mllib-linear-methods.html) for | ||
| details. In `spark.ml`, we also include Pipelines API for [Elastic | ||
| net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid | ||
| of $L_1$ and $L_2$ regularization proposed in [Zou et al, Regularization | ||
| and variable selection via the elastic | ||
| net](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). | ||
| Mathematically, it is defined as a convex combination of the $L_1$ and | ||
| the $L_2$ regularization terms: | ||
| `\[ | ||
| \alpha \|\wv\|_1 + (1-\alpha) \frac{1}{2}\|\wv\|_2^2, \alpha \in [0, 1]. | ||
| \alpha~\lambda \|\wv\|_1 + (1-\alpha) \frac{\lambda}{2}\|\wv\|_2^2, \alpha \in [0, 1], \lambda \geq 0. | ||
| \]` | ||
| By setting $\alpha$ properly, it contains both L1 and L2 regularization as special cases. For example, if a [linear regression](https://en.wikipedia.org/wiki/Linear_regression) model is trained with the elastic net parameter $\alpha$ set to $1$, it is equivalent to a [Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model. On the other hand, if $\alpha$ is set to $0$, the trained model reduces to a [ridge regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model. We implement Pipelines API for both linear regression and logistic regression with elastic net regularization. | ||
|
|
||
| **Examples** | ||
| By setting $\alpha$ properly, elastic net contains both $L_1$ and $L_2$ | ||
| regularization as special cases. For example, if a [linear | ||
| regression](https://en.wikipedia.org/wiki/Linear_regression) model is | ||
| trained with the elastic net parameter $\alpha$ set to $1$, it is | ||
| equivalent to a | ||
| [Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model. | ||
| On the other hand, if $\alpha$ is set to $0$, the trained model reduces | ||
| to a [ridge | ||
| regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model. | ||
| We implement Pipelines API for both linear regression and logistic | ||
| regression with elastic net regularization. | ||
|
|
||
| ## Example: Logistic Regression | ||
|
|
||
| The following example shows how to train a logistic regression model | ||
| with elastic net regularization. `elasticNetParam` corresponds to | ||
| $\alpha$ and `regParam` corresponds to $\lambda$. | ||
|
|
||
| <div class="codetabs"> | ||
|
|
||
| <div data-lang="scala" markdown="1"> | ||
|
|
||
| {% highlight scala %} | ||
|
|
||
| import org.apache.spark.ml.classification.LogisticRegression | ||
| import org.apache.spark.mllib.util.MLUtils | ||
|
|
||
|
|
@@ -53,15 +74,11 @@ val lrModel = lr.fit(training) | |
|
|
||
| // Print the weights and intercept for logistic regression | ||
| println(s"Weights: ${lrModel.weights} Intercept: ${lrModel.intercept}") | ||
|
|
||
| {% endhighlight %} | ||
|
|
||
| </div> | ||
|
|
||
| <div data-lang="java" markdown="1"> | ||
|
|
||
| {% highlight java %} | ||
|
|
||
| import org.apache.spark.ml.classification.LogisticRegression; | ||
| import org.apache.spark.ml.classification.LogisticRegressionModel; | ||
| import org.apache.spark.mllib.regression.LabeledPoint; | ||
|
|
@@ -99,9 +116,7 @@ public class LogisticRegressionWithElasticNetExample { | |
| </div> | ||
|
|
||
| <div data-lang="python" markdown="1"> | ||
|
|
||
| {% highlight python %} | ||
|
|
||
| from pyspark.ml.classification import LogisticRegression | ||
| from pyspark.mllib.regression import LabeledPoint | ||
| from pyspark.mllib.util import MLUtils | ||
|
|
@@ -118,12 +133,114 @@ lrModel = lr.fit(training) | |
| print("Weights: " + str(lrModel.weights)) | ||
| print("Intercept: " + str(lrModel.intercept)) | ||
| {% endhighlight %} | ||
| </div> | ||
|
|
||
| </div> | ||
|
|
||
| The `spark.ml` implementation of logistic regression also supports | ||
| extracting a summary of the model over the training set. Note that the | ||
| predictions and metrics which are stored as `Dataframe` in | ||
| `BinaryLogisticRegressionSummary` are annotated `@transient` and hence | ||
| only available on the driver. | ||
|
|
||
| <div class="codetabs"> | ||
|
|
||
| <div data-lang="scala" markdown="1"> | ||
|
|
||
| [`LogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary) | ||
| provides a summary for a | ||
| [`LogisticRegressionModel`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel). | ||
| Currently, only binary classification is supported and the | ||
| summary must be explicitly cast to | ||
| [`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary). | ||
| This will likely change when multiclass classification is supported. | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually the cast will not change. Even when multiclass classification is supported, we will have to do an explicit cast since the binary metrics will not be available in the sealed
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suggest we just remove this line. What say?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Downcasting is almost always an indication of a poor abstraction and IMO the stabilized API should not require any explicit typecasting by the end user, here's an explanation
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The other option would be just to make all metrics available in See:
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Synced with @jkbradley offline. Summary: We should not require end users to perform any sort of downcasting in the stabilized API. This is OK for now since the API is still experimental. Eventually we could provide two methods, a |
||
|
|
||
| Continuing the earlier example: | ||
|
|
||
| {% highlight scala %} | ||
| // Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example | ||
| val trainingSummary = lrModel.summary | ||
|
|
||
| // Obtain the loss per iteration. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "loss" --> "objective" (which includes regularization) |
||
| val objectiveHistory = trainingSummary.objectiveHistory | ||
| objectiveHistory.foreach(loss => println(loss)) | ||
|
|
||
| // Obtain the metrics useful to judge performance on test data. | ||
| // We cast the summary to a BinaryLogisticRegressionSummary since the problem is a | ||
| // binary classification problem. | ||
| val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary] | ||
|
|
||
| // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. | ||
| val roc = binarySummary.roc | ||
| roc.show() | ||
| roc.select("FPR").show() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd remove this line and instead add a comment to the previous one if you want to say what the column names are. |
||
| println(binarySummary.areaUnderROC) | ||
|
|
||
| // Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with | ||
| // this selected threshold. | ||
| val fMeasure = binarySummary.fMeasureByThreshold | ||
| val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0) | ||
| val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure). | ||
| select("threshold").head().getDouble(0) | ||
| logReg.setThreshold(bestThreshold) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's no need to re-fit the model since the threshold is only used during prediction. Instead, set the threshold in lrModel. |
||
| logReg.fit(logRegDataFrame) | ||
| {% endhighlight %} | ||
| </div> | ||
|
|
||
| ### Optimization | ||
| <div data-lang="java" markdown="1"> | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add the needed imports for this new Java section. |
||
| [`LogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html) | ||
| provides a summary for a | ||
| [`LogisticRegressionModel`](api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html). | ||
| Currently, only binary classification is supported and the | ||
| summary must be explicitly cast to | ||
| [`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html). | ||
| This will likely change when multiclass classification is supported. | ||
|
|
||
| Continuing the earlier example: | ||
|
|
||
| {% highlight java %} | ||
| // Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example | ||
| LogisticRegressionTrainingSummary trainingSummary = logRegModel.summary(); | ||
|
|
||
| // Obtain the loss per iteration. | ||
| double[] objectiveHistory = trainingSummary.objectiveHistory(); | ||
| for (double lossPerIteration : objectiveHistory) { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know about Java style but from the other examples I think the colon should be one place to the left
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think so, see Google's Java Style Guide
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see then other places such as this (https://github.com/apache/spark/blob/master/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java#L68) have to be changed.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps... the Spark style guide only covers Scala and Python so I was just going by past experience on this change. We could try to get a Java style guide in if there's a community need for it |
||
| System.out.println(lossPerIteration); | ||
| } | ||
|
|
||
| // Obtain the metrics useful to judge performance on test data. | ||
| // We cast the summary to a BinaryLogisticRegressionSummary since the problem is a | ||
| // binary classification problem. | ||
| BinaryLogisticRegressionSummary binarySummary = (BinaryLogisticRegressionSummary) trainingSummary; | ||
|
|
||
| // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. | ||
| DataFrame roc = binarySummary.roc(); | ||
| roc.show(); | ||
| roc.select("FPR").show(); | ||
| System.out.println(binarySummary.areaUnderROC()); | ||
|
|
||
| // Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with | ||
| // this selected threshold. | ||
| DataFrame fMeasure = binarySummary.fMeasureByThreshold(); | ||
| double maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0); | ||
| double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure)). | ||
| select("threshold").head().getDouble(0); | ||
| logReg.setThreshold(bestThreshold); | ||
| logReg.fit(logRegDataFrame); | ||
| {% endhighlight %} | ||
| </div> | ||
|
|
||
| <div data-lang="python" markdown="1"> | ||
| Logistic regression model summary is not yet supported in Python. | ||
| </div> | ||
|
|
||
| </div> | ||
|
|
||
| # Optimization | ||
|
|
||
| The optimization algorithm underlying the implementation is called | ||
| [Orthant-Wise Limited-memory | ||
| QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf) | ||
| (OWL-QN). It is an extension of L-BFGS that can effectively handle L1 | ||
| regularization and elastic net. | ||
|
|
||
| The optimization algorithm underlies the implementation is called [Orthant-Wise Limited-memory QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf) | ||
| (OWL-QN). It is an extension of L-BFGS that can effectively handle L1 regularization and elastic net. | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need to put lambda in, but if you do, then how about putting it outside of big parentheses or brackets to make the equation easier to read?