@@ -38,13 +38,13 @@ import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
3838private [ml] trait DecisionTreeParams extends PredictorParams {
3939
4040 /**
41- * Maximum depth of the tree.
41+ * Maximum depth of the tree (>= 0) .
4242 * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
4343 * (default = 5)
4444 * @group param
4545 */
4646 final val maxDepth : IntParam =
47- new IntParam (this , " maxDepth" , " Maximum depth of the tree." +
47+ new IntParam (this , " maxDepth" , " Maximum depth of the tree. (>= 0) " +
4848 " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes." )
4949
5050 /**
@@ -173,6 +173,24 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
173173 /** @group expertGetParam */
174174 final def getCheckpointInterval : Int = getOrDefault(checkpointInterval)
175175
176+ /**
177+ * Same as [[validate() ]], but renamed to force concrete classes to explicitly implement
178+ * validation (in case concrete classes have their own parameters).
179+ */
180+ protected def validateImpl (paramMap : ParamMap ): Unit = {
181+ val map = extractParamMap(paramMap)
182+ require(map(maxDepth) >= 0 , s " ${this .getClass.getSimpleName}" +
183+ s " maxDepth must be >= 0, but was ${map(maxDepth)}" )
184+ require(map(maxBins) >= 2 , s " ${this .getClass.getSimpleName}" +
185+ s " maxBins must be >= 2, but was ${map(maxBins)}" )
186+ require(map(minInstancesPerNode) >= 1 , s " ${this .getClass.getSimpleName}" +
187+ s " minInstancesPerNode must be >= 1, but was ${map(minInstancesPerNode)}" )
188+ require(map(maxMemoryInMB) > 0 , s " ${this .getClass.getSimpleName}" +
189+ s " maxMemoryInMB must be > 0, but was ${map(maxMemoryInMB)}" )
190+ require(map(checkpointInterval) >= 1 , s " ${this .getClass.getSimpleName}" +
191+ s " checkpointInterval must be >= 1, but was ${map(checkpointInterval)}" )
192+ }
193+
176194 /** (private[ml]) Create a Strategy instance to use with the old API. */
177195 private [ml] def getOldStrategy (
178196 categoricalFeatures : Map [Int , Int ],
@@ -299,12 +317,12 @@ private[ml] object TreeRegressorParams {
299317private [ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed {
300318
301319 /**
302- * Fraction of the training data used for learning each decision tree.
320+ * Fraction of the training data used for learning each decision tree, in range (0, 1] .
303321 * (default = 1.0)
304322 * @group param
305323 */
306324 final val subsamplingRate : DoubleParam = new DoubleParam (this , " subsamplingRate" ,
307- " Fraction of the training data used for learning each decision tree." )
325+ " Fraction of the training data used for learning each decision tree, in range (0, 1] ." )
308326
309327 setDefault(subsamplingRate -> 1.0 )
310328
@@ -321,6 +339,14 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed {
321339 /** @group setParam */
322340 def setSeed (value : Long ): this .type = set(seed, value)
323341
342+ override protected def validateImpl (paramMap : ParamMap ): Unit = {
343+ super .validateImpl(paramMap)
344+ val map = extractParamMap(paramMap)
345+ val rate = map(subsamplingRate)
346+ require(0.0 < rate && rate <= 1.0 , s " ${this .getClass.getSimpleName}" +
347+ s " subsamplingRate must be in range (0, 1], but was $rate" )
348+ }
349+
324350 /**
325351 * Create a Strategy instance to use with the old API.
326352 * NOTE: The caller should set impurity and seed.
@@ -402,6 +428,18 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams {
402428
403429 /** @group getParam */
404430 final def getFeatureSubsetStrategy : String = getOrDefault(featureSubsetStrategy)
431+
432+ override protected def validateImpl (paramMap : ParamMap ): Unit = {
433+ super .validateImpl(paramMap)
434+ val map = extractParamMap(paramMap)
435+ require(map(numTrees) >= 1 , s " ${this .getClass.getSimpleName}" +
436+ s " numTrees must be >= 1, but was ${map(numTrees)}" )
437+ require(
438+ RandomForestParams .supportedFeatureSubsetStrategies.contains(map(featureSubsetStrategy)),
439+ s " RandomForestParams was given unrecognized featureSubsetStrategy: " +
440+ s " ${map(featureSubsetStrategy)}. Supported " +
441+ s " options: ${RandomForestParams .supportedFeatureSubsetStrategies.mkString(" , " )}" )
442+ }
405443}
406444
407445private [ml] object RandomForestParams {
0 commit comments