-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-19357][ML] Adding parallel model evaluation in ML tuning #16774
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
5650e98
36a1a68
b051afa
46fe252
1274ba4
80ac2fd
8126710
6a9b735
1c2e391
9e055cd
97ad7b4
5e8a086
864c99c
ad8a870
911af1d
658aacb
2c73b0b
7a8221b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.ml.param.shared | ||
|
|
||
| import scala.concurrent.ExecutionContext | ||
|
|
||
| import org.apache.spark.ml.param.{IntParam, Params, ParamValidators} | ||
| import org.apache.spark.util.ThreadUtils | ||
|
|
||
| /** | ||
| * Trait to define a level of parallelism for algorithms that are able to use | ||
| * multithreaded execution, and provide a thread-pool based execution context. | ||
| */ | ||
| private[ml] trait HasParallelism extends Params { | ||
|
|
||
| /** | ||
| * The number of threads to use when running parallel algorithms. | ||
| * Default is 1 for serial execution | ||
| * | ||
| * @group expertParam | ||
| */ | ||
| val parallelism = new IntParam(this, "parallelism", | ||
| "the number of threads to use when running parallel algorithms", ParamValidators.gtEq(1)) | ||
|
|
||
| setDefault(parallelism -> 1) | ||
|
|
||
| /** @group expertGetParam */ | ||
| def getParallelism: Int = $(parallelism) | ||
|
|
||
| /** | ||
| * Create a new execution context with a thread-pool that has a maximum number of threads | ||
| * set to the value of [[parallelism]]. If this param is set to 1, a same-thread executor | ||
| * will be used to run in serial. | ||
| */ | ||
| private[ml] def getExecutionContext: ExecutionContext = { | ||
| getParallelism match { | ||
| case 1 => | ||
| ThreadUtils.sameThread | ||
| case n => | ||
| ExecutionContext.fromExecutorService(ThreadUtils | ||
| .newDaemonCachedThreadPool(s"${this.getClass.getSimpleName}-thread-pool", n)) | ||
| } | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,20 +20,24 @@ package org.apache.spark.ml.tuning | |
| import java.util.{List => JList} | ||
|
|
||
| import scala.collection.JavaConverters._ | ||
| import scala.concurrent.Future | ||
| import scala.concurrent.duration.Duration | ||
|
|
||
| import com.github.fommil.netlib.F2jBLAS | ||
| import org.apache.hadoop.fs.Path | ||
| import org.json4s.DefaultFormats | ||
|
|
||
| import org.apache.spark.annotation.Since | ||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.ml._ | ||
| import org.apache.spark.ml.{Estimator, Model} | ||
| import org.apache.spark.ml.evaluation.Evaluator | ||
| import org.apache.spark.ml.param._ | ||
| import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} | ||
| import org.apache.spark.ml.param.shared.HasParallelism | ||
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.mllib.util.MLUtils | ||
| import org.apache.spark.sql.{DataFrame, Dataset} | ||
| import org.apache.spark.sql.types.StructType | ||
| import org.apache.spark.util.ThreadUtils | ||
|
|
||
| /** | ||
| * Params for [[CrossValidator]] and [[CrossValidatorModel]]. | ||
|
|
@@ -64,7 +68,7 @@ private[ml] trait CrossValidatorParams extends ValidatorParams { | |
| @Since("1.2.0") | ||
| class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) | ||
| extends Estimator[CrossValidatorModel] | ||
| with CrossValidatorParams with MLWritable with Logging { | ||
| with CrossValidatorParams with HasParallelism with MLWritable with Logging { | ||
|
|
||
| @Since("1.2.0") | ||
| def this() = this(Identifiable.randomUID("cv")) | ||
|
|
@@ -91,6 +95,15 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) | |
| @Since("2.0.0") | ||
| def setSeed(value: Long): this.type = set(seed, value) | ||
|
|
||
| /** | ||
| * Set the mamixum level of parallelism to evaluate models in parallel. | ||
| * Default is 1 for serial evaluation | ||
| * | ||
| * @group expertSetParam | ||
| */ | ||
| @Since("2.3.0") | ||
| def setParallelism(value: Int): this.type = set(parallelism, value) | ||
|
|
||
| @Since("2.0.0") | ||
| override def fit(dataset: Dataset[_]): CrossValidatorModel = { | ||
| val schema = dataset.schema | ||
|
|
@@ -100,31 +113,53 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) | |
| val eval = $(evaluator) | ||
| val epm = $(estimatorParamMaps) | ||
| val numModels = epm.length | ||
| val metrics = new Array[Double](epm.length) | ||
|
|
||
| // Create execution context based on $(parallelism) | ||
| val executionContext = getExecutionContext | ||
|
|
||
| val instr = Instrumentation.create(this, dataset) | ||
| instr.logParams(numFolds, seed) | ||
| instr.logParams(numFolds, seed, parallelism) | ||
| logTuningParams(instr) | ||
|
|
||
| // Compute metrics for each model over each split | ||
| val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) | ||
| splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => | ||
| val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) => | ||
| val trainingDataset = sparkSession.createDataFrame(training, schema).cache() | ||
| val validationDataset = sparkSession.createDataFrame(validation, schema).cache() | ||
| // multi-model training | ||
| logDebug(s"Train split $splitIndex with multiple sets of parameters.") | ||
| val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] | ||
| trainingDataset.unpersist() | ||
| var i = 0 | ||
| while (i < numModels) { | ||
| // TODO: duplicate evaluator to take extra params from input | ||
| val metric = eval.evaluate(models(i).transform(validationDataset, epm(i))) | ||
| logDebug(s"Got metric $metric for model trained with ${epm(i)}.") | ||
| metrics(i) += metric | ||
| i += 1 | ||
|
|
||
| // Fit models in a Future for training in parallel | ||
| val models = epm.map { paramMap => | ||
|
||
| Future[Model[_]] { | ||
| val model = est.fit(trainingDataset, paramMap) | ||
| model.asInstanceOf[Model[_]] | ||
| } (executionContext) | ||
| } | ||
|
|
||
| // Unpersist training data only when all models have trained | ||
| Future.sequence[Model[_], Iterable](models)(implicitly, executionContext).onComplete { _ => | ||
|
||
| trainingDataset.unpersist() | ||
| } (executionContext) | ||
|
|
||
| // Evaluate models in a Future that will calulate a metric and allow model to be cleaned up | ||
| val foldMetricFutures = models.zip(epm).map { case (modelFuture, paramMap) => | ||
| modelFuture.map { model => | ||
| // TODO: duplicate evaluator to take extra params from input | ||
| val metric = eval.evaluate(model.transform(validationDataset, paramMap)) | ||
| logDebug(s"Got metric $metric for model trained with $paramMap.") | ||
| metric | ||
| } (executionContext) | ||
| } | ||
|
|
||
| // Wait for metrics to be calculated before unpersisting validation dataset | ||
| val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it make sense to also use val metrics = (ThreadUtils.awaitResult(
Future.sequence[Double, Iterable](metricFutures), Duration.Inf)).toArray
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought about that, but since it's a blocking call anyway, it will still be bound by the longest running thread.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, not a big deal either way |
||
| validationDataset.unpersist() | ||
| } | ||
| foldMetrics | ||
| }.transpose.map(_.sum) | ||
|
|
||
| // Calculate average metric over all splits | ||
| f2jBLAS.dscal(numModels, 1.0 / $(numFolds), metrics, 1) | ||
|
||
|
|
||
| logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") | ||
| val (bestMetric, bestIndex) = | ||
| if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,8 @@ package org.apache.spark.ml.tuning | |
| import java.util.{List => JList} | ||
|
|
||
| import scala.collection.JavaConverters._ | ||
| import scala.concurrent.Future | ||
| import scala.concurrent.duration.Duration | ||
| import scala.language.existentials | ||
|
|
||
| import org.apache.hadoop.fs.Path | ||
|
|
@@ -30,9 +32,11 @@ import org.apache.spark.internal.Logging | |
| import org.apache.spark.ml.{Estimator, Model} | ||
| import org.apache.spark.ml.evaluation.Evaluator | ||
| import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} | ||
| import org.apache.spark.ml.param.shared.HasParallelism | ||
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.sql.{DataFrame, Dataset} | ||
| import org.apache.spark.sql.types.StructType | ||
| import org.apache.spark.util.ThreadUtils | ||
|
|
||
| /** | ||
| * Params for [[TrainValidationSplit]] and [[TrainValidationSplitModel]]. | ||
|
|
@@ -62,7 +66,7 @@ private[ml] trait TrainValidationSplitParams extends ValidatorParams { | |
| @Since("1.5.0") | ||
| class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: String) | ||
| extends Estimator[TrainValidationSplitModel] | ||
| with TrainValidationSplitParams with MLWritable with Logging { | ||
| with TrainValidationSplitParams with HasParallelism with MLWritable with Logging { | ||
|
|
||
| @Since("1.5.0") | ||
| def this() = this(Identifiable.randomUID("tvs")) | ||
|
|
@@ -87,37 +91,63 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St | |
| @Since("2.0.0") | ||
| def setSeed(value: Long): this.type = set(seed, value) | ||
|
|
||
| /** | ||
| * Set the mamixum level of parallelism to evaluate models in parallel. | ||
| * Default is 1 for serial evaluation | ||
| * | ||
| * @group expertSetParam | ||
| */ | ||
| @Since("2.3.0") | ||
| def setParallelism(value: Int): this.type = set(parallelism, value) | ||
|
|
||
| @Since("2.0.0") | ||
| override def fit(dataset: Dataset[_]): TrainValidationSplitModel = { | ||
| val schema = dataset.schema | ||
| transformSchema(schema, logging = true) | ||
| val est = $(estimator) | ||
| val eval = $(evaluator) | ||
| val epm = $(estimatorParamMaps) | ||
| val numModels = epm.length | ||
| val metrics = new Array[Double](epm.length) | ||
|
|
||
| // Create execution context based on $(parallelism) | ||
| val executionContext = getExecutionContext | ||
|
|
||
| val instr = Instrumentation.create(this, dataset) | ||
| instr.logParams(trainRatio, seed) | ||
| instr.logParams(trainRatio, seed, parallelism) | ||
| logTuningParams(instr) | ||
|
|
||
| val Array(trainingDataset, validationDataset) = | ||
| dataset.randomSplit(Array($(trainRatio), 1 - $(trainRatio)), $(seed)) | ||
| trainingDataset.cache() | ||
| validationDataset.cache() | ||
|
|
||
| // multi-model training | ||
| // Fit models in a Future for training in parallel | ||
| logDebug(s"Train split with multiple sets of parameters.") | ||
| val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] | ||
| trainingDataset.unpersist() | ||
| var i = 0 | ||
| while (i < numModels) { | ||
| // TODO: duplicate evaluator to take extra params from input | ||
| val metric = eval.evaluate(models(i).transform(validationDataset, epm(i))) | ||
| logDebug(s"Got metric $metric for model trained with ${epm(i)}.") | ||
| metrics(i) += metric | ||
| i += 1 | ||
| val models = epm.map { paramMap => | ||
|
||
| Future[Model[_]] { | ||
| val model = est.fit(trainingDataset, paramMap) | ||
| model.asInstanceOf[Model[_]] | ||
| } (executionContext) | ||
| } | ||
|
|
||
| // Unpersist training data only when all models have trained | ||
| Future.sequence[Model[_], Iterable](models)(implicitly, executionContext).onComplete { _ => | ||
| trainingDataset.unpersist() | ||
| } (executionContext) | ||
|
|
||
| // Evaluate models in a Future that will calulate a metric and allow model to be cleaned up | ||
| val metricFutures = models.zip(epm).map { case (modelFuture, paramMap) => | ||
| modelFuture.map { model => | ||
| // TODO: duplicate evaluator to take extra params from input | ||
| val metric = eval.evaluate(model.transform(validationDataset, paramMap)) | ||
| logDebug(s"Got metric $metric for model trained with $paramMap.") | ||
| metric | ||
| } (executionContext) | ||
| } | ||
|
|
||
| // Wait for all metrics to be calculated | ||
| val metrics = metricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf)) | ||
|
|
||
| // Unpersist validation set once all metrics have been produced | ||
| validationDataset.unpersist() | ||
|
|
||
| logInfo(s"Train validation split metrics: ${metrics.toSeq}") | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -120,6 +120,33 @@ class CrossValidatorSuite | |
| } | ||
| } | ||
|
|
||
| test("cross validation with parallel evaluation") { | ||
| val lr = new LogisticRegression | ||
| val lrParamMaps = new ParamGridBuilder() | ||
| .addGrid(lr.regParam, Array(0.001, 1000.0)) | ||
| .addGrid(lr.maxIter, Array(0, 3)) | ||
| .build() | ||
| val eval = new BinaryClassificationEvaluator | ||
| val cv = new CrossValidator() | ||
| .setEstimator(lr) | ||
| .setEstimatorParamMaps(lrParamMaps) | ||
| .setEvaluator(eval) | ||
| .setNumFolds(2) | ||
| .setParallelism(1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So the seed param here is fixed by default and doesn't need to be set to ensure consistent results. I think that's why it's not set in the other tests in this suite. I'm not a fan of this behavior and I think it's better to explicitly set in tests, but then we should probably be consistent and set elsewhere too. What are your thoughts on this @MLnick ?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK I agree.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah seed defaults to a hash of the class name. There has been debate over this (see SPARK-16832). Personally I also don't like that behavior, but for now that's what it is. |
||
| val cvSerialModel = cv.fit(dataset) | ||
| cv.setParallelism(2) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how do we validate setParallelism is parallelizing?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a little difficult to do this in a unit test without making it flaky. I have run tests manually and verified it is working by both the expected speedup in timing and that the expected number of tasks are run concurrently. I can post some results if that would help.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @BryanCutler may be worth posting the result to the JIRA for posterity. |
||
| val cvParallelModel = cv.fit(dataset) | ||
|
|
||
| val serialMetrics = cvSerialModel.avgMetrics.sorted | ||
| val parallelMetrics = cvParallelModel.avgMetrics.sorted | ||
| assert(serialMetrics === parallelMetrics) | ||
|
|
||
| val parentSerial = cvSerialModel.bestModel.parent.asInstanceOf[LogisticRegression] | ||
| val parentParallel = cvParallelModel.bestModel.parent.asInstanceOf[LogisticRegression] | ||
| assert(parentSerial.getRegParam === parentParallel.getRegParam) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it make sense to also check
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should probably be done in the test that already runs |
||
| assert(parentSerial.getMaxIter === parentParallel.getMaxIter) | ||
| } | ||
|
|
||
| test("read/write: CrossValidator with simple estimator") { | ||
| val lr = new LogisticRegression().setMaxIter(3) | ||
| val evaluator = new BinaryClassificationEvaluator() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the corresponding PR for PySpark implementation the number of threads is limited by the number of models to be trained (https://github.com/WeichenXu123/spark/blob/be2f3d0ec50db4730c9e3f9a813a4eb96889f5b6/python/pyspark/ml/tuning.py#L261). We might do that for instance by overriding the
getParallelismmethod. What do you think about this?