Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
5650e98
Changed CrossValidator and TrainValidationSplit fit methods to evalua…
BryanCutler Jan 31, 2017
36a1a68
made closure vars more explicit, moved param default to trait
BryanCutler Feb 14, 2017
b051afa
added paramvalidator for numParallelEval to ensure >=1
BryanCutler Feb 14, 2017
46fe252
added test cases for CrossValidation and TrainValidationSplit
BryanCutler Feb 15, 2017
1274ba4
added numParallelEval param usage to examples
BryanCutler Feb 15, 2017
80ac2fd
added documentation to ml-tuning
BryanCutler Feb 16, 2017
8126710
changed sliding window limit to use a semaphore instead to prevent wa…
BryanCutler Feb 16, 2017
6a9b735
added note about parallelism capped by Scala collection thread pool, …
BryanCutler Feb 16, 2017
1c2e391
reworked to use ExecutorService and Futures
BryanCutler Feb 28, 2017
9e055cd
fixed wildcard import
BryanCutler Feb 28, 2017
97ad7b4
made doc changes
BryanCutler Apr 11, 2017
5e8a086
changed ExecutorService factory to a trait to be compatible with Java
BryanCutler Apr 12, 2017
864c99c
Merge remote-tracking branch 'upstream/master' into parallel-model-ev…
BryanCutler Jun 13, 2017
ad8a870
Changed ExecutorService to be set explicitly instead of factory
BryanCutler Jun 14, 2017
911af1d
added HasParallelism trait
BryanCutler Aug 23, 2017
658aacb
Updated to use Trait HasParallelsim
BryanCutler Aug 23, 2017
2c73b0b
fixed up docs
BryanCutler Aug 23, 2017
7a8221b
removed blas calculation for CrossValidator metric calc, was not nece…
BryanCutler Sep 5, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
removed blas calculation for CrossValidator metric calc, was not nece…
…ssary
  • Loading branch information
BryanCutler committed Sep 5, 2017
commit 7a8221ba939da026a9818299cf2b897ea81766a5
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import scala.collection.JavaConverters._
import scala.concurrent.Future
import scala.concurrent.duration.Duration

import com.github.fommil.netlib.F2jBLAS
import org.apache.hadoop.fs.Path
import org.json4s.DefaultFormats

Expand Down Expand Up @@ -73,8 +72,6 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
@Since("1.2.0")
def this() = this(Identifiable.randomUID("cv"))

private val f2jBLAS = new F2jBLAS

/** @group setParam */
@Since("1.2.0")
def setEstimator(value: Estimator[_]): this.type = set(estimator, value)
Expand Down Expand Up @@ -112,7 +109,6 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
val est = $(estimator)
val eval = $(evaluator)
val epm = $(estimatorParamMaps)
val numModels = epm.length

// Create execution context based on $(parallelism)
val executionContext = getExecutionContext
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the corresponding PR for PySpark implementation the number of threads is limited by the number of models to be trained (https://github.com/WeichenXu123/spark/blob/be2f3d0ec50db4730c9e3f9a813a4eb96889f5b6/python/pyspark/ml/tuning.py#L261). We might do that for instance by overriding the getParallelism method. What do you think about this?

Expand All @@ -129,20 +125,19 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
logDebug(s"Train split $splitIndex with multiple sets of parameters.")

// Fit models in a Future for training in parallel
val models = epm.map { paramMap =>
val modelFutures = epm.map { paramMap =>
Future[Model[_]] {
val model = est.fit(trainingDataset, paramMap)
model.asInstanceOf[Model[_]]
} (executionContext)
}

// Unpersist training data only when all models have trained
Future.sequence[Model[_], Iterable](models)(implicitly, executionContext).onComplete { _ =>
trainingDataset.unpersist()
} (executionContext)
Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext)
.onComplete { _ => trainingDataset.unpersist() } (executionContext)

// Evaluate models in a Future that will calulate a metric and allow model to be cleaned up
val foldMetricFutures = models.zip(epm).map { case (modelFuture, paramMap) =>
val foldMetricFutures = modelFutures.zip(epm).map { case (modelFuture, paramMap) =>
modelFuture.map { model =>
// TODO: duplicate evaluator to take extra params from input
val metric = eval.evaluate(model.transform(validationDataset, paramMap))
Expand All @@ -155,10 +150,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to also use sequence here?

    val metrics = (ThreadUtils.awaitResult(
      Future.sequence[Double, Iterable](metricFutures), Duration.Inf)).toArray

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about that, but since it's a blocking call anyway, it will still be bound by the longest running thread.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, not a big deal either way

validationDataset.unpersist()
foldMetrics
}.transpose.map(_.sum)

// Calculate average metric over all splits
f2jBLAS.dscal(numModels, 1.0 / $(numFolds), metrics, 1)
}.transpose.map(_.sum / $(numFolds)) // Calculate average metric over all splits

logInfo(s"Average cross-validation metrics: ${metrics.toSeq}")
val (bestMetric, bestIndex) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,20 +122,19 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St

// Fit models in a Future for training in parallel
logDebug(s"Train split with multiple sets of parameters.")
val models = epm.map { paramMap =>
val modelFutures = epm.map { paramMap =>
Future[Model[_]] {
val model = est.fit(trainingDataset, paramMap)
model.asInstanceOf[Model[_]]
} (executionContext)
}

// Unpersist training data only when all models have trained
Future.sequence[Model[_], Iterable](models)(implicitly, executionContext).onComplete { _ =>
trainingDataset.unpersist()
} (executionContext)
Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext)
.onComplete { _ => trainingDataset.unpersist() } (executionContext)

// Evaluate models in a Future that will calulate a metric and allow model to be cleaned up
val metricFutures = models.zip(epm).map { case (modelFuture, paramMap) =>
val metricFutures = modelFutures.zip(epm).map { case (modelFuture, paramMap) =>
modelFuture.map { model =>
// TODO: duplicate evaluator to take extra params from input
val metric = eval.evaluate(model.transform(validationDataset, paramMap))
Expand Down