-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-9906] [ML] User guide for LogisticRegressionSummary #8197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
487b361
56cb35b
9831270
1ab3d9c
9825b14
1dfe7f6
4060c5b
83d229f
7bf922c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -809,12 +809,9 @@ loss per iteration which will provide an intuition on overfitting and metrics to | |
| how well the model has performed on training and test data. | ||
|
|
||
| [`LogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionTrainingSummary) | ||
| provides an interface to access such relevant information. i.e the objectiveHistory and metrics | ||
| provides an interface to access such relevant information. i.e the `objectiveHistory` and metrics | ||
| to evaluate the performance on the training data directly with very less code to be rewritten by | ||
| the user. In the future, a method would be made available in the fitted | ||
| [`LogisticRegressionModel`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel) to obtain | ||
| a [`LogisticRegressionSummary`](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionSummary) | ||
| of the test data as well. | ||
| the user. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should document that logistic regression in ML currently [only supports two classes|https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala#L259] in the user guide (hence the casts to
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should also mention that the predictions are transient and that the summary is only available on the driver |
||
|
|
||
| This examples illustrates the use of `LogisticRegressionTrainingSummary` on some toy data. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. examples -> example |
||
|
|
||
|
|
@@ -868,20 +865,21 @@ roc.show() | |
| roc.select("FPR").show() | ||
| println(binarySummary.areaUnderROC) | ||
|
|
||
| // Obtain the threshold with the highest fMeasure. | ||
| // Print all threshold, fMeasure pairs. | ||
| val fMeasure = binarySummary.fMeasureByThreshold | ||
| val fScoreRDD = fMeasure.map { case Row(thresh: Double, fscore: Double) => (thresh, fscore) } | ||
| val (highThresh, highFScore) = fScoreRDD.fold((0.0, 0.0))((threshFScore1, threshFScore2) => { | ||
| if (threshFScore1._2 > threshFScore2._2) threshFScore1 else threshFScore2 | ||
| }) | ||
| fMeasure.foreach { case Row(thresh: Double, fscore: Double) => | ||
| println(s"Threshold: $thresh, F-Measure: $fscore") } | ||
|
|
||
| {% endhighlight %} | ||
| </div> | ||
|
|
||
| <div data-lang="java"> | ||
| {% highlight java %} | ||
| import com.google.common.collect.Lists; | ||
|
|
||
| import org.apache.spark.SparkConf; | ||
| import org.apache.spark.api.java.JavaRDD; | ||
| import org.apache.spark.api.java.JavaSparkContext; | ||
| import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary; | ||
| import org.apache.spark.ml.classification.LogisticRegression; | ||
| import org.apache.spark.ml.classification.LogisticRegressionModel; | ||
|
|
@@ -890,6 +888,11 @@ import org.apache.spark.mllib.regression.LabeledPoint; | |
| import org.apache.spark.mllib.linalg.Vectors; | ||
| import org.apache.spark.sql.DataFrame; | ||
| import org.apache.spark.sql.Row; | ||
| import org.apache.spark.sql.SQLContext; | ||
|
|
||
| SparkConf conf = new SparkConf().setAppName("LogisticRegressionSummary"); | ||
| JavaSparkContext jsc = new JavaSparkContext(conf); | ||
| SQLContext jsql = new SQLContext(jsc); | ||
|
|
||
| // Use some random data for demonstration. | ||
| // Note that the RDD of LabeledPoints can be converted to a dataframe directly. | ||
|
|
@@ -929,18 +932,15 @@ roc.show(); | |
| roc.select("FPR").show(); | ||
| System.out.println(binarySummary.areaUnderROC()); | ||
|
|
||
| // Obtain the threshold with the highest fMeasure. | ||
| // Print all threshold, fMeasure pairs. | ||
| DataFrame fMeasure = binarySummary.fMeasureByThreshold(); | ||
|
|
||
|
|
||
| {% highlight %} | ||
| for (Row r: fMeasure.collect()) { | ||
| System.out.println("Threshold: " + r.get(0) + ", F-Measure: " + r.get(1)); | ||
| } | ||
| {% endhighlight %} | ||
| </div> | ||
|
|
||
| </div> | ||
|
|
||
|
|
||
|
|
||
|
|
||
| # Dependencies | ||
|
|
||
| Spark ML currently depends on MLlib and has the same dependencies. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this links the scala api doc, move this under the scala codetab and add another one for the java api doc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
example