-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-21087] [ML] CrossValidator, TrainValidationSplit expose sub models after fitting: Scala #19208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-21087] [ML] CrossValidator, TrainValidationSplit expose sub models after fitting: Scala #19208
Changes from 10 commits
46d3ab3
ae13440
e0f4ce6
a33c4ea
e009ee1
931fa6c
2a83fb5
f2ef609
7bacfca
654e4d5
7e997da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,7 @@ | |
|
|
||
| package org.apache.spark.ml.tuning | ||
|
|
||
| import java.util.{List => JList} | ||
| import java.util.{List => JList, Locale} | ||
|
|
||
| import scala.collection.JavaConverters._ | ||
| import scala.concurrent.Future | ||
|
|
@@ -31,7 +31,7 @@ import org.apache.spark.internal.Logging | |
| import org.apache.spark.ml.{Estimator, Model} | ||
| import org.apache.spark.ml.evaluation.Evaluator | ||
| import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} | ||
| import org.apache.spark.ml.param.shared.HasParallelism | ||
| import org.apache.spark.ml.param.shared.{HasCollectSubModels, HasParallelism} | ||
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.mllib.util.MLUtils | ||
| import org.apache.spark.sql.{DataFrame, Dataset} | ||
|
|
@@ -67,7 +67,8 @@ private[ml] trait CrossValidatorParams extends ValidatorParams { | |
| @Since("1.2.0") | ||
| class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) | ||
| extends Estimator[CrossValidatorModel] | ||
| with CrossValidatorParams with HasParallelism with MLWritable with Logging { | ||
| with CrossValidatorParams with HasParallelism with HasCollectSubModels | ||
| with MLWritable with Logging { | ||
|
|
||
| @Since("1.2.0") | ||
| def this() = this(Identifiable.randomUID("cv")) | ||
|
|
@@ -101,6 +102,21 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) | |
| @Since("2.3.0") | ||
| def setParallelism(value: Int): this.type = set(parallelism, value) | ||
|
|
||
| /** | ||
| * Whether to collect submodels when fitting. If set, we can get submodels from | ||
| * the returned model. | ||
| * | ||
| * Note: If set this param, when you save the returned model, you can set an option | ||
| * "persistSubModels" to be "true" before saving, in order to save these submodels. | ||
| * You can check documents of | ||
| * {@link org.apache.spark.ml.tuning.CrossValidatorModel.CrossValidatorModelWriter} | ||
| * for more information. | ||
| * | ||
| * @group expertSetParam | ||
| */ | ||
| @Since("2.3.0") | ||
| def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, value) | ||
|
|
||
| @Since("2.0.0") | ||
| override def fit(dataset: Dataset[_]): CrossValidatorModel = { | ||
| val schema = dataset.schema | ||
|
|
@@ -117,6 +133,12 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) | |
| instr.logParams(numFolds, seed, parallelism) | ||
| logTuningParams(instr) | ||
|
|
||
| val collectSubModelsParam = $(collectSubModels) | ||
|
|
||
| var subModels: Option[Array[Array[Model[_]]]] = if (collectSubModelsParam) { | ||
| Some(Array.fill($(numFolds))(Array.fill[Model[_]](epm.length)(null))) | ||
| } else None | ||
|
|
||
| // Compute metrics for each model over each split | ||
| val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) | ||
| val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) => | ||
|
|
@@ -125,10 +147,14 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) | |
| logDebug(s"Train split $splitIndex with multiple sets of parameters.") | ||
|
|
||
| // Fit models in a Future for training in parallel | ||
| val modelFutures = epm.map { paramMap => | ||
| val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => | ||
| Future[Model[_]] { | ||
| val model = est.fit(trainingDataset, paramMap) | ||
| model.asInstanceOf[Model[_]] | ||
| val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] | ||
|
|
||
| if (collectSubModelsParam) { | ||
| subModels.get(splitIndex)(paramIndex) = model | ||
| } | ||
| model | ||
| } (executionContext) | ||
| } | ||
|
|
||
|
|
@@ -160,7 +186,8 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) | |
| logInfo(s"Best cross-validation metric: $bestMetric.") | ||
| val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] | ||
| instr.logSuccess(bestModel) | ||
| copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this)) | ||
| copyValues(new CrossValidatorModel(uid, bestModel, metrics) | ||
| .setSubModels(subModels).setParent(this)) | ||
| } | ||
|
|
||
| @Since("1.4.0") | ||
|
|
@@ -244,6 +271,31 @@ class CrossValidatorModel private[ml] ( | |
| this(uid, bestModel, avgMetrics.asScala.toArray) | ||
| } | ||
|
|
||
| private var _subModels: Option[Array[Array[Model[_]]]] = None | ||
|
|
||
| private[tuning] def setSubModels(subModels: Option[Array[Array[Model[_]]]]) | ||
| : CrossValidatorModel = { | ||
| _subModels = subModels | ||
| this | ||
| } | ||
|
|
||
| /** | ||
| * @return submodels represented in two dimension array. The index of outer array is the | ||
| * fold index, and the index of inner array corresponds to the ordering of | ||
| * estimatorParamMaps | ||
| * @throws IllegalArgumentException if subModels are not available. To retrieve subModels, | ||
| * make sure to set collectSubModels to true before fitting. | ||
| */ | ||
| @Since("2.3.0") | ||
| def subModels: Array[Array[Model[_]]] = { | ||
| require(_subModels.isDefined, "subModels not available, To retrieve subModels, make sure " + | ||
| "to set collectSubModels to true before fitting.") | ||
| _subModels.get | ||
| } | ||
|
|
||
| @Since("2.3.0") | ||
| def hasSubModels: Boolean = _subModels.isDefined | ||
|
|
||
| @Since("2.0.0") | ||
| override def transform(dataset: Dataset[_]): DataFrame = { | ||
| transformSchema(dataset.schema, logging = true) | ||
|
|
@@ -260,34 +312,76 @@ class CrossValidatorModel private[ml] ( | |
| val copied = new CrossValidatorModel( | ||
| uid, | ||
| bestModel.copy(extra).asInstanceOf[Model[_]], | ||
| avgMetrics.clone()) | ||
| avgMetrics.clone() | ||
| ).setSubModels(CrossValidatorModel.copySubModels(_subModels)) | ||
| copyValues(copied, extra).setParent(parent) | ||
| } | ||
|
|
||
| @Since("1.6.0") | ||
| override def write: MLWriter = new CrossValidatorModel.CrossValidatorModelWriter(this) | ||
| override def write: CrossValidatorModel.CrossValidatorModelWriter = { | ||
| new CrossValidatorModel.CrossValidatorModelWriter(this) | ||
| } | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I add this method because the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think users can still access The
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Discussion: Another way I think is adding an interface
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with the last suggestion of adding |
||
| } | ||
|
|
||
| @Since("1.6.0") | ||
| object CrossValidatorModel extends MLReadable[CrossValidatorModel] { | ||
|
|
||
| private[CrossValidatorModel] def copySubModels(subModels: Option[Array[Array[Model[_]]]]) | ||
| : Option[Array[Array[Model[_]]]] = { | ||
| subModels.map(_.map(_.map(_.copy(ParamMap.empty).asInstanceOf[Model[_]]))) | ||
| } | ||
|
|
||
| @Since("1.6.0") | ||
| override def read: MLReader[CrossValidatorModel] = new CrossValidatorModelReader | ||
|
|
||
| @Since("1.6.0") | ||
| override def load(path: String): CrossValidatorModel = super.load(path) | ||
|
|
||
| private[CrossValidatorModel] | ||
| class CrossValidatorModelWriter(instance: CrossValidatorModel) extends MLWriter { | ||
| /** | ||
| * Writer for CrossValidatorModel. | ||
| * @param instance CrossValidatorModel instance used to construct the writer | ||
| * | ||
| * CrossValidatorModelWriter supports an option "persistSubModels", with possible values | ||
| * "true" or "false". If you set the collectSubModels Param before fitting, then you can | ||
| * set "persistSubModels" to "true" in order to persist the subModels. By default, | ||
| * "persistSubModels" will be "true" when subModels are available and "false" otherwise. | ||
| * If subModels are not available, then setting "persistSubModels" to "true" will cause | ||
| * an exception. | ||
| */ | ||
| @Since("2.3.0") | ||
| final class CrossValidatorModelWriter private[tuning] ( | ||
| instance: CrossValidatorModel) extends MLWriter { | ||
|
|
||
| ValidatorParams.validateParams(instance) | ||
|
|
||
| override protected def saveImpl(path: String): Unit = { | ||
| val persistSubModelsParam = optionMap.getOrElse("persistsubmodels", | ||
| if (instance.hasSubModels) "true" else "false") | ||
|
|
||
| require(Array("true", "false").contains(persistSubModelsParam.toLowerCase(Locale.ROOT)), | ||
| s"persistSubModels option value ${persistSubModelsParam} is invalid, the possible " + | ||
| "values are \"true\" or \"false\"") | ||
| val persistSubModels = persistSubModelsParam.toBoolean | ||
|
|
||
| import org.json4s.JsonDSL._ | ||
| val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq | ||
| val extraMetadata = ("avgMetrics" -> instance.avgMetrics.toSeq) ~ | ||
| ("persistSubModels" -> persistSubModels) | ||
| ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata)) | ||
| val bestModelPath = new Path(path, "bestModel").toString | ||
| instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) | ||
| if (persistSubModels) { | ||
| require(instance.hasSubModels, "When persisting tuning models, you can only set " + | ||
| "persistSubModels to true if the tuning was done with collectSubModels set to true. " + | ||
| "To save the sub-models, try rerunning fitting with collectSubModels set to true.") | ||
| val subModelsPath = new Path(path, "subModels") | ||
| for (splitIndex <- 0 until instance.getNumFolds) { | ||
| val splitPath = new Path(subModelsPath, s"fold${splitIndex.toString}") | ||
| for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) { | ||
| val modelPath = new Path(splitPath, paramIndex.toString).toString | ||
| instance.subModels(splitIndex)(paramIndex).asInstanceOf[MLWritable].save(modelPath) | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -301,11 +395,29 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { | |
|
|
||
| val (metadata, estimator, evaluator, estimatorParamMaps) = | ||
| ValidatorParams.loadImpl(path, sc, className) | ||
| val numFolds = (metadata.params \ "numFolds").extract[Int] | ||
| val bestModelPath = new Path(path, "bestModel").toString | ||
| val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) | ||
| val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray | ||
| val shouldPersistSubModels = (metadata.metadata \ "persistSubModels").extract[Boolean] | ||
|
||
|
|
||
| val subModels: Option[Array[Array[Model[_]]]] = if (shouldPersistSubModels) { | ||
| val subModelsPath = new Path(path, "subModels") | ||
| val _subModels = Array.fill(numFolds)(Array.fill[Model[_]]( | ||
| estimatorParamMaps.length)(null)) | ||
| for (splitIndex <- 0 until numFolds) { | ||
| val splitPath = new Path(subModelsPath, s"fold${splitIndex.toString}") | ||
| for (paramIndex <- 0 until estimatorParamMaps.length) { | ||
| val modelPath = new Path(splitPath, paramIndex.toString).toString | ||
| _subModels(splitIndex)(paramIndex) = | ||
| DefaultParamsReader.loadParamsInstance(modelPath, sc) | ||
| } | ||
| } | ||
| Some(_subModels) | ||
| } else None | ||
|
|
||
| val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics) | ||
| .setSubModels(subModels) | ||
| model.set(model.estimator, estimator) | ||
| .set(model.evaluator, evaluator) | ||
| .set(model.estimatorParamMaps, estimatorParamMaps) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so this var seems unnecessary, could we just it seems like we'd be better by just collecting modelFutures in copy values (then we can avoid the mutation on L145)
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@holdenk @jkbradley I already thought about this issue. The reason I use this way is:
modelFuturesandfoldMetricFutureswill be executed in pipelined way, when$(collectSubModels) == false, this will make sure that themodelgenerated inmodelFutureswill be released in time, so that the maximum memory cost will benumParallelism * sizeof(model). If we use the way of "collecting modelFutures", it will increase the memory cost to be$(estimatorParamMaps).length * sizeof(model). This is a serious issue which is discussed before.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't follow with #1, if we keep all the models (e.g. set
collectSubModelsParam) then the maximum memory cost will be$(estimatorParamMaps).length * sizeof(model)in either case? If we don't keep the models (e.g. setcollectSubModelsParamto false) then you don't have to collect the future back at the end and there is no additional overhead.For #2, It's not that mutation impacts performance, its that it makes the code less easy to reason about for no gain (unless I've misunderstood something about part 1).
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@holdenk Oh, sorry for confusing you. Yes, if set
collectSubModelsParamthe memory cost will always be$(estimatorParamMaps).length * sizeof(model). According to your suggestion, we have to duplicate code logic (but if i am wrong correct me):collectSubModelsParam, we cannot pipelinemodelFuturesandfoldMetricFutures, we should executemodelFuturesand collect results first, and modifyfoldMetricFutureslogic, change it into something like following:collectSubModelsParam, just keep currentmodelFutures&foldMetricFuturesand pipeline them to execute. (Only pipeline them we can save memory cost tonumParallelism * sizeof(model).So, according to your suggestion, it seems need more code. So do you still prefer this way ? Or do you have better way to implement that ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I didn't follow up on this before. I think that @WeichenXu123 's argument is valid, but please say if there are issues I'm missing @holdenk