-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-21087] [ML] CrossValidator, TrainValidationSplit expose sub models after fitting: Scala #19208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-21087] [ML] CrossValidator, TrainValidationSplit expose sub models after fitting: Scala #19208
Changes from 1 commit
46d3ab3
ae13440
e0f4ce6
a33c4ea
e009ee1
931fa6c
2a83fb5
f2ef609
7bacfca
654e4d5
7e997da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
|
|
||
| package org.apache.spark.ml.tuning | ||
|
|
||
| import java.io.IOException | ||
| import java.util.{List => JList} | ||
|
|
||
| import scala.collection.JavaConverters._ | ||
|
|
@@ -31,7 +32,7 @@ import org.apache.spark.internal.Logging | |
| import org.apache.spark.ml.{Estimator, Model} | ||
| import org.apache.spark.ml.evaluation.Evaluator | ||
| import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} | ||
| import org.apache.spark.ml.param.shared.HasParallelism | ||
| import org.apache.spark.ml.param.shared.{HasCollectSubModels, HasParallelism, HasPersistSubModelsPath} | ||
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.mllib.util.MLUtils | ||
| import org.apache.spark.sql.{DataFrame, Dataset} | ||
|
|
@@ -67,7 +68,8 @@ private[ml] trait CrossValidatorParams extends ValidatorParams { | |
| @Since("1.2.0") | ||
| class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) | ||
| extends Estimator[CrossValidatorModel] | ||
| with CrossValidatorParams with HasParallelism with MLWritable with Logging { | ||
| with CrossValidatorParams with HasParallelism with HasCollectSubModels | ||
| with HasPersistSubModelsPath with MLWritable with Logging { | ||
|
|
||
| @Since("1.2.0") | ||
| def this() = this(Identifiable.randomUID("cv")) | ||
|
|
@@ -101,6 +103,14 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) | |
| @Since("2.3.0") | ||
| def setParallelism(value: Int): this.type = set(parallelism, value) | ||
|
|
||
| /** @group expertSetParam */ | ||
| @Since("2.3.0") | ||
| def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, value) | ||
|
|
||
| /** @group expertSetParam */ | ||
| @Since("2.3.0") | ||
| def setPersistSubModelsPath(value: String): this.type = set(persistSubModelsPath, value) | ||
|
|
||
| @Since("2.0.0") | ||
| override def fit(dataset: Dataset[_]): CrossValidatorModel = { | ||
| val schema = dataset.schema | ||
|
|
@@ -117,6 +127,13 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) | |
| instr.logParams(numFolds, seed, parallelism) | ||
| logTuningParams(instr) | ||
|
|
||
| val collectSubModelsParam = $(collectSubModels) | ||
| val persistSubModelsPathParam = $(persistSubModelsPath) | ||
|
|
||
| var subModels: Array[Array[Model[_]]] = if (collectSubModelsParam) { | ||
|
||
| Array.fill($(numFolds))(Array.fill[Model[_]](epm.length)(null)) | ||
| } else null | ||
|
|
||
| // Compute metrics for each model over each split | ||
| val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) | ||
| val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) => | ||
|
|
@@ -125,10 +142,19 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) | |
| logDebug(s"Train split $splitIndex with multiple sets of parameters.") | ||
|
|
||
| // Fit models in a Future for training in parallel | ||
| val modelFutures = epm.map { paramMap => | ||
| val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => | ||
| Future[Model[_]] { | ||
| val model = est.fit(trainingDataset, paramMap) | ||
| model.asInstanceOf[Model[_]] | ||
| val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] | ||
|
|
||
| if (collectSubModelsParam) { | ||
| subModels(splitIndex)(paramIndex) = model | ||
| } | ||
| if (persistSubModelsPathParam.nonEmpty) { | ||
| val modelPath = new Path(new Path(persistSubModelsPathParam, splitIndex.toString), | ||
| paramIndex.toString).toString | ||
| model.asInstanceOf[MLWritable].save(modelPath) | ||
| } | ||
| model | ||
| } (executionContext) | ||
| } | ||
|
|
||
|
|
@@ -160,7 +186,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) | |
| logInfo(s"Best cross-validation metric: $bestMetric.") | ||
| val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] | ||
| instr.logSuccess(bestModel) | ||
| copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this)) | ||
| copyValues(new CrossValidatorModel(uid, bestModel, metrics, subModels).setParent(this)) | ||
| } | ||
|
|
||
| @Since("1.4.0") | ||
|
|
@@ -212,14 +238,12 @@ object CrossValidator extends MLReadable[CrossValidator] { | |
|
|
||
| val (metadata, estimator, evaluator, estimatorParamMaps) = | ||
| ValidatorParams.loadImpl(path, sc, className) | ||
| val numFolds = (metadata.params \ "numFolds").extract[Int] | ||
| val seed = (metadata.params \ "seed").extract[Long] | ||
| new CrossValidator(metadata.uid) | ||
| val cv = new CrossValidator(metadata.uid) | ||
| .setEstimator(estimator) | ||
| .setEvaluator(evaluator) | ||
| .setEstimatorParamMaps(estimatorParamMaps) | ||
| .setNumFolds(numFolds) | ||
| .setSeed(seed) | ||
| DefaultParamsReader.getAndSetParams(cv, metadata, skipParams = List("estimatorParamMaps")) | ||
|
||
| cv | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -237,12 +261,17 @@ object CrossValidator extends MLReadable[CrossValidator] { | |
| class CrossValidatorModel private[ml] ( | ||
| @Since("1.4.0") override val uid: String, | ||
| @Since("1.2.0") val bestModel: Model[_], | ||
| @Since("1.5.0") val avgMetrics: Array[Double]) | ||
| @Since("1.5.0") val avgMetrics: Array[Double], | ||
| @Since("2.3.0") val subModels: Array[Array[Model[_]]]) | ||
| extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable { | ||
|
|
||
| /** A Python-friendly auxiliary constructor. */ | ||
| private[ml] def this(uid: String, bestModel: Model[_], avgMetrics: JList[Double]) = { | ||
| this(uid, bestModel, avgMetrics.asScala.toArray) | ||
| this(uid, bestModel, avgMetrics.asScala.toArray, null) | ||
|
||
| } | ||
|
|
||
| private[ml] def this(uid: String, bestModel: Model[_], avgMetrics: Array[Double]) = { | ||
| this(uid, bestModel, avgMetrics, null) | ||
| } | ||
|
|
||
| @Since("2.0.0") | ||
|
|
@@ -261,17 +290,40 @@ class CrossValidatorModel private[ml] ( | |
| val copied = new CrossValidatorModel( | ||
| uid, | ||
| bestModel.copy(extra).asInstanceOf[Model[_]], | ||
| avgMetrics.clone()) | ||
| avgMetrics.clone(), | ||
| CrossValidatorModel.copySubModels(subModels)) | ||
| copyValues(copied, extra).setParent(parent) | ||
| } | ||
|
|
||
| @Since("1.6.0") | ||
| override def write: MLWriter = new CrossValidatorModel.CrossValidatorModelWriter(this) | ||
|
|
||
| @Since("2.3.0") | ||
| @throws[IOException]("If the input path already exists but overwrite is not enabled.") | ||
| def save(path: String, persistSubModels: Boolean): Unit = { | ||
| write.asInstanceOf[CrossValidatorModel.CrossValidatorModelWriter] | ||
| .persistSubModels(persistSubModels).save(path) | ||
| } | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I add this method because the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think users can still access The
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Discussion: Another way I think is adding an interface
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with the last suggestion of adding |
||
| } | ||
|
|
||
| @Since("1.6.0") | ||
| object CrossValidatorModel extends MLReadable[CrossValidatorModel] { | ||
|
|
||
| private[CrossValidatorModel] def copySubModels(subModels: Array[Array[Model[_]]]) = { | ||
| var copiedSubModels: Array[Array[Model[_]]] = null | ||
| if (subModels != null) { | ||
| val numFolds = subModels.length | ||
| val numParamMaps = subModels(0).length | ||
| copiedSubModels = Array.fill(numFolds)(Array.fill[Model[_]](numParamMaps)(null)) | ||
| for (i <- 0 until numFolds) { | ||
| for (j <- 0 until numParamMaps) { | ||
| copiedSubModels(i)(j) = subModels(i)(j).copy(ParamMap.empty).asInstanceOf[Model[_]] | ||
| } | ||
| } | ||
| } | ||
| copiedSubModels | ||
| } | ||
|
|
||
| @Since("1.6.0") | ||
| override def read: MLReader[CrossValidatorModel] = new CrossValidatorModelReader | ||
|
|
||
|
|
@@ -283,12 +335,35 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { | |
|
|
||
| ValidatorParams.validateParams(instance) | ||
|
|
||
| protected var shouldPersistSubModels: Boolean = false | ||
|
|
||
| /** | ||
| * Set option for persist sub models. | ||
| */ | ||
| @Since("2.3.0") | ||
| def persistSubModels(persist: Boolean): this.type = { | ||
| shouldPersistSubModels = persist | ||
| this | ||
| } | ||
|
|
||
| override protected def saveImpl(path: String): Unit = { | ||
| import org.json4s.JsonDSL._ | ||
| val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq | ||
| val extraMetadata = ("avgMetrics" -> instance.avgMetrics.toSeq) ~ | ||
| ("shouldPersistSubModels" -> shouldPersistSubModels) | ||
|
||
| ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata)) | ||
| val bestModelPath = new Path(path, "bestModel").toString | ||
| instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) | ||
| if (shouldPersistSubModels) { | ||
| require(instance.subModels != null, "Cannot get sub models to persist.") | ||
| val subModelsPath = new Path(path, "subModels") | ||
| for (splitIndex <- 0 until instance.getNumFolds) { | ||
| val splitPath = new Path(subModelsPath, splitIndex.toString) | ||
|
||
| for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) { | ||
| val modelPath = new Path(splitPath, paramIndex.toString).toString | ||
| instance.subModels(splitIndex)(paramIndex).asInstanceOf[MLWritable].save(modelPath) | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -303,16 +378,32 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { | |
| val (metadata, estimator, evaluator, estimatorParamMaps) = | ||
| ValidatorParams.loadImpl(path, sc, className) | ||
| val numFolds = (metadata.params \ "numFolds").extract[Int] | ||
| val seed = (metadata.params \ "seed").extract[Long] | ||
| val bestModelPath = new Path(path, "bestModel").toString | ||
| val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) | ||
| val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray | ||
| val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics) | ||
| val shouldPersistSubModels = (metadata.metadata \ "shouldPersistSubModels").extract[Boolean] | ||
|
|
||
| val subModels: Array[Array[Model[_]]] = if (shouldPersistSubModels) { | ||
| val subModelsPath = new Path(path, "subModels") | ||
| val _subModels = Array.fill(numFolds)(Array.fill[Model[_]]( | ||
| estimatorParamMaps.length)(null)) | ||
| for (splitIndex <- 0 until numFolds) { | ||
| val splitPath = new Path(subModelsPath, splitIndex.toString) | ||
| for (paramIndex <- 0 until estimatorParamMaps.length) { | ||
| val modelPath = new Path(splitPath, paramIndex.toString).toString | ||
| _subModels(splitIndex)(paramIndex) = | ||
| DefaultParamsReader.loadParamsInstance(modelPath, sc) | ||
| } | ||
| } | ||
| _subModels | ||
| } else null | ||
|
|
||
| val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics, subModels) | ||
| model.set(model.estimator, estimator) | ||
| .set(model.evaluator, evaluator) | ||
| .set(model.estimatorParamMaps, estimatorParamMaps) | ||
| .set(model.numFolds, numFolds) | ||
| .set(model.seed, seed) | ||
| DefaultParamsReader.getAndSetParams(model, metadata, skipParams = List("estimatorParamMaps")) | ||
| model | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: reword "whether to collect sub models when tuning fitting" --> "whether to collect a list of sub-models trained during tuning"