|
18 | 18 | package org.apache.spark.ml.tuning |
19 | 19 |
|
20 | 20 | import java.util.{List => JList} |
21 | | -import java.util.concurrent.Semaphore |
22 | 21 |
|
23 | 22 | import scala.collection.JavaConverters._ |
| 23 | +import scala.concurrent.{ExecutionContext, Future} |
| 24 | +import scala.concurrent.duration.Duration |
24 | 25 |
|
25 | 26 | import com.github.fommil.netlib.F2jBLAS |
26 | 27 | import org.apache.hadoop.fs.Path |
27 | 28 | import org.json4s.DefaultFormats |
28 | 29 |
|
29 | 30 | import org.apache.spark.annotation.Since |
30 | 31 | import org.apache.spark.internal.Logging |
31 | | -import org.apache.spark.ml._ |
| 32 | +import org.apache.spark.ml.{Estimator, Model} |
32 | 33 | import org.apache.spark.ml.evaluation.Evaluator |
33 | | -import org.apache.spark.ml.param._ |
| 34 | +import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} |
34 | 35 | import org.apache.spark.ml.util._ |
35 | 36 | import org.apache.spark.mllib.util.MLUtils |
36 | 37 | import org.apache.spark.sql.{DataFrame, Dataset} |
37 | 38 | import org.apache.spark.sql.types.StructType |
| 39 | +import org.apache.spark.util.ThreadUtils |
| 40 | + |
38 | 41 |
|
39 | 42 | /** |
40 | 43 | * Params for [[CrossValidator]] and [[CrossValidatorModel]]. |
@@ -105,48 +108,58 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) |
105 | 108 | val eval = $(evaluator) |
106 | 109 | val epm = $(estimatorParamMaps) |
107 | 110 | val numModels = epm.length |
108 | | - // Barrier to limit parallelism during model fit/evaluation |
109 | | - // NOTE: will be capped by size of thread pool used in Scala parallel collections, which is |
110 | | - // number of cores in the system by default |
111 | | - val numParBarrier = new Semaphore($(numParallelEval)) |
| 111 | + |
| 112 | + // Create execution context, run in serial if numParallelEval is 1 |
| 113 | + val executionContext = $(numParallelEval) match { |
| 114 | + case 1 => |
| 115 | + ThreadUtils.sameThread |
| 116 | + case n => |
| 117 | + ExecutionContext.fromExecutorService(executorServiceFactory(n)) |
| 118 | + } |
112 | 119 |
|
113 | 120 | val instr = Instrumentation.create(this, dataset) |
114 | 121 | instr.logParams(numFolds, seed) |
115 | 122 | logTuningParams(instr) |
116 | 123 |
|
117 | | - // Compute metrics for each model over each fold |
118 | | - logDebug("Running cross-validation with level of parallelism: " + |
119 | | - s"${numParBarrier.availablePermits()}.") |
| 124 | + // Compute metrics for each model over each split |
| 125 | + logDebug(s"Running cross-validation with level of parallelism: $numParallelEval.") |
120 | 126 | val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) |
121 | 127 | val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) => |
122 | 128 | val trainingDataset = sparkSession.createDataFrame(training, schema).cache() |
123 | 129 | val validationDataset = sparkSession.createDataFrame(validation, schema).cache() |
124 | 130 | logDebug(s"Train split $splitIndex with multiple sets of parameters.") |
125 | 131 |
|
126 | | - // Fit models concurrently, limited by a barrier with '$numParallelEval' permits |
127 | | - val models = epm.par.map { paramMap => |
128 | | - numParBarrier.acquire() |
129 | | - val model = est.fit(trainingDataset, paramMap) |
130 | | - numParBarrier.release() |
131 | | - model.asInstanceOf[Model[_]] |
132 | | - }.seq |
133 | | - trainingDataset.unpersist() |
134 | | - |
135 | | - // Evaluate models concurrently, limited by a barrier with '$numParallelEval' permits |
136 | | - val foldMetrics = models.zip(epm).par.map { case (model, paramMap) => |
137 | | - numParBarrier.acquire() |
138 | | - // TODO: duplicate evaluator to take extra params from input |
139 | | - val metric = eval.evaluate(model.transform(validationDataset, paramMap)) |
140 | | - numParBarrier.release() |
141 | | - logDebug(s"Got metric $metric for model trained with $paramMap.") |
142 | | - metric |
143 | | - }.seq |
144 | | - |
| 132 | + // Fit models in a Future with thread-pool size determined by '$numParallelEval' |
| 133 | + val models = epm.map { paramMap => |
| 134 | + Future[Model[_]] { |
| 135 | + val model = est.fit(trainingDataset, paramMap) |
| 136 | + model.asInstanceOf[Model[_]] |
| 137 | + } (executionContext) |
| 138 | + } |
| 139 | + |
| 140 | + Future.sequence[Model[_], Iterable](models)(implicitly, executionContext).onComplete { _ => |
| 141 | + trainingDataset.unpersist() |
| 142 | + } (executionContext) |
| 143 | + |
| 144 | + // Evaluate models in a Future with thread-pool size determined by '$numParallelEval' |
| 145 | + val foldMetricFutures = models.zip(epm).map { case (modelFuture, paramMap) => |
| 146 | + modelFuture.flatMap { model => |
| 147 | + Future { |
| 148 | + // TODO: duplicate evaluator to take extra params from input |
| 149 | + val metric = eval.evaluate(model.transform(validationDataset, paramMap)) |
| 150 | + logDebug(s"Got metric $metric for model trained with $paramMap.") |
| 151 | + metric |
| 152 | + } (executionContext) |
| 153 | + } (executionContext) |
| 154 | + } |
| 155 | + |
| 156 | + // Wait for metrics to be calculated before upersisting validation dataset |
| 157 | + val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf)) |
145 | 158 | validationDataset.unpersist() |
146 | 159 | foldMetrics |
147 | | - }.reduce((mA, mB) => mA.zip(mB).map(m => m._1 + m._2)).toArray |
| 160 | + }.transpose.map(_.sum) |
148 | 161 |
|
149 | | - // Calculate average metric for all folds |
| 162 | + // Calculate average metric over all splits |
150 | 163 | f2jBLAS.dscal(numModels, 1.0 / $(numFolds), metrics, 1) |
151 | 164 |
|
152 | 165 | logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") |
|
0 commit comments