diff --git a/docs/ml-linear-methods.md b/docs/ml-linear-methods.md index 12b95144b5b1..85045db0528c 100644 --- a/docs/ml-linear-methods.md +++ b/docs/ml-linear-methods.md @@ -23,20 +23,41 @@ displayTitle: ML - Linear Methods \]` -In MLlib, we implement popular linear methods such as logistic regression and linear least squares with L1 or L2 regularization. Refer to [the linear methods in mllib](mllib-linear-methods.html) for details. In `spark.ml`, we also include Pipelines API for [Elastic net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid of L1 and L2 regularization proposed in [this paper](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). Mathematically it is defined as a linear combination of the L1-norm and the L2-norm: +In MLlib, we implement popular linear methods such as logistic +regression and linear least squares with $L_1$ or $L_2$ regularization. +Refer to [the linear methods in mllib](mllib-linear-methods.html) for +details. In `spark.ml`, we also include Pipelines API for [Elastic +net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid +of $L_1$ and $L_2$ regularization proposed in [Zou et al, Regularization +and variable selection via the elastic +net](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). +Mathematically, it is defined as a convex combination of the $L_1$ and +the $L_2$ regularization terms: `\[ -\alpha \|\wv\|_1 + (1-\alpha) \frac{1}{2}\|\wv\|_2^2, \alpha \in [0, 1]. +\alpha~\lambda \|\wv\|_1 + (1-\alpha) \frac{\lambda}{2}\|\wv\|_2^2, \alpha \in [0, 1], \lambda \geq 0. \]` -By setting $\alpha$ properly, it contains both L1 and L2 regularization as special cases. For example, if a [linear regression](https://en.wikipedia.org/wiki/Linear_regression) model is trained with the elastic net parameter $\alpha$ set to $1$, it is equivalent to a [Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model. On the other hand, if $\alpha$ is set to $0$, the trained model reduces to a [ridge regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model. We implement Pipelines API for both linear regression and logistic regression with elastic net regularization. - -**Examples** +By setting $\alpha$ properly, elastic net contains both $L_1$ and $L_2$ +regularization as special cases. For example, if a [linear +regression](https://en.wikipedia.org/wiki/Linear_regression) model is +trained with the elastic net parameter $\alpha$ set to $1$, it is +equivalent to a +[Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model. +On the other hand, if $\alpha$ is set to $0$, the trained model reduces +to a [ridge +regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model. +We implement Pipelines API for both linear regression and logistic +regression with elastic net regularization. + +## Example: Logistic Regression + +The following example shows how to train a logistic regression model +with elastic net regularization. `elasticNetParam` corresponds to +$\alpha$ and `regParam` corresponds to $\lambda$.
- {% highlight scala %} - import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.mllib.util.MLUtils @@ -53,15 +74,11 @@ val lrModel = lr.fit(training) // Print the weights and intercept for logistic regression println(s"Weights: ${lrModel.weights} Intercept: ${lrModel.intercept}") - {% endhighlight %} -
- {% highlight java %} - import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.classification.LogisticRegressionModel; import org.apache.spark.mllib.regression.LabeledPoint; @@ -99,9 +116,7 @@ public class LogisticRegressionWithElasticNetExample {
- {% highlight python %} - from pyspark.ml.classification import LogisticRegression from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.util import MLUtils @@ -118,67 +133,33 @@ lrModel = lr.fit(training) print("Weights: " + str(lrModel.weights)) print("Intercept: " + str(lrModel.intercept)) {% endhighlight %} -
-### Optimization - -The optimization algorithm underlies the implementation is called [Orthant-Wise Limited-memory QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf) -(OWL-QN). It is an extension of L-BFGS that can effectively handle L1 regularization and elastic net. - -### Model Summaries +The `spark.ml` implementation of logistic regression also supports +extracting a summary of the model over the training set. Note that the +predictions and metrics which are stored as `Datafram`s in +`BinaryLogisticRegressionSummary` are annoted `@transient` and hence +only available on the driver. -Once a linear model is fit on data, it is useful to extract statistics such as the -loss per iteration and metrics to understand how well the model has performed on training -and test data. The examples provided below will help in understanding how to use the summaries -obtained by the summary method of the fitted linear models. +
-Note that the predictions and metrics which are stored as dataframes obtained from the summary -are transient and are not available on the driver. This is because these are as expensive -to store as the original data itself. +
-#### Example: Summary for LogisticRegression -
-
[`LogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary) -provides an interface to access information such as `objectiveHistory` and metrics -to evaluate the performance on the training data directly with very less code to be rewritten by -the user. [`LogisticRegression`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegression) -currently supports only binary classification and hence in order to access the binary metrics -the summary must be explicitly cast to -[BinaryLogisticRegressionTrainingSummary](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary) -as done in the code below. This avoids raising errors for multiclass outputs while providing -extensiblity when multiclass classification is supported in the future. +provides a summary for a +[`LogisticRegressionModel`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel). +Currently, only binary classification is supported and the +summary must be explicitly cast to +[`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary). +This will likely change when multiclass classification is supported. -This example illustrates the use of `LogisticRegressionTrainingSummary` on some toy data. +Continuing the earlier example: {% highlight scala %} -import org.apache.spark.ml.classification.{LogisticRegression, BinaryLogisticRegressionSummary} -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.sql.Row - -// Use some random data for demonstration. -// Note that the RDD of LabeledPoints can be converted to a dataframe directly. -val data = sc.parallelize(Array( - LabeledPoint(0.0, Vectors.dense(0.2, 4.5, 1.6)), - LabeledPoint(1.0, Vectors.dense(3.1, 6.8, 3.6)), - LabeledPoint(0.0, Vectors.dense(2.4, 0.9, 1.9)), - LabeledPoint(1.0, Vectors.dense(9.1, 3.1, 3.6)), - LabeledPoint(0.0, Vectors.dense(2.5, 1.9, 9.1))) -) -val logRegDataFrame = data.toDF() - -// Run Logistic Regression on your toy data. -// Since LogisticRegression is an estimator, it returns an instance of LogisticRegressionModel -// which is a transformer. -val logReg = new LogisticRegression().setMaxIter(5).setRegParam(0.01) -val logRegModel = logReg.fit(logRegDataFrame) - -// Extract the summary directly from the returned LogisticRegressionModel instance. -val trainingSummary = logRegModel.summary +// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example +val trainingSummary = lrModel.summary // Obtain the loss per iteration. val objectiveHistory = trainingSummary.objectiveHistory @@ -206,60 +187,30 @@ logReg.fit(logRegDataFrame) {% endhighlight %}
-
-[`LogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary) -provides an interface to access information such as `objectiveHistory` and metrics -to evaluate the performance on the training data directly with very less code to be rewritten by -the user. [`LogisticRegression`](api/java/org/apache/spark/ml/classification/LogisticRegression) -currently supports only binary classification and hence in order to access the binary metrics -the summary must be explicitly cast to -[BinaryLogisticRegressionTrainingSummary](api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary) -as done in the code below. This avoids raising errors for multiclass outputs while providing -extensiblity when multiclass classification is supported in the future +
+[`LogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html) +provides a summary for a +[`LogisticRegressionModel`](api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html). +Currently, only binary classification is supported and the +summary must be explicitly cast to +[`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html). +This will likely change when multiclass classification is supported. -This example illustrates the use of `LogisticRegressionTrainingSummary` on some toy data. +Continuing the earlier example: {% highlight java %} -import com.google.common.collect.Lists; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary; -import org.apache.spark.ml.classification.LogisticRegression; -import org.apache.spark.ml.classification.LogisticRegressionModel; -import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import static org.apache.spark.sql.functions.*; - -// Use some random data for demonstration. -// Note that the RDD of LabeledPoints can be converted to a dataframe directly. -JavaRDD data = sc.parallelize(Lists.newArrayList( - new LabeledPoint(0.0, Vectors.dense(0.2, 4.5, 1.6)), - new LabeledPoint(1.0, Vectors.dense(3.1, 6.8, 3.6)), - new LabeledPoint(0.0, Vectors.dense(2.4, 0.9, 1.9)), - new LabeledPoint(1.0, Vectors.dense(9.1, 3.1, 3.6)), - new LabeledPoint(0.0, Vectors.dense(2.5, 1.9, 9.1))) -); -DataFrame logRegDataFrame = sql.createDataFrame(data, LabeledPoint.class); - -// Run Logistic Regression on your toy data. -// Since LogisticRegression is an estimator, it returns an instance of LogisticRegressionModel -// which is a transformer. -LogisticRegression logReg = new LogisticRegression().setMaxIter(5).setRegParam(0.01); -LogisticRegressionModel logRegModel = logReg.fit(logRegDataFrame); - -// Extract the summary directly from the returned LogisticRegressionModel instance. +// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example LogisticRegressionTrainingSummary trainingSummary = logRegModel.summary(); // Obtain the loss per iteration. double[] objectiveHistory = trainingSummary.objectiveHistory(); -for (double lossPerIteration: objectiveHistory) { +for (double lossPerIteration : objectiveHistory) { System.out.println(lossPerIteration); } // Obtain the metrics useful to judge performance on test data. +// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a +// binary classification problem. BinaryLogisticRegressionSummary binarySummary = (BinaryLogisticRegressionSummary) trainingSummary; // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. @@ -278,4 +229,18 @@ logReg.setThreshold(bestThreshold); logReg.fit(logRegDataFrame); {% endhighlight %}
+ +
+Logistic regression model summary is not yet supported in Python. +
+
+ +# Optimization + +The optimization algorithm underlying the implementation is called +[Orthant-Wise Limited-memory +QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf) +(OWL-QN). It is an extension of L-BFGS that can effectively handle L1 +regularization and elastic net. +