-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-18724][ML] Add TuningSummary for TrainValidationSplit and CrossValidator #16158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
d1e22d5
a7cfa63
ad73c12
425a419
bd18c00
2a0af1d
1a594d0
0a698fe
bd459b1
4e3e19c
c0bc81a
bbf3f9f
b6a7c53
72aea62
91da358
297091f
670467a
36b1dd5
5844e0c
8c829b5
4aaf8e5
ceaad1c
41c4c12
4aef3aa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,12 +23,13 @@ import org.json4s.jackson.JsonMethods._ | |
|
|
||
| import org.apache.spark.SparkContext | ||
| import org.apache.spark.ml.{Estimator, Model} | ||
| import org.apache.spark.ml.evaluation.Evaluator | ||
| import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, MulticlassClassificationEvaluator, RegressionEvaluator} | ||
| import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} | ||
| import org.apache.spark.ml.param.shared.HasSeed | ||
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.ml.util.DefaultParamsReader.Metadata | ||
| import org.apache.spark.sql.types.StructType | ||
| import org.apache.spark.sql.{DataFrame, Row, SparkSession} | ||
| import org.apache.spark.sql.types.{StringType, StructField, StructType} | ||
|
|
||
| /** | ||
| * Common params for [[TrainValidationSplitParams]] and [[CrossValidatorParams]]. | ||
|
|
@@ -85,6 +86,32 @@ private[ml] trait ValidatorParams extends HasSeed with Params { | |
| instrumentation.logNamedValue("evaluator", $(evaluator).getClass.getCanonicalName) | ||
| instrumentation.logNamedValue("estimatorParamMapsLength", $(estimatorParamMaps).length) | ||
| } | ||
|
|
||
|
|
||
| /** | ||
| * Summary of grid search tuning in the format of DataFrame. Each row contains one candidate | ||
| * paramMap and the corresponding metric of trained model. | ||
| */ | ||
| protected def getTuningSummaryDF(metrics: Array[Double]): DataFrame = { | ||
| val params = $(estimatorParamMaps) | ||
| require(params.nonEmpty, "estimator param maps should not be empty") | ||
| require(params.length == metrics.length, "estimator param maps number should match metrics") | ||
| val metricName = $(evaluator) match { | ||
| case b: BinaryClassificationEvaluator => b.getMetricName | ||
| case m: MulticlassClassificationEvaluator => m.getMetricName | ||
| case r: RegressionEvaluator => r.getMetricName | ||
| case _ => "metrics" | ||
| } | ||
| val spark = SparkSession.builder().getOrCreate() | ||
| val sc = spark.sparkContext | ||
| val fields = params(0).toSeq.sortBy(_.param.name).map(_.param.name) ++ Seq(metricName) | ||
| val schema = new StructType(fields.map(name => StructField(name, StringType)).toArray) | ||
| val rows = sc.parallelize(params.zip(metrics)).map { case (param, metric) => | ||
| val values = param.toSeq.sortBy(_.param.name).map(_.value.toString) ++ Seq(metric.toString) | ||
| Row.fromSeq(values) | ||
| } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here the var names is a little confusing,
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
| spark.createDataFrame(rows, schema) | ||
| } | ||
| } | ||
|
|
||
| private[ml] object ValidatorParams { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,7 +29,7 @@ import org.apache.spark.ml.param.shared.HasInputCol | |
| import org.apache.spark.ml.regression.LinearRegression | ||
| import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} | ||
| import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} | ||
| import org.apache.spark.sql.Dataset | ||
| import org.apache.spark.sql.{Dataset, Row} | ||
| import org.apache.spark.sql.types.StructType | ||
|
|
||
| class CrossValidatorSuite | ||
|
|
@@ -66,6 +66,26 @@ class CrossValidatorSuite | |
| assert(cvModel.avgMetrics.length === lrParamMaps.length) | ||
| } | ||
|
|
||
| test("cross validation with tuning summary") { | ||
| val lr = new LogisticRegression | ||
| val lrParamMaps = new ParamGridBuilder() | ||
| .addGrid(lr.regParam, Array(0.001, 1.0, 1000.0)) | ||
| .addGrid(lr.maxIter, Array(0, 2)) | ||
| .build() | ||
| val eval = new BinaryClassificationEvaluator | ||
| val cv = new CrossValidator() | ||
| .setEstimator(lr) | ||
| .setEstimatorParamMaps(lrParamMaps) | ||
| .setEvaluator(eval) | ||
| .setNumFolds(3) | ||
| val cvModel = cv.fit(dataset) | ||
| val expected = lrParamMaps.zip(cvModel.avgMetrics).map { case (map, metric) => | ||
| Row.fromSeq(map.toSeq.sortBy(_.param.name).map(_.value.toString) ++ Seq(metric.toString)) | ||
| } | ||
| assert(cvModel.tuningSummary.collect().toSet === expected.toSet) | ||
| assert(cvModel.tuningSummary.columns.last === eval.getMetricName) | ||
| } | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shall we add a test for the exception being thrown if no summary? |
||
| test("cross validation with linear regression") { | ||
| val dataset = sc.parallelize( | ||
| LinearDataGenerator.generateLinearInput( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here seems exists a problem:
Suppose
params(0)(which is aParamMap) contains ParamA and ParamB,and
params(1)(which is aParamMap) contains ParamA and ParamC,The code here will run into problems. Because you compose the row values sorted by param name but do not check whether every row exactly match the first row.
I think better way is, go though the whole
ParamMaplist and collect all params used, and sort them by name, as the dataframe schema.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And here use
param_value.toString, some array type param will convert to unreadable string.For example,
DoubleArrayParam, doubleArray.toString will became "[DXXXXX"use
Param.jsonEncodeis better.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, we should support the case for custom paramMap.