-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-19357][ML] Adding parallel model evaluation in ML tuning #16774
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
5650e98
36a1a68
b051afa
46fe252
1274ba4
80ac2fd
8126710
6a9b735
1c2e391
9e055cd
97ad7b4
5e8a086
864c99c
ad8a870
911af1d
658aacb
2c73b0b
7a8221b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
…te models in paralell
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -51,7 +51,7 @@ private[ml] trait CrossValidatorParams extends ValidatorParams { | |
| /** @group getParam */ | ||
| def getNumFolds: Int = $(numFolds) | ||
|
|
||
| setDefault(numFolds -> 3) | ||
| setDefault(numFolds -> 3, numParallelEval -> 1) | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -91,6 +91,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) | |
| @Since("2.0.0") | ||
| def setSeed(value: Long): this.type = set(seed, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.2.0") | ||
| def setNumParallelEval(value: Int): this.type = set(numParallelEval, value) | ||
|
|
||
| @Since("2.0.0") | ||
| override def fit(dataset: Dataset[_]): CrossValidatorModel = { | ||
| val schema = dataset.schema | ||
|
|
@@ -100,31 +104,44 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) | |
| val eval = $(evaluator) | ||
| val epm = $(estimatorParamMaps) | ||
| val numModels = epm.length | ||
| val metrics = new Array[Double](epm.length) | ||
| val numPar = $(numParallelEval) | ||
|
|
||
| val instr = Instrumentation.create(this, dataset) | ||
| instr.logParams(numFolds, seed) | ||
| logTuningParams(instr) | ||
|
|
||
| // Compute metrics for each model over each fold | ||
|
||
| logDebug(s"Running cross-validation with level of parallelism: $numPar.") | ||
| val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) | ||
| splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => | ||
| val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) => | ||
| val trainingDataset = sparkSession.createDataFrame(training, schema).cache() | ||
| val validationDataset = sparkSession.createDataFrame(validation, schema).cache() | ||
| // multi-model training | ||
| logDebug(s"Train split $splitIndex with multiple sets of parameters.") | ||
| val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] | ||
|
|
||
| // Fit models concurrently, limited by using a sliding window over models | ||
| val models = epm.grouped(numPar).map { win => | ||
| win.par.map(est.fit(trainingDataset, _)) | ||
|
||
| }.toList.flatten.asInstanceOf[Seq[Model[_]]] | ||
| trainingDataset.unpersist() | ||
| var i = 0 | ||
| while (i < numModels) { | ||
| // TODO: duplicate evaluator to take extra params from input | ||
| val metric = eval.evaluate(models(i).transform(validationDataset, epm(i))) | ||
| logDebug(s"Got metric $metric for model trained with ${epm(i)}.") | ||
| metrics(i) += metric | ||
| i += 1 | ||
| } | ||
|
|
||
| // Evaluate models concurrently, limited by using a sliding window over models | ||
| val foldMetrics = models.zip(epm).grouped(numPar).map { win => | ||
| win.par.map { m => | ||
|
||
| // TODO: duplicate evaluator to take extra params from input | ||
| val metric = eval.evaluate(m._1.transform(validationDataset, m._2)) | ||
| logDebug(s"Got metric $metric for model trained with ${m._2}.") | ||
| metric | ||
| } | ||
| }.toList.flatten | ||
|
|
||
| validationDataset.unpersist() | ||
| } | ||
| foldMetrics | ||
| }.reduce((mA, mB) => mA.zip(mB).map(m => m._1 + m._2)).toArray | ||
|
|
||
| // Calculate average metric for all folds | ||
| f2jBLAS.dscal(numModels, 1.0 / $(numFolds), metrics, 1) | ||
|
||
|
|
||
| logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") | ||
| val (bestMetric, bestIndex) = | ||
| if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,7 +50,7 @@ private[ml] trait TrainValidationSplitParams extends ValidatorParams { | |
| /** @group getParam */ | ||
| def getTrainRatio: Double = $(trainRatio) | ||
|
|
||
| setDefault(trainRatio -> 0.75) | ||
| setDefault(trainRatio -> 0.75, numParallelEval -> 1) | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -87,15 +87,19 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St | |
| @Since("2.0.0") | ||
| def setSeed(value: Long): this.type = set(seed, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.2.0") | ||
| def setNumParallelEval(value: Int): this.type = set(numParallelEval, value) | ||
|
|
||
| @Since("2.0.0") | ||
| override def fit(dataset: Dataset[_]): TrainValidationSplitModel = { | ||
| val schema = dataset.schema | ||
| transformSchema(schema, logging = true) | ||
| val est = $(estimator) | ||
| val eval = $(evaluator) | ||
| val epm = $(estimatorParamMaps) | ||
| val numModels = epm.length | ||
| val metrics = new Array[Double](epm.length) | ||
| val numPar = $(numParallelEval) | ||
| logDebug(s"Running validation with level of parallelism: $numPar.") | ||
|
|
||
| val instr = Instrumentation.create(this, dataset) | ||
| instr.logParams(trainRatio, seed) | ||
|
|
@@ -106,18 +110,21 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St | |
| trainingDataset.cache() | ||
| validationDataset.cache() | ||
|
|
||
| // multi-model training | ||
| logDebug(s"Train split with multiple sets of parameters.") | ||
| val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] | ||
| // Fit models concurrently, limited by using a sliding window over models | ||
|
||
| val models = epm.grouped(numPar).map { win => | ||
| win.par.map(est.fit(trainingDataset, _)) | ||
|
||
| }.toList.flatten.asInstanceOf[Seq[Model[_]]] | ||
| trainingDataset.unpersist() | ||
| var i = 0 | ||
| while (i < numModels) { | ||
| // TODO: duplicate evaluator to take extra params from input | ||
| val metric = eval.evaluate(models(i).transform(validationDataset, epm(i))) | ||
| logDebug(s"Got metric $metric for model trained with ${epm(i)}.") | ||
| metrics(i) += metric | ||
| i += 1 | ||
| } | ||
| // Evaluate models concurrently, limited by using a sliding window over models | ||
| val metrics = models.zip(epm).grouped(numPar).map { win => | ||
| win.par.map { m => | ||
|
||
| // TODO: duplicate evaluator to take extra params from input | ||
| val metric = eval.evaluate(m._1.transform(validationDataset, m._2)) | ||
| logDebug(s"Got metric $metric for model trained with ${m._2}.") | ||
| metric | ||
| } | ||
| }.toList.flatten.toArray | ||
| validationDataset.unpersist() | ||
|
|
||
| logInfo(s"Train validation split metrics: ${metrics.toSeq}") | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,7 +24,7 @@ import org.json4s.jackson.JsonMethods._ | |
| import org.apache.spark.SparkContext | ||
| import org.apache.spark.ml.{Estimator, Model} | ||
| import org.apache.spark.ml.evaluation.Evaluator | ||
| import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} | ||
| import org.apache.spark.ml.param.{IntParam, Param, ParamMap, ParamPair, Params} | ||
| import org.apache.spark.ml.param.shared.HasSeed | ||
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.ml.util.DefaultParamsReader.Metadata | ||
|
|
@@ -67,6 +67,17 @@ private[ml] trait ValidatorParams extends HasSeed with Params { | |
| /** @group getParam */ | ||
| def getEvaluator: Evaluator = $(evaluator) | ||
|
|
||
| /** | ||
| * param to control the number of models evaluated in parallel | ||
| * | ||
| * @group param | ||
| */ | ||
| val numParallelEval: IntParam = new IntParam(this, "numParallelEval", | ||
|
||
| "max number of models to evaluate in parallel, 1 for serial evaluation") | ||
|
||
|
|
||
| /** @group getParam */ | ||
| def getNumParallelEval: Int = $(numParallelEval) | ||
|
|
||
| protected def transformSchemaImpl(schema: StructType): StructType = { | ||
| require($(estimatorParamMaps).nonEmpty, s"Validator requires non-empty estimatorParamMaps") | ||
| val firstEstimatorParamMap = $(estimatorParamMaps).head | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May make more sense to put the
setDefaultcall in the parent traitValidatorParams