Skip to content

Commit b987319

Browse files
committed
Partly done with adding checks, but blocking on adding checking functionality to Param
1 parent dbc9fb2 commit b987319

File tree

5 files changed

+82
-13
lines changed

5 files changed

+82
-13
lines changed

mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import scala.collection.mutable.ListBuffer
2121

2222
import org.apache.spark.Logging
2323
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
24-
import org.apache.spark.ml.param.{Param, ParamMap}
24+
import org.apache.spark.ml.param.{Params, Param, ParamMap}
2525
import org.apache.spark.sql.DataFrame
2626
import org.apache.spark.sql.types.StructType
2727

@@ -86,6 +86,13 @@ class Pipeline extends Estimator[PipelineModel] {
8686
def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this }
8787
def getStages: Array[PipelineStage] = getOrDefault(stages)
8888

89+
override def validate(paramMap: ParamMap): Unit = {
90+
val map = extractParamMap(paramMap)
91+
getStages.foreach {
92+
case pStage: Params => pStage.validate(map)
93+
}
94+
}
95+
8996
/**
9097
* Fits the pipeline to the input dataset with additional parameters. If a stage is an
9198
* [[Estimator]], its [[Estimator#fit]] method will be called on the input dataset to fit a model.
@@ -140,7 +147,7 @@ class Pipeline extends Estimator[PipelineModel] {
140147
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
141148
val map = extractParamMap(paramMap)
142149
val theStages = map(stages)
143-
require(theStages.toSet.size == theStages.size,
150+
require(theStages.toSet.size == theStages.length,
144151
"Cannot have duplicate components in a pipeline.")
145152
theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur, paramMap))
146153
}
@@ -157,6 +164,11 @@ class PipelineModel private[ml] (
157164
private[ml] val stages: Array[Transformer])
158165
extends Model[PipelineModel] with Logging {
159166

167+
override def validate(paramMap: ParamMap): Unit = {
168+
val map = fittingParamMap ++ extractParamMap(paramMap)
169+
stages.foreach(_.validate(map))
170+
}
171+
160172
/**
161173
* Gets the model produced by the input estimator. Throws an NoSuchElementException is the input
162174
* estimator does not exist in the pipeline.
@@ -168,7 +180,7 @@ class PipelineModel private[ml] (
168180
}
169181
if (matched.isEmpty) {
170182
throw new NoSuchElementException(s"Cannot find stage $stage from the pipeline.")
171-
} else if (matched.size > 1) {
183+
} else if (matched.length > 1) {
172184
throw new IllegalStateException(s"Cannot have duplicate estimators in the sample pipeline.")
173185
} else {
174186
matched.head.asInstanceOf[M]

mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ import org.apache.spark.sql.types._
3838
@AlphaComponent
3939
class VectorAssembler extends Transformer with HasInputCols with HasOutputCol {
4040

41+
override def validate(paramMap: ParamMap): Unit = { }
42+
4143
/** @group setParam */
4244
def setInputCols(value: Array[String]): this.type = set(inputCols, value)
4345

mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,23 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu
3737
/**
3838
* Threshold for the number of values a categorical feature can take.
3939
* If a feature is found to have > maxCategories values, then it is declared continuous.
40+
* Must be >= 2.
4041
*
4142
* (default = 20)
4243
*/
4344
val maxCategories = new IntParam(this, "maxCategories",
44-
"Threshold for the number of values a categorical feature can take." +
45+
"Threshold for the number of values a categorical feature can take (>= 2)." +
4546
" If a feature is found to have > maxCategories values, then it is declared continuous.")
4647

48+
setDefault(maxCategories -> 20)
49+
4750
/** @group getParam */
4851
def getMaxCategories: Int = getOrDefault(maxCategories)
4952

50-
setDefault(maxCategories -> 20)
53+
override def validate(paramMap: ParamMap): Unit = {
54+
require(getOrDefault(maxCategories) >= 2,
55+
s"VectorIndexer maxCategories must be >= 2, but was ${getOrDefault(maxCategories)}")
56+
}
5157
}
5258

