Skip to content

Commit 1c2e391

Browse files
committed
reworked to use ExecutorService and Futures
1 parent 6a9b735 commit 1c2e391

File tree

3 files changed

+105
-53
lines changed

3 files changed

+105
-53
lines changed

mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,26 @@
1818
package org.apache.spark.ml.tuning
1919

2020
import java.util.{List => JList}
21-
import java.util.concurrent.Semaphore
2221

2322
import scala.collection.JavaConverters._
23+
import scala.concurrent.{ExecutionContext, Future}
24+
import scala.concurrent.duration.Duration
2425

2526
import com.github.fommil.netlib.F2jBLAS
2627
import org.apache.hadoop.fs.Path
2728
import org.json4s.DefaultFormats
2829

2930
import org.apache.spark.annotation.Since
3031
import org.apache.spark.internal.Logging
31-
import org.apache.spark.ml._
32+
import org.apache.spark.ml.{Estimator, Model}
3233
import org.apache.spark.ml.evaluation.Evaluator
33-
import org.apache.spark.ml.param._
34+
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
3435
import org.apache.spark.ml.util._
3536
import org.apache.spark.mllib.util.MLUtils
3637
import org.apache.spark.sql.{DataFrame, Dataset}
3738
import org.apache.spark.sql.types.StructType
39+
import org.apache.spark.util.ThreadUtils
40+
3841

3942
/**
4043
* Params for [[CrossValidator]] and [[CrossValidatorModel]].
@@ -105,48 +108,58 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
105108
val eval = $(evaluator)
106109
val epm = $(estimatorParamMaps)
107110
val numModels = epm.length
108-
// Barrier to limit parallelism during model fit/evaluation
109-
// NOTE: will be capped by size of thread pool used in Scala parallel collections, which is
110-
// number of cores in the system by default
111-
val numParBarrier = new Semaphore($(numParallelEval))
111+
112+
// Create execution context, run in serial if numParallelEval is 1
113+
val executionContext = $(numParallelEval) match {
114+
case 1 =>
115+
ThreadUtils.sameThread
116+
case n =>
117+
ExecutionContext.fromExecutorService(executorServiceFactory(n))
118+
}
112119

113120
val instr = Instrumentation.create(this, dataset)
114121
instr.logParams(numFolds, seed)
115122
logTuningParams(instr)
116123

117-
// Compute metrics for each model over each fold
118-
logDebug("Running cross-validation with level of parallelism: " +
119-
s"${numParBarrier.availablePermits()}.")
124+
// Compute metrics for each model over each split
125+
logDebug(s"Running cross-validation with level of parallelism: $numParallelEval.")
120126
val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed))
121127
val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) =>
122128
val trainingDataset = sparkSession.createDataFrame(training, schema).cache()
123129
val validationDataset = sparkSession.createDataFrame(validation, schema).cache()
124130
logDebug(s"Train split $splitIndex with multiple sets of parameters.")
125131

126-
// Fit models concurrently, limited by a barrier with '$numParallelEval' permits
127-
val models = epm.par.map { paramMap =>
128-
numParBarrier.acquire()
129-
val model = est.fit(trainingDataset, paramMap)
130-
numParBarrier.release()
131-
model.asInstanceOf[Model[_]]
132-
}.seq
133-
trainingDataset.unpersist()
134-
135-
// Evaluate models concurrently, limited by a barrier with '$numParallelEval' permits
136-
val foldMetrics = models.zip(epm).par.map { case (model, paramMap) =>
137-
numParBarrier.acquire()
138-
// TODO: duplicate evaluator to take extra params from input
139-
val metric = eval.evaluate(model.transform(validationDataset, paramMap))
140-
numParBarrier.release()
141-
logDebug(s"Got metric $metric for model trained with $paramMap.")
142-
metric
143-
}.seq
144-
132+
// Fit models in a Future with thread-pool size determined by '$numParallelEval'
133+
val models = epm.map { paramMap =>
134+
Future[Model[_]] {
135+
val model = est.fit(trainingDataset, paramMap)
136+
model.asInstanceOf[Model[_]]
137+
} (executionContext)
138+
}
139+
140+
Future.sequence[Model[_], Iterable](models)(implicitly, executionContext).onComplete { _ =>
141+
trainingDataset.unpersist()
142+
} (executionContext)
143+
144+
// Evaluate models in a Future with thread-pool size determined by '$numParallelEval'
145+
val foldMetricFutures = models.zip(epm).map { case (modelFuture, paramMap) =>
146+
modelFuture.flatMap { model =>
147+
Future {
148+
// TODO: duplicate evaluator to take extra params from input
149+
val metric = eval.evaluate(model.transform(validationDataset, paramMap))
150+
logDebug(s"Got metric $metric for model trained with $paramMap.")
151+
metric
152+
} (executionContext)
153+
} (executionContext)
154+
}
155+
156+
// Wait for metrics to be calculated before upersisting validation dataset
157+
val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf))
145158
validationDataset.unpersist()
146159
foldMetrics
147-
}.reduce((mA, mB) => mA.zip(mB).map(m => m._1 + m._2)).toArray
160+
}.transpose.map(_.sum)
148161

149-
// Calculate average metric for all folds
162+
// Calculate average metric over all splits
150163
f2jBLAS.dscal(numModels, 1.0 / $(numFolds), metrics, 1)
151164

152165
logInfo(s"Average cross-validation metrics: ${metrics.toSeq}")

mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818
package org.apache.spark.ml.tuning
1919

2020
import java.util.{List => JList}
21-
import java.util.concurrent.Semaphore
2221

2322
import scala.collection.JavaConverters._
23+
import scala.concurrent.{ExecutionContext, Future}
24+
import scala.concurrent.duration.Duration
2425
import scala.language.existentials
2526

2627
import org.apache.hadoop.fs.Path
@@ -34,6 +35,7 @@ import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators}
3435
import org.apache.spark.ml.util._
3536
import org.apache.spark.sql.{DataFrame, Dataset}
3637
import org.apache.spark.sql.types.StructType
38+
import org.apache.spark.util.ThreadUtils
3739

3840
/**
3941
* Params for [[TrainValidationSplit]] and [[TrainValidationSplitModel]].
@@ -99,40 +101,53 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
99101
val est = $(estimator)
100102
val eval = $(evaluator)
101103
val epm = $(estimatorParamMaps)
102-
// Barrier to limit parallelism during model fit/evaluation
103-
// NOTE: will be capped by size of thread pool used in Scala parallel collections, which is
104-
// number of cores in the system by default
105-
val numParBarrier = new Semaphore($(numParallelEval))
106-
logDebug(s"Running validation with level of parallelism: ${numParBarrier.availablePermits()}.")
104+
105+
// Create execution context, run in serial if numParallelEval is 1
106+
val executionContext = $(numParallelEval) match {
107+
case 1 =>
108+
ThreadUtils.sameThread
109+
case n =>
110+
ExecutionContext.fromExecutorService(executorServiceFactory(n))
111+
}
107112

108113
val instr = Instrumentation.create(this, dataset)
109114
instr.logParams(trainRatio, seed)
110115
logTuningParams(instr)
111116

117+
logDebug(s"Running validation with level of parallelism: $numParallelEval.")
112118
val Array(trainingDataset, validationDataset) =
113119
dataset.randomSplit(Array($(trainRatio), 1 - $(trainRatio)), $(seed))
114120
trainingDataset.cache()
115121
validationDataset.cache()
116122

117-
// Fit models concurrently, limited by a barrier with '$numParallelEval' permits
123+
// Fit models in a Future with thread-pool size determined by '$numParallelEval'
118124
logDebug(s"Train split with multiple sets of parameters.")
119-
val models = epm.par.map { paramMap =>
120-
numParBarrier.acquire()
121-
val model = est.fit(trainingDataset, paramMap)
122-
numParBarrier.release()
123-
model.asInstanceOf[Model[_]]
124-
}.seq
125-
trainingDataset.unpersist()
125+
val models = epm.map { paramMap =>
126+
Future[Model[_]] {
127+
val model = est.fit(trainingDataset, paramMap)
128+
model.asInstanceOf[Model[_]]
129+
} (executionContext)
130+
}
131+
132+
Future.sequence[Model[_], Iterable](models)(implicitly, executionContext).onComplete { _ =>
133+
trainingDataset.unpersist()
134+
} (executionContext)
126135

127136
// Evaluate models concurrently, limited by a barrier with '$numParallelEval' permits
128-
val metrics = models.zip(epm).par.map { case (model, paramMap) =>
129-
numParBarrier.acquire()
130-
// TODO: duplicate evaluator to take extra params from input
131-
val metric = eval.evaluate(model.transform(validationDataset, paramMap))
132-
numParBarrier.release()
133-
logDebug(s"Got metric $metric for model trained with $paramMap.")
134-
metric
135-
}.seq.toArray
137+
val metricFutures = models.zip(epm).map { case (modelFuture, paramMap) =>
138+
modelFuture.flatMap { model =>
139+
Future {
140+
// TODO: duplicate evaluator to take extra params from input
141+
val metric = eval.evaluate(model.transform(validationDataset, paramMap))
142+
logDebug(s"Got metric $metric for model trained with $paramMap.")
143+
metric
144+
} (executionContext)
145+
} (executionContext)
146+
}
147+
148+
// Wait for all metrics to be calculated
149+
val metrics = metricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf))
150+
136151
validationDataset.unpersist()
137152

138153
logInfo(s"Train validation split metrics: ${metrics.toSeq}")

mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@
1717

1818
package org.apache.spark.ml.tuning
1919

20+
import java.util.concurrent.ExecutorService
21+
2022
import org.apache.hadoop.fs.Path
2123
import org.json4s.{DefaultFormats, _}
2224
import org.json4s.jackson.JsonMethods._
2325

26+
import org.apache.spark.annotation.{Experimental, InterfaceStability}
2427
import org.apache.spark.SparkContext
2528
import org.apache.spark.ml.{Estimator, Model}
2629
import org.apache.spark.ml.evaluation.Evaluator
@@ -29,6 +32,7 @@ import org.apache.spark.ml.param.shared.HasSeed
2932
import org.apache.spark.ml.util._
3033
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
3134
import org.apache.spark.sql.types.StructType
35+
import org.apache.spark.util.ThreadUtils
3236

3337
/**
3438
* Common params for [[TrainValidationSplitParams]] and [[CrossValidatorParams]].
@@ -80,6 +84,26 @@ private[ml] trait ValidatorParams extends HasSeed with Params {
8084
/** @group getParam */
8185
def getNumParallelEval: Int = $(numParallelEval)
8286

87+
/**
88+
* Creates a execution service to be used for validation, defaults to a thread-pool with
89+
* size of `numParallelEval`
90+
*/
91+
protected var executorServiceFactory: (Int) => ExecutorService = {
92+
(requestedMaxThreads: Int) => ThreadUtils.newDaemonCachedThreadPool(
93+
s"${this.getClass.getSimpleName}-thread-pool", requestedMaxThreads)
94+
}
95+
96+
/**
97+
* Sets a function to get an execution service to be used for validation
98+
*
99+
* @param getExecutorService function to get an ExecutorService given a requestedMaxThread size
100+
*/
101+
@Experimental
102+
@InterfaceStability.Unstable
103+
def setExecutorService(getExecutorService: (Int) => ExecutorService): Unit = {
104+
executorServiceFactory = getExecutorService
105+
}
106+
83107
protected def transformSchemaImpl(schema: StructType): StructType = {
84108
require($(estimatorParamMaps).nonEmpty, s"Validator requires non-empty estimatorParamMaps")
85109
val firstEstimatorParamMap = $(estimatorParamMaps).head

0 commit comments

Comments
 (0)