Skip to content

Commit 30734d4

Browse files
MechCodermengxr
authored andcommitted
[SPARK-9911] [DOC] [ML] Update Userguide for Evaluator
I added a small note about the different types of evaluator and the metrics used. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #8304 from MechCoder/multiclass_evaluator.
1 parent 1f90c5e commit 30734d4

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

docs/ml-guide.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,13 @@ An important task in ML is *model selection*, or using data to find the best mod
643643
Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator) class, which takes an `Estimator`, a set of `ParamMap`s, and an [`Evaluator`](api/scala/index.html#org.apache.spark.ml.Evaluator).
644644
`CrossValidator` begins by splitting the dataset into a set of *folds* which are used as separate training and test datasets; e.g., with `$k=3$` folds, `CrossValidator` will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing.
645645
`CrossValidator` iterates through the set of `ParamMap`s. For each `ParamMap`, it trains the given `Estimator` and evaluates it using the given `Evaluator`.
646+
647+
The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.RegressionEvaluator)
648+
for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.BinaryClassificationEvaluator)
649+
for binary data or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.MultiClassClassificationEvaluator)
650+
for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the setMetric
651+
method in each of these evaluators.
652+
646653
The `ParamMap` which produces the best evaluation metric (averaged over the `$k$` folds) is selected as the best model.
647654
`CrossValidator` finally fits the `Estimator` using the best `ParamMap` and the entire dataset.
648655

@@ -708,9 +715,12 @@ val pipeline = new Pipeline()
708715
// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
709716
// This will allow us to jointly choose parameters for all Pipeline stages.
710717
// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
718+
// Note that the evaluator here is a BinaryClassificationEvaluator and the default metric
719+
// used is areaUnderROC.
711720
val crossval = new CrossValidator()
712721
.setEstimator(pipeline)
713722
.setEvaluator(new BinaryClassificationEvaluator)
723+
714724
// We use a ParamGridBuilder to construct a grid of parameters to search over.
715725
// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,
716726
// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.
@@ -831,9 +841,12 @@ Pipeline pipeline = new Pipeline()
831841
// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
832842
// This will allow us to jointly choose parameters for all Pipeline stages.
833843
// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
844+
// Note that the evaluator here is a BinaryClassificationEvaluator and the default metric
845+
// used is areaUnderROC.
834846
CrossValidator crossval = new CrossValidator()
835847
.setEstimator(pipeline)
836848
.setEvaluator(new BinaryClassificationEvaluator());
849+
837850
// We use a ParamGridBuilder to construct a grid of parameters to search over.
838851
// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,
839852
// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.

0 commit comments

Comments
 (0)