5359
/**

mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,13 @@ import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
3838
private[ml] trait DecisionTreeParams extends PredictorParams {
3939

4040
/**
41-
* Maximum depth of the tree.
41+
* Maximum depth of the tree (>= 0).
4242
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
4343
* (default = 5)
4444
* @group param
4545
*/
4646
final val maxDepth: IntParam =
47-
new IntParam(this, "maxDepth", "Maximum depth of the tree." +
47+
new IntParam(this, "maxDepth", "Maximum depth of the tree. (>= 0)" +
4848
" E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.")
4949

5050
/**
@@ -173,6 +173,24 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
173173
/** @group expertGetParam */
174174
final def getCheckpointInterval: Int = getOrDefault(checkpointInterval)
175175

176+
/**
177+
* Same as [[validate()]], but renamed to force concrete classes to explicitly implement
178+
* validation (in case concrete classes have their own parameters).
179+
*/
180+
protected def validateImpl(paramMap: ParamMap): Unit = {
181+
val map = extractParamMap(paramMap)
182+
require(map(maxDepth) >= 0, s"${this.getClass.getSimpleName}" +
183+
s" maxDepth must be >= 0, but was ${map(maxDepth)}")
184+
require(map(maxBins) >= 2, s"${this.getClass.getSimpleName}" +
185+
s" maxBins must be >= 2, but was ${map(maxBins)}")
186+
require(map(minInstancesPerNode) >= 1, s"${this.getClass.getSimpleName}" +
187+
s" minInstancesPerNode must be >= 1, but was ${map(minInstancesPerNode)}")
188+
require(map(maxMemoryInMB) > 0, s"${this.getClass.getSimpleName}" +
189+
s" maxMemoryInMB must be > 0, but was ${map(maxMemoryInMB)}")
190+
require(map(checkpointInterval) >= 1, s"${this.getClass.getSimpleName}" +
191+
s" checkpointInterval must be >= 1, but was ${map(checkpointInterval)}")
192+
}
193+
176194
/** (private[ml]) Create a Strategy instance to use with the old API. */
177195
private[ml] def getOldStrategy(
178196
categoricalFeatures: Map[Int, Int],
@@ -299,12 +317,12 @@ private[ml] object TreeRegressorParams {
299317
private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed {
300318

301319
/**
302-
* Fraction of the training data used for learning each decision tree.
320+
* Fraction of the training data used for learning each decision tree, in range (0, 1].
303321
* (default = 1.0)
304322
* @group param
305323
*/
306324
final val subsamplingRate: DoubleParam = new DoubleParam(this, "subsamplingRate",
307-
"Fraction of the training data used for learning each decision tree.")
325+
"Fraction of the training data used for learning each decision tree, in range (0, 1].")
308326

309327
setDefault(subsamplingRate -> 1.0)
310328

@@ -321,6 +339,14 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed {
321339
/** @group setParam */
322340
def setSeed(value: Long): this.type = set(seed, value)
323341

342+
override protected def validateImpl(paramMap: ParamMap): Unit = {
343+
super.validateImpl(paramMap)
344+
val map = extractParamMap(paramMap)
345+
val rate = map(subsamplingRate)
346+
require(0.0 < rate && rate <= 1.0, s"${this.getClass.getSimpleName}" +
347+
s" subsamplingRate must be in range (0, 1], but was $rate")
348+
}
349+
324350
/**
325351
* Create a Strategy instance to use with the old API.
326352
* NOTE: The caller should set impurity and seed.
@@ -402,6 +428,18 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams {
402428

403429
/** @group getParam */
404430
final def getFeatureSubsetStrategy: String = getOrDefault(featureSubsetStrategy)
431+
432+
override protected def validateImpl(paramMap: ParamMap): Unit = {
433+
super.validateImpl(paramMap)
434+
val map = extractParamMap(paramMap)
435+
require(map(numTrees) >= 1, s"${this.getClass.getSimpleName}" +
436+
s" numTrees must be >= 1, but was ${map(numTrees)}")
437+
require(
438+
RandomForestParams.supportedFeatureSubsetStrategies.contains(map(featureSubsetStrategy)),
439+
s"RandomForestParams was given unrecognized featureSubsetStrategy:" +
440+
s" ${map(featureSubsetStrategy)}. Supported" +
441+
s" options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}")
442+
}
405443
}
406444

407445
private[ml] object RandomForestParams {

mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,26 @@ private[ml] trait CrossValidatorParams extends Params {
6161
def getEvaluator: Evaluator = getOrDefault(evaluator)
6262

6363
/**
64-
* param for number of folds for cross validation
64+
* Param for number of folds for cross validation. Must be >= 2.
65+
* Default: 3
6566
* @group param
6667
*/
67-
val numFolds: IntParam = new IntParam(this, "numFolds", "number of folds for cross validation")
68+
val numFolds: IntParam =
69+
new IntParam(this, "numFolds", "number of folds for cross validation (>= 2)")
6870

6971
/** @group getParam */
7072
def getNumFolds: Int = getOrDefault(numFolds)
7173

7274
setDefault(numFolds -> 3)
75+
76+
override def validate(paramMap: ParamMap): Unit = {
77+
require(getOrDefault(numFolds) >= 2,
78+
s"CrossValidator numFolds must be >= 2, but was ${getOrDefault(numFolds)}")
79+
val map = extractParamMap(paramMap)
80+
getEstimatorParamMaps.foreach { eMap =>
81+
getEstimator.validate(map ++ eMap)
82+
}
83+
}
7384
}
7485

7586
/**
@@ -101,8 +112,8 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
101112
val est = map(estimator)
102113
val eval = map(evaluator)
103114
val epm = map(estimatorParamMaps)
104-
val numModels = epm.size
105-
val metrics = new Array[Double](epm.size)
115+
val numModels = epm.length
116+
val metrics = new Array[Double](epm.length)
106117
val splits = MLUtils.kFold(dataset.rdd, map(numFolds), 0)
107118
splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
108119
val trainingDataset = sqlCtx.createDataFrame(training, schema).cache()

0 commit comments

Comments
 (0)