You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[SPARK-9911] [DOC] [ML] Update Userguide for Evaluator
I added a small note about the different types of evaluator and the metrics used.
Author: MechCoder <manojkumarsivaraj334@gmail.com>
Closes#8304 from MechCoder/multiclass_evaluator.
Copy file name to clipboardExpand all lines: docs/ml-guide.md
+13Lines changed: 13 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -643,6 +643,13 @@ An important task in ML is *model selection*, or using data to find the best mod
643
643
Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator) class, which takes an `Estimator`, a set of `ParamMap`s, and an [`Evaluator`](api/scala/index.html#org.apache.spark.ml.Evaluator).
644
644
`CrossValidator` begins by splitting the dataset into a set of *folds* which are used as separate training and test datasets; e.g., with `$k=3$` folds, `CrossValidator` will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing.
645
645
`CrossValidator` iterates through the set of `ParamMap`s. For each `ParamMap`, it trains the given `Estimator` and evaluates it using the given `Evaluator`.
646
+
647
+
The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.RegressionEvaluator)
648
+
for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.BinaryClassificationEvaluator)
649
+
for binary data or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.MultiClassClassificationEvaluator)
650
+
for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the setMetric
651
+
method in each of these evaluators.
652
+
646
653
The `ParamMap` which produces the best evaluation metric (averaged over the `$k$` folds) is selected as the best model.
647
654
`CrossValidator` finally fits the `Estimator` using the best `ParamMap` and the entire dataset.
648
655
@@ -708,9 +715,12 @@ val pipeline = new Pipeline()
708
715
// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
709
716
// This will allow us to jointly choose parameters for all Pipeline stages.
710
717
// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
718
+
// Note that the evaluator here is a BinaryClassificationEvaluator and the default metric
719
+
// used is areaUnderROC.
711
720
val crossval = new CrossValidator()
712
721
.setEstimator(pipeline)
713
722
.setEvaluator(new BinaryClassificationEvaluator)
723
+
714
724
// We use a ParamGridBuilder to construct a grid of parameters to search over.
715
725
// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,
716
726
// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.
@@ -831,9 +841,12 @@ Pipeline pipeline = new Pipeline()
831
841
// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
832
842
// This will allow us to jointly choose parameters for all Pipeline stages.
833
843
// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
844
+
// Note that the evaluator here is a BinaryClassificationEvaluator and the default metric
0 commit comments