-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-18724][ML] Add TuningSummary for TrainValidationSplit and CrossValidator #16158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
d1e22d5
a7cfa63
ad73c12
425a419
bd18c00
2a0af1d
1a594d0
0a698fe
bd459b1
4e3e19c
c0bc81a
bbf3f9f
b6a7c53
72aea62
91da358
297091f
670467a
36b1dd5
5844e0c
8c829b5
4aaf8e5
ceaad1c
41c4c12
4aef3aa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,7 @@ import scala.language.existentials | |
| import org.apache.hadoop.fs.Path | ||
| import org.json4s.DefaultFormats | ||
|
|
||
| import org.apache.spark.SparkException | ||
| import org.apache.spark.annotation.Since | ||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.ml.{Estimator, Model} | ||
|
|
@@ -123,7 +124,10 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St | |
| logInfo(s"Best set of parameters:\n${epm(bestIndex)}") | ||
| logInfo(s"Best train validation split metric: $bestMetric.") | ||
| val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] | ||
| copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this)) | ||
| val model = copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this)) | ||
| val summary = new TuningSummary(bestModel.transform(dataset), epm, metrics, bestIndex) | ||
|
||
| model.setSummary(Some(summary)) | ||
| model | ||
| } | ||
|
|
||
| @Since("1.5.0") | ||
|
|
@@ -226,6 +230,29 @@ class TrainValidationSplitModel private[ml] ( | |
|
|
||
| @Since("2.0.0") | ||
| override def write: MLWriter = new TrainValidationSplitModel.TrainValidationSplitModelWriter(this) | ||
|
|
||
| private var trainingSummary: Option[TuningSummary] = None | ||
|
|
||
| private[tuning] def setSummary(summary: Option[TuningSummary]): this.type = { | ||
| this.trainingSummary = summary | ||
| this | ||
| } | ||
|
|
||
| /** | ||
| * Return true if there exists summary of model. | ||
| */ | ||
| @Since("2.0.0") | ||
| def hasSummary: Boolean = trainingSummary.nonEmpty | ||
|
|
||
| /** | ||
| * Gets summary of model on training set. An exception is | ||
|
||
| * thrown if `trainingSummary == None`. | ||
| */ | ||
| @Since("2.0.0") | ||
| def summary: TuningSummary = trainingSummary.getOrElse { | ||
| throw new SparkException( | ||
| s"No training summary available for the ${this.getClass.getSimpleName}") | ||
| } | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm thinking we should add a new trait
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. addressed in #17654 |
||
| } | ||
|
|
||
| @Since("2.0.0") | ||
|
|
@@ -275,3 +302,4 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { | |
| } | ||
| } | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.ml.tuning | ||
|
|
||
| import org.apache.spark.annotation.{Experimental, Since} | ||
| import org.apache.spark.ml.param.ParamMap | ||
| import org.apache.spark.sql.{DataFrame, Row} | ||
| import org.apache.spark.sql.types.{StringType, StructField, StructType} | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * Summary of grid search tuning. | ||
| * | ||
| * @param params estimator param maps | ||
| * @param metrics Corresponding evaluation metrics for the param maps | ||
| */ | ||
| @Since("2.2.0") | ||
| @Experimental | ||
| class TuningSummary private[tuning]( | ||
| @transient val predictions: DataFrame, | ||
| val params: Array[ParamMap], | ||
| val metrics: Array[Double], | ||
| val bestIndex: Int) { | ||
|
|
||
| def trainingMetrics: DataFrame = { | ||
| require(params.nonEmpty, "estimator param maps should not be empty") | ||
| require(params.length == metrics.length, "estimator param maps numner should match metrics") | ||
| val sqlContext = predictions.sqlContext | ||
| val sc = sqlContext.sparkContext | ||
| val fields = params(0).toSeq.sortBy(_.param.name).map(_.param.name) ++ Seq("metrics") | ||
|
||
| val schema = new StructType(fields.map(name => StructField(name, StringType)).toArray) | ||
|
|
||
| val rows = sc.parallelize(params.zip(metrics)).map { case (param, metric) => | ||
| val values = param.toSeq.sortBy(_.param.name).map(_.value.toString) ++ Seq(metric.toString) | ||
| Row.fromSeq(values) | ||
| } | ||
| sqlContext.createDataFrame(rows, schema) | ||
| } | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to confirm, the tuning summary will not be saved? Since it's a small dataframe, perhaps we should consider saving it with the model? (Can do that in a later PR however)
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we want to just save the tuning summary in the model, perhaps we can just discard the TuningSummary, and add a tuningSummary: DataFrame field/function in the models. Sounds good?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there other obvious things that might go into the summary in future, that would make a
TuningSummaryclass a better fit?Future support for say, multiple metrics, could simply extend the dataframe columns so that is ok. But is there anything else you can think of?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There might be something like detailed training log and training time for each model. But I'm thinking the current Summary pattern does have some room for improvement (e.g., save/load and API), it makes me feel bad when I have to duplicate the code like
def hasSummary: Boolean = trainingSummary.nonEmpty. Thus saving it to the models sounds like a good idea to me.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The latest implementation does not need to save the extra dataframe. Since basically the dataframe can be generated from $(estimatorParamMaps) and avgMetrics.