Skip to content

Commit af0e124

Browse files
Feynman Liangmengxr
authored andcommitted
[SPARK-9905] [ML] [DOC] Adds LinearRegressionSummary user guide
* Adds user guide for `LinearRegressionSummary` * Fixes unresolved issues in #8197 CC jkbradley mengxr Author: Feynman Liang <[email protected]> Closes #8491 from feynmanliang/SPARK-9905.
1 parent 30734d4 commit af0e124

File tree

1 file changed

+127
-13
lines changed

1 file changed

+127
-13
lines changed

docs/ml-linear-methods.md

Lines changed: 127 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ net](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf).
3434
Mathematically, it is defined as a convex combination of the $L_1$ and
3535
the $L_2$ regularization terms:
3636
`\[
37-
\alpha~\lambda \|\wv\|_1 + (1-\alpha) \frac{\lambda}{2}\|\wv\|_2^2, \alpha \in [0, 1], \lambda \geq 0.
37+
\alpha \left( \lambda \|\wv\|_1 \right) + (1-\alpha) \left( \frac{\lambda}{2}\|\wv\|_2^2 \right) , \alpha \in [0, 1], \lambda \geq 0
3838
\]`
3939
By setting $\alpha$ properly, elastic net contains both $L_1$ and $L_2$
4040
regularization as special cases. For example, if a [linear
@@ -95,15 +95,15 @@ public class LogisticRegressionWithElasticNetExample {
9595

9696
SparkContext sc = new SparkContext(conf);
9797
SQLContext sql = new SQLContext(sc);
98-
String path = "sample_libsvm_data.txt";
98+
String path = "data/mllib/sample_libsvm_data.txt";
9999

100100
// Load training data
101101
DataFrame training = sql.createDataFrame(MLUtils.loadLibSVMFile(sc, path).toJavaRDD(), LabeledPoint.class);
102102

103103
LogisticRegression lr = new LogisticRegression()
104104
.setMaxIter(10)
105105
.setRegParam(0.3)
106-
.setElasticNetParam(0.8)
106+
.setElasticNetParam(0.8);
107107

108108
// Fit the model
109109
LogisticRegressionModel lrModel = lr.fit(training);
@@ -158,10 +158,12 @@ This will likely change when multiclass classification is supported.
158158
Continuing the earlier example:
159159

160160
{% highlight scala %}
161+
import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary
162+
161163
// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example
162164
val trainingSummary = lrModel.summary
163165

164-
// Obtain the loss per iteration.
166+
// Obtain the objective per iteration.
165167
val objectiveHistory = trainingSummary.objectiveHistory
166168
objectiveHistory.foreach(loss => println(loss))
167169

@@ -173,17 +175,14 @@ val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary
173175
// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
174176
val roc = binarySummary.roc
175177
roc.show()
176-
roc.select("FPR").show()
177178
println(binarySummary.areaUnderROC)
178179

179-
// Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with
180-
// this selected threshold.
180+
// Set the model threshold to maximize F-Measure
181181
val fMeasure = binarySummary.fMeasureByThreshold
182182
val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0)
183183
val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure).
184184
select("threshold").head().getDouble(0)
185-
logReg.setThreshold(bestThreshold)
186-
logReg.fit(logRegDataFrame)
185+
lrModel.setThreshold(bestThreshold)
187186
{% endhighlight %}
188187
</div>
189188

@@ -199,8 +198,12 @@ This will likely change when multiclass classification is supported.
199198
Continuing the earlier example:
200199

201200
{% highlight java %}
201+
import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary;
202+
import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary;
203+
import org.apache.spark.sql.functions;
204+
202205
// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example
203-
LogisticRegressionTrainingSummary trainingSummary = logRegModel.summary();
206+
LogisticRegressionTrainingSummary trainingSummary = lrModel.summary();
204207

205208
// Obtain the loss per iteration.
206209
double[] objectiveHistory = trainingSummary.objectiveHistory();
@@ -222,20 +225,131 @@ System.out.println(binarySummary.areaUnderROC());
222225
// Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with
223226
// this selected threshold.
224227
DataFrame fMeasure = binarySummary.fMeasureByThreshold();
225-
double maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0);
228+
double maxFMeasure = fMeasure.select(functions.max("F-Measure")).head().getDouble(0);
226229
double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure)).
227230
select("threshold").head().getDouble(0);
228-
logReg.setThreshold(bestThreshold);
229-
logReg.fit(logRegDataFrame);
231+
lrModel.setThreshold(bestThreshold);
230232
{% endhighlight %}
231233
</div>
232234

235+
<!--- TODO: Add python model summaries once implemented -->
233236
<div data-lang="python" markdown="1">
234237
Logistic regression model summary is not yet supported in Python.
235238
</div>
236239

237240
</div>
238241

242+
## Example: Linear Regression
243+
244+
The interface for working with linear regression models and model
245+
summaries is similar to the logistic regression case. The following
246+
example demonstrates training an elastic net regularized linear
247+
regression model and extracting model summary statistics.
248+
249+
<div class="codetabs">
250+
251+
<div data-lang="scala" markdown="1">
252+
{% highlight scala %}
253+
import org.apache.spark.ml.regression.LinearRegression
254+
import org.apache.spark.mllib.util.MLUtils
255+
256+
// Load training data
257+
val training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
258+
259+
val lr = new LinearRegression()
260+
.setMaxIter(10)
261+
.setRegParam(0.3)
262+
.setElasticNetParam(0.8)
263+
264+
// Fit the model
265+
val lrModel = lr.fit(training)
266+
267+
// Print the weights and intercept for linear regression
268+
println(s"Weights: ${lrModel.weights} Intercept: ${lrModel.intercept}")
269+
270+
// Summarize the model over the training set and print out some metrics
271+
val trainingSummary = lrModel.summary
272+
println(s"numIterations: ${trainingSummary.totalIterations}")
273+
println(s"objectiveHistory: ${trainingSummary.objectiveHistory.toList}")
274+
trainingSummary.residuals.show()
275+
println(s"RMSE: ${trainingSummary.rootMeanSquaredError}")
276+
println(s"r2: ${trainingSummary.r2}")
277+
{% endhighlight %}
278+
</div>
279+
280+
<div data-lang="java" markdown="1">
281+
{% highlight java %}
282+
import org.apache.spark.ml.regression.LinearRegression;
283+
import org.apache.spark.ml.regression.LinearRegressionModel;
284+
import org.apache.spark.ml.regression.LinearRegressionTrainingSummary;
285+
import org.apache.spark.mllib.linalg.Vectors;
286+
import org.apache.spark.mllib.regression.LabeledPoint;
287+
import org.apache.spark.mllib.util.MLUtils;
288+
import org.apache.spark.SparkConf;
289+
import org.apache.spark.SparkContext;
290+
import org.apache.spark.sql.DataFrame;
291+
import org.apache.spark.sql.SQLContext;
292+
293+
public class LinearRegressionWithElasticNetExample {
294+
public static void main(String[] args) {
295+
SparkConf conf = new SparkConf()
296+
.setAppName("Linear Regression with Elastic Net Example");
297+
298+
SparkContext sc = new SparkContext(conf);
299+
SQLContext sql = new SQLContext(sc);
300+
String path = "data/mllib/sample_libsvm_data.txt";
301+
302+
// Load training data
303+
DataFrame training = sql.createDataFrame(MLUtils.loadLibSVMFile(sc, path).toJavaRDD(), LabeledPoint.class);
304+
305+
LinearRegression lr = new LinearRegression()
306+
.setMaxIter(10)
307+
.setRegParam(0.3)
308+
.setElasticNetParam(0.8);
309+
310+
// Fit the model
311+
LinearRegressionModel lrModel = lr.fit(training);
312+
313+
// Print the weights and intercept for linear regression
314+
System.out.println("Weights: " + lrModel.weights() + " Intercept: " + lrModel.intercept());
315+
316+
// Summarize the model over the training set and print out some metrics
317+
LinearRegressionTrainingSummary trainingSummary = lrModel.summary();
318+
System.out.println("numIterations: " + trainingSummary.totalIterations());
319+
System.out.println("objectiveHistory: " + Vectors.dense(trainingSummary.objectiveHistory()));
320+
trainingSummary.residuals().show();
321+
System.out.println("RMSE: " + trainingSummary.rootMeanSquaredError());
322+
System.out.println("r2: " + trainingSummary.r2());
323+
}
324+
}
325+
{% endhighlight %}
326+
</div>
327+
328+
<div data-lang="python" markdown="1">
329+
<!--- TODO: Add python model summaries once implemented -->
330+
{% highlight python %}
331+
from pyspark.ml.regression import LinearRegression
332+
from pyspark.mllib.regression import LabeledPoint
333+
from pyspark.mllib.util import MLUtils
334+
335+
# Load training data
336+
training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
337+
338+
lr = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8)
339+
340+
# Fit the model
341+
lrModel = lr.fit(training)
342+
343+
# Print the weights and intercept for linear regression
344+
print("Weights: " + str(lrModel.weights))
345+
print("Intercept: " + str(lrModel.intercept))
346+
347+
# Linear regression model summary is not yet supported in Python.
348+
{% endhighlight %}
349+
</div>
350+
351+
</div>
352+
239353
# Optimization
240354

241355
The optimization algorithm underlying the implementation is called

0 commit comments

Comments
 (0)