Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
init pr
  • Loading branch information
WeichenXu123 committed Dec 8, 2017
commit ec50dadfb0fa050945fbc0804c048de1e332d19e
Original file line number Diff line number Diff line change
Expand Up @@ -146,31 +146,34 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
val validationDataset = sparkSession.createDataFrame(validation, schema).cache()
logDebug(s"Train split $splitIndex with multiple sets of parameters.")

var completeFitCount = 0
val signal = new Object
// Fit models in a Future for training in parallel
val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>
Future[Model[_]] {
val foldMetricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>
Future[Double] {
val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]]
signal.synchronized {
completeFitCount += 1
signal.notify()
}

if (collectSubModelsParam) {
subModels.get(splitIndex)(paramIndex) = model
}
model
} (executionContext)
}

// Unpersist training data only when all models have trained
Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext)
.onComplete { _ => trainingDataset.unpersist() } (executionContext)

// Evaluate models in a Future that will calulate a metric and allow model to be cleaned up
val foldMetricFutures = modelFutures.zip(epm).map { case (modelFuture, paramMap) =>
modelFuture.map { model =>
// TODO: duplicate evaluator to take extra params from input
val metric = eval.evaluate(model.transform(validationDataset, paramMap))
logDebug(s"Got metric $metric for model trained with $paramMap.")
metric
} (executionContext)
}
Future {
signal.synchronized {
while (completeFitCount < epm.length) {
Copy link
Contributor

@MrBago MrBago Dec 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I'm not too familiar with Futures in Scala. Is it save to create a blocking future like this, do you risk starving the thread pool? Can we jus an if statement in the synchronized block above? something like:

completeFitCount += 1
if (completeFitCount == epm.length) {
    trainingDataset.unpersist()
}

Copy link
Contributor Author

@WeichenXu123 WeichenXu123 Dec 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah! your idea is really good. We don't need any "wait thread" to do this unpersist. Your solution will be simpler.
btw, about the risk of starving the thread pool, if it has the issue, then the current master code will also have this issue (because the "Future.sequence" thread use the same thread pool). But the thread was added into threadpool at the last, if it is scheduled to launch at the last, so won't casue this issue. But it seems depend on the threadpool implementation.

signal.wait()
}
}
trainingDataset.unpersist()
} (executionContext)

// Wait for metrics to be calculated before unpersisting validation dataset
val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf))
Expand Down