Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 133 additions & 16 deletions docs/ml-linear-methods.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,41 @@ displayTitle: <a href="ml-guide.html">ML</a> - Linear Methods
\]`


In MLlib, we implement popular linear methods such as logistic regression and linear least squares with L1 or L2 regularization. Refer to [the linear methods in mllib](mllib-linear-methods.html) for details. In `spark.ml`, we also include Pipelines API for [Elastic net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid of L1 and L2 regularization proposed in [this paper](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). Mathematically it is defined as a linear combination of the L1-norm and the L2-norm:
In MLlib, we implement popular linear methods such as logistic
regression and linear least squares with $L_1$ or $L_2$ regularization.
Refer to [the linear methods in mllib](mllib-linear-methods.html) for
details. In `spark.ml`, we also include Pipelines API for [Elastic
net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid
of $L_1$ and $L_2$ regularization proposed in [Zou et al, Regularization
and variable selection via the elastic
net](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf).
Mathematically, it is defined as a convex combination of the $L_1$ and
the $L_2$ regularization terms:
`\[
\alpha \|\wv\|_1 + (1-\alpha) \frac{1}{2}\|\wv\|_2^2, \alpha \in [0, 1].
\alpha~\lambda \|\wv\|_1 + (1-\alpha) \frac{\lambda}{2}\|\wv\|_2^2, \alpha \in [0, 1], \lambda \geq 0.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need to put lambda in, but if you do, then how about putting it outside of big parentheses or brackets to make the equation easier to read?

\]`
By setting $\alpha$ properly, it contains both L1 and L2 regularization as special cases. For example, if a [linear regression](https://en.wikipedia.org/wiki/Linear_regression) model is trained with the elastic net parameter $\alpha$ set to $1$, it is equivalent to a [Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model. On the other hand, if $\alpha$ is set to $0$, the trained model reduces to a [ridge regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model. We implement Pipelines API for both linear regression and logistic regression with elastic net regularization.

**Examples**
By setting $\alpha$ properly, elastic net contains both $L_1$ and $L_2$
regularization as special cases. For example, if a [linear
regression](https://en.wikipedia.org/wiki/Linear_regression) model is
trained with the elastic net parameter $\alpha$ set to $1$, it is
equivalent to a
[Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model.
On the other hand, if $\alpha$ is set to $0$, the trained model reduces
to a [ridge
regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model.
We implement Pipelines API for both linear regression and logistic
regression with elastic net regularization.

## Example: Logistic Regression

The following example shows how to train a logistic regression model
with elastic net regularization. `elasticNetParam` corresponds to
$\alpha$ and `regParam` corresponds to $\lambda$.

<div class="codetabs">

<div data-lang="scala" markdown="1">

{% highlight scala %}

import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.mllib.util.MLUtils

Expand All @@ -53,15 +74,11 @@ val lrModel = lr.fit(training)

// Print the weights and intercept for logistic regression
println(s"Weights: ${lrModel.weights} Intercept: ${lrModel.intercept}")

{% endhighlight %}

</div>

<div data-lang="java" markdown="1">

{% highlight java %}

import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.mllib.regression.LabeledPoint;
Expand Down Expand Up @@ -99,9 +116,7 @@ public class LogisticRegressionWithElasticNetExample {
</div>

<div data-lang="python" markdown="1">

{% highlight python %}

from pyspark.ml.classification import LogisticRegression
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.util import MLUtils
Expand All @@ -118,12 +133,114 @@ lrModel = lr.fit(training)
print("Weights: " + str(lrModel.weights))
print("Intercept: " + str(lrModel.intercept))
{% endhighlight %}
</div>

</div>

The `spark.ml` implementation of logistic regression also supports
extracting a summary of the model over the training set. Note that the
predictions and metrics which are stored as `Dataframe` in
`BinaryLogisticRegressionSummary` are annotated `@transient` and hence
only available on the driver.

<div class="codetabs">

<div data-lang="scala" markdown="1">

[`LogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary)
provides a summary for a
[`LogisticRegressionModel`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel).
Currently, only binary classification is supported and the
summary must be explicitly cast to
[`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary).
This will likely change when multiclass classification is supported.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the cast will not change. Even when multiclass classification is supported, we will have to do an explicit cast since the binary metrics will not be available in the sealed LogisticRegressionSummary trait. Only the multiclass metrics which will not raise errors for binary data will be common to both.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest we just remove this line. What say?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Downcasting is almost always an indication of a poor abstraction and IMO the stabilized API should not require any explicit typecasting by the end user, here's an explanation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other option would be just to make all metrics available in LogisticRegressionSummaryand raise errors when multiclass data is used. In that case there would have not been a need for a BinaryLogisticRegressionSummary. However this was preferred.

See:
#7538 (comment) and
#7538 (comment) and
#7538 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Synced with @jkbradley offline. Summary:

We should not require end users to perform any sort of downcasting in the stabilized API. This is OK for now since the API is still experimental.

Eventually we could provide two methods, a summary : LogisticRegressionSummary and a binarySummary : BInaryLogisticRegressionSummary which errors when called on a multiclass LRModel. This will be easy to implement because summary is returning the base LogisticRegressionSummary class so will not require any public API change.


Continuing the earlier example:

{% highlight scala %}
// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example
val trainingSummary = lrModel.summary

// Obtain the loss per iteration.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"loss" --> "objective" (which includes regularization)

val objectiveHistory = trainingSummary.objectiveHistory
objectiveHistory.foreach(loss => println(loss))

// Obtain the metrics useful to judge performance on test data.
// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a
// binary classification problem.
val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary]

// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
val roc = binarySummary.roc
roc.show()
roc.select("FPR").show()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd remove this line and instead add a comment to the previous one if you want to say what the column names are.

println(binarySummary.areaUnderROC)

// Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with
// this selected threshold.
val fMeasure = binarySummary.fMeasureByThreshold
val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0)
val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure).
select("threshold").head().getDouble(0)
logReg.setThreshold(bestThreshold)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no need to re-fit the model since the threshold is only used during prediction. Instead, set the threshold in lrModel.

logReg.fit(logRegDataFrame)
{% endhighlight %}
</div>

### Optimization
<div data-lang="java" markdown="1">
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add the needed imports for this new Java section.

[`LogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html)
provides a summary for a
[`LogisticRegressionModel`](api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html).
Currently, only binary classification is supported and the
summary must be explicitly cast to
[`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html).
This will likely change when multiclass classification is supported.

Continuing the earlier example:

{% highlight java %}
// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example
LogisticRegressionTrainingSummary trainingSummary = logRegModel.summary();

// Obtain the loss per iteration.
double[] objectiveHistory = trainingSummary.objectiveHistory();
for (double lossPerIteration : objectiveHistory) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know about Java style but from the other examples I think the colon should be one place to the left

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, see Google's Java Style Guide

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps... the Spark style guide only covers Scala and Python so I was just going by past experience on this change. We could try to get a Java style guide in if there's a community need for it

System.out.println(lossPerIteration);
}

// Obtain the metrics useful to judge performance on test data.
// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a
// binary classification problem.
BinaryLogisticRegressionSummary binarySummary = (BinaryLogisticRegressionSummary) trainingSummary;

// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
DataFrame roc = binarySummary.roc();
roc.show();
roc.select("FPR").show();
System.out.println(binarySummary.areaUnderROC());

// Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with
// this selected threshold.
DataFrame fMeasure = binarySummary.fMeasureByThreshold();
double maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0);
double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure)).
select("threshold").head().getDouble(0);
logReg.setThreshold(bestThreshold);
logReg.fit(logRegDataFrame);
{% endhighlight %}
</div>

<div data-lang="python" markdown="1">
Logistic regression model summary is not yet supported in Python.
</div>

</div>

# Optimization

The optimization algorithm underlying the implementation is called
[Orthant-Wise Limited-memory
QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf)
(OWL-QN). It is an extension of L-BFGS that can effectively handle L1
regularization and elastic net.

The optimization algorithm underlies the implementation is called [Orthant-Wise Limited-memory QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf)
(OWL-QN). It is an extension of L-BFGS that can effectively handle L1 regularization and elastic net.