-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-19357][ML] Adding parallel model evaluation in ML tuning #16774
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
5650e98
36a1a68
b051afa
46fe252
1274ba4
80ac2fd
8126710
6a9b735
1c2e391
9e055cd
97ad7b4
5e8a086
864c99c
ad8a870
911af1d
658aacb
2c73b0b
7a8221b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
…ssary
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,7 +23,6 @@ import scala.collection.JavaConverters._ | |
| import scala.concurrent.Future | ||
| import scala.concurrent.duration.Duration | ||
|
|
||
| import com.github.fommil.netlib.F2jBLAS | ||
| import org.apache.hadoop.fs.Path | ||
| import org.json4s.DefaultFormats | ||
|
|
||
|
|
@@ -73,8 +72,6 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) | |
| @Since("1.2.0") | ||
| def this() = this(Identifiable.randomUID("cv")) | ||
|
|
||
| private val f2jBLAS = new F2jBLAS | ||
|
|
||
| /** @group setParam */ | ||
| @Since("1.2.0") | ||
| def setEstimator(value: Estimator[_]): this.type = set(estimator, value) | ||
|
|
@@ -112,7 +109,6 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) | |
| val est = $(estimator) | ||
| val eval = $(evaluator) | ||
| val epm = $(estimatorParamMaps) | ||
| val numModels = epm.length | ||
|
|
||
| // Create execution context based on $(parallelism) | ||
| val executionContext = getExecutionContext | ||
|
|
@@ -129,20 +125,19 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) | |
| logDebug(s"Train split $splitIndex with multiple sets of parameters.") | ||
|
|
||
| // Fit models in a Future for training in parallel | ||
| val models = epm.map { paramMap => | ||
| val modelFutures = epm.map { paramMap => | ||
| Future[Model[_]] { | ||
| val model = est.fit(trainingDataset, paramMap) | ||
| model.asInstanceOf[Model[_]] | ||
| } (executionContext) | ||
| } | ||
|
|
||
| // Unpersist training data only when all models have trained | ||
| Future.sequence[Model[_], Iterable](models)(implicitly, executionContext).onComplete { _ => | ||
| trainingDataset.unpersist() | ||
| } (executionContext) | ||
| Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext) | ||
| .onComplete { _ => trainingDataset.unpersist() } (executionContext) | ||
|
|
||
| // Evaluate models in a Future that will calulate a metric and allow model to be cleaned up | ||
| val foldMetricFutures = models.zip(epm).map { case (modelFuture, paramMap) => | ||
| val foldMetricFutures = modelFutures.zip(epm).map { case (modelFuture, paramMap) => | ||
| modelFuture.map { model => | ||
| // TODO: duplicate evaluator to take extra params from input | ||
| val metric = eval.evaluate(model.transform(validationDataset, paramMap)) | ||
|
|
@@ -155,10 +150,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) | |
| val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it make sense to also use val metrics = (ThreadUtils.awaitResult(
Future.sequence[Double, Iterable](metricFutures), Duration.Inf)).toArray
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought about that, but since it's a blocking call anyway, it will still be bound by the longest running thread.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, not a big deal either way |
||
| validationDataset.unpersist() | ||
| foldMetrics | ||
| }.transpose.map(_.sum) | ||
|
|
||
| // Calculate average metric over all splits | ||
| f2jBLAS.dscal(numModels, 1.0 / $(numFolds), metrics, 1) | ||
| }.transpose.map(_.sum / $(numFolds)) // Calculate average metric over all splits | ||
|
|
||
| logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") | ||
| val (bestMetric, bestIndex) = | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the corresponding PR for PySpark implementation the number of threads is limited by the number of models to be trained (https://github.com/WeichenXu123/spark/blob/be2f3d0ec50db4730c9e3f9a813a4eb96889f5b6/python/pyspark/ml/tuning.py#L261). We might do that for instance by overriding the
getParallelismmethod. What do you think about this?