@@ -34,7 +34,7 @@ net](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf).
3434Mathematically, it is defined as a convex combination of the $L_1$ and
3535the $L_2$ regularization terms:
3636`\[
37- \alpha~ \ lambda \| \wv\| _ 1 + (1-\alpha) \frac{\lambda}{2}\| \wv\| _ 2^2, \alpha \in [ 0, 1] , \lambda \geq 0.
37+ \alpha \left( \ lambda \| \wv\| _ 1 \right) + (1-\alpha) \left( \ frac{\lambda}{2}\| \wv\| _ 2^2 \right) , \alpha \in [ 0, 1] , \lambda \geq 0
3838\] `
3939By setting $\alpha$ properly, elastic net contains both $L_1$ and $L_2$
4040regularization as special cases. For example, if a [ linear
@@ -95,15 +95,15 @@ public class LogisticRegressionWithElasticNetExample {
9595
9696 SparkContext sc = new SparkContext(conf);
9797 SQLContext sql = new SQLContext(sc);
98- String path = "sample_libsvm_data.txt";
98+ String path = "data/mllib/ sample_libsvm_data.txt";
9999
100100 // Load training data
101101 DataFrame training = sql.createDataFrame(MLUtils.loadLibSVMFile(sc, path).toJavaRDD(), LabeledPoint.class);
102102
103103 LogisticRegression lr = new LogisticRegression()
104104 .setMaxIter(10)
105105 .setRegParam(0.3)
106- .setElasticNetParam(0.8)
106+ .setElasticNetParam(0.8);
107107
108108 // Fit the model
109109 LogisticRegressionModel lrModel = lr.fit(training);
@@ -158,10 +158,12 @@ This will likely change when multiclass classification is supported.
158158Continuing the earlier example:
159159
160160{% highlight scala %}
161+ import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary
162+
161163// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example
162164val trainingSummary = lrModel.summary
163165
164- // Obtain the loss per iteration.
166+ // Obtain the objective per iteration.
165167val objectiveHistory = trainingSummary.objectiveHistory
166168objectiveHistory.foreach(loss => println(loss))
167169
@@ -173,17 +175,14 @@ val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary
173175// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
174176val roc = binarySummary.roc
175177roc.show()
176- roc.select("FPR").show()
177178println(binarySummary.areaUnderROC)
178179
179- // Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with
180- // this selected threshold.
180+ // Set the model threshold to maximize F-Measure
181181val fMeasure = binarySummary.fMeasureByThreshold
182182val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0)
183183val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure).
184184 select("threshold").head().getDouble(0)
185- logReg.setThreshold(bestThreshold)
186- logReg.fit(logRegDataFrame)
185+ lrModel.setThreshold(bestThreshold)
187186{% endhighlight %}
188187</div >
189188
@@ -199,8 +198,12 @@ This will likely change when multiclass classification is supported.
199198Continuing the earlier example:
200199
201200{% highlight java %}
201+ import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary;
202+ import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary;
203+ import org.apache.spark.sql.functions;
204+
202205// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example
203- LogisticRegressionTrainingSummary trainingSummary = logRegModel .summary();
206+ LogisticRegressionTrainingSummary trainingSummary = lrModel .summary();
204207
205208// Obtain the loss per iteration.
206209double[ ] objectiveHistory = trainingSummary.objectiveHistory();
@@ -222,20 +225,131 @@ System.out.println(binarySummary.areaUnderROC());
222225// Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with
223226// this selected threshold.
224227DataFrame fMeasure = binarySummary.fMeasureByThreshold();
225- double maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0);
228+ double maxFMeasure = fMeasure.select(functions. max("F-Measure")).head().getDouble(0);
226229double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure)).
227230 select("threshold").head().getDouble(0);
228- logReg.setThreshold(bestThreshold);
229- logReg.fit(logRegDataFrame);
231+ lrModel.setThreshold(bestThreshold);
230232{% endhighlight %}
231233</div >
232234
235+ <!-- - TODO: Add python model summaries once implemented -->
233236<div data-lang =" python " markdown =" 1 " >
234237Logistic regression model summary is not yet supported in Python.
235238</div >
236239
237240</div >
238241
242+ ## Example: Linear Regression
243+
244+ The interface for working with linear regression models and model
245+ summaries is similar to the logistic regression case. The following
246+ example demonstrates training an elastic net regularized linear
247+ regression model and extracting model summary statistics.
248+
249+ <div class =" codetabs " >
250+
251+ <div data-lang =" scala " markdown =" 1 " >
252+ {% highlight scala %}
253+ import org.apache.spark.ml.regression.LinearRegression
254+ import org.apache.spark.mllib.util.MLUtils
255+
256+ // Load training data
257+ val training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
258+
259+ val lr = new LinearRegression()
260+ .setMaxIter(10)
261+ .setRegParam(0.3)
262+ .setElasticNetParam(0.8)
263+
264+ // Fit the model
265+ val lrModel = lr.fit(training)
266+
267+ // Print the weights and intercept for linear regression
268+ println(s"Weights: ${lrModel.weights} Intercept: ${lrModel.intercept}")
269+
270+ // Summarize the model over the training set and print out some metrics
271+ val trainingSummary = lrModel.summary
272+ println(s"numIterations: ${trainingSummary.totalIterations}")
273+ println(s"objectiveHistory: ${trainingSummary.objectiveHistory.toList}")
274+ trainingSummary.residuals.show()
275+ println(s"RMSE: ${trainingSummary.rootMeanSquaredError}")
276+ println(s"r2: ${trainingSummary.r2}")
277+ {% endhighlight %}
278+ </div >
279+
280+ <div data-lang =" java " markdown =" 1 " >
281+ {% highlight java %}
282+ import org.apache.spark.ml.regression.LinearRegression;
283+ import org.apache.spark.ml.regression.LinearRegressionModel;
284+ import org.apache.spark.ml.regression.LinearRegressionTrainingSummary;
285+ import org.apache.spark.mllib.linalg.Vectors;
286+ import org.apache.spark.mllib.regression.LabeledPoint;
287+ import org.apache.spark.mllib.util.MLUtils;
288+ import org.apache.spark.SparkConf;
289+ import org.apache.spark.SparkContext;
290+ import org.apache.spark.sql.DataFrame;
291+ import org.apache.spark.sql.SQLContext;
292+
293+ public class LinearRegressionWithElasticNetExample {
294+ public static void main(String[ ] args) {
295+ SparkConf conf = new SparkConf()
296+ .setAppName("Linear Regression with Elastic Net Example");
297+
298+ SparkContext sc = new SparkContext(conf);
299+ SQLContext sql = new SQLContext(sc);
300+ String path = "data/mllib/sample_libsvm_data.txt";
301+
302+ // Load training data
303+ DataFrame training = sql.createDataFrame(MLUtils.loadLibSVMFile(sc, path).toJavaRDD(), LabeledPoint.class);
304+
305+ LinearRegression lr = new LinearRegression()
306+ .setMaxIter(10)
307+ .setRegParam(0.3)
308+ .setElasticNetParam(0.8);
309+
310+ // Fit the model
311+ LinearRegressionModel lrModel = lr.fit(training);
312+
313+ // Print the weights and intercept for linear regression
314+ System.out.println("Weights: " + lrModel.weights() + " Intercept: " + lrModel.intercept());
315+
316+ // Summarize the model over the training set and print out some metrics
317+ LinearRegressionTrainingSummary trainingSummary = lrModel.summary();
318+ System.out.println("numIterations: " + trainingSummary.totalIterations());
319+ System.out.println("objectiveHistory: " + Vectors.dense(trainingSummary.objectiveHistory()));
320+ trainingSummary.residuals().show();
321+ System.out.println("RMSE: " + trainingSummary.rootMeanSquaredError());
322+ System.out.println("r2: " + trainingSummary.r2());
323+ }
324+ }
325+ {% endhighlight %}
326+ </div >
327+
328+ <div data-lang =" python " markdown =" 1 " >
329+ <!-- - TODO: Add python model summaries once implemented -->
330+ {% highlight python %}
331+ from pyspark.ml.regression import LinearRegression
332+ from pyspark.mllib.regression import LabeledPoint
333+ from pyspark.mllib.util import MLUtils
334+
335+ # Load training data
336+ training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
337+
338+ lr = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8)
339+
340+ # Fit the model
341+ lrModel = lr.fit(training)
342+
343+ # Print the weights and intercept for linear regression
344+ print("Weights: " + str(lrModel.weights))
345+ print("Intercept: " + str(lrModel.intercept))
346+
347+ # Linear regression model summary is not yet supported in Python.
348+ {% endhighlight %}
349+ </div >
350+
351+ </div >
352+
239353# Optimization
240354
241355The optimization algorithm underlying the implementation is called
0 commit comments