Skip to content
Closed
Prev Previous commit
Next Next commit
remove threshold selection
  • Loading branch information
MechCoder committed Aug 16, 2015
commit 983127077bf44c55fd85d0ca9bd206f3dd3b74fe
36 changes: 18 additions & 18 deletions docs/ml-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -809,12 +809,9 @@ loss per iteration which will provide an intuition on overfitting and metrics to
how well the model has performed on training and test data.

[`LogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionTrainingSummary)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this links the scala api doc, move this under the scala codetab and add another one for the java api doc

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

provides an interface to access such relevant information. i.e the objectiveHistory and metrics
provides an interface to access such relevant information. i.e the `objectiveHistory` and metrics
to evaluate the performance on the training data directly with very less code to be rewritten by
the user. In the future, a method would be made available in the fitted
[`LogisticRegressionModel`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel) to obtain
a [`LogisticRegressionSummary`](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionSummary)
of the test data as well.
the user.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should document that logistic regression in ML currently [only supports two classes|https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala#L259] in the user guide (hence the casts to BinaryLogisticRegression), and that the traits are there for future extensibility purposes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also mention that the predictions are transient and that the summary is only available on the driver


This examples illustrates the use of `LogisticRegressionTrainingSummary` on some toy data.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

examples -> example


Expand Down Expand Up @@ -868,20 +865,21 @@ roc.show()
roc.select("FPR").show()
println(binarySummary.areaUnderROC)

// Obtain the threshold with the highest fMeasure.
// Print all threshold, fMeasure pairs.
val fMeasure = binarySummary.fMeasureByThreshold
val fScoreRDD = fMeasure.map { case Row(thresh: Double, fscore: Double) => (thresh, fscore) }
val (highThresh, highFScore) = fScoreRDD.fold((0.0, 0.0))((threshFScore1, threshFScore2) => {
if (threshFScore1._2 > threshFScore2._2) threshFScore1 else threshFScore2
})
fMeasure.foreach { case Row(thresh: Double, fscore: Double) =>
println(s"Threshold: $thresh, F-Measure: $fscore") }

{% endhighlight %}
</div>

<div data-lang="java">
{% highlight java %}
import com.google.common.collect.Lists;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel;
Expand All @@ -890,6 +888,11 @@ import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;

SparkConf conf = new SparkConf().setAppName("LogisticRegressionSummary");
JavaSparkContext jsc = new JavaSparkContext(conf);
SQLContext jsql = new SQLContext(jsc);

// Use some random data for demonstration.
// Note that the RDD of LabeledPoints can be converted to a dataframe directly.
Expand Down Expand Up @@ -929,18 +932,15 @@ roc.show();
roc.select("FPR").show();
System.out.println(binarySummary.areaUnderROC());

// Obtain the threshold with the highest fMeasure.
// Print all threshold, fMeasure pairs.
DataFrame fMeasure = binarySummary.fMeasureByThreshold();


{% highlight %}
for (Row r: fMeasure.collect()) {
System.out.println("Threshold: " + r.get(0) + ", F-Measure: " + r.get(1));
}
{% endhighlight %}
</div>

</div>




# Dependencies

Spark ML currently depends on MLlib and has the same dependencies.
Expand Down