Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ import org.apache.spark.util.random.XORShiftRandom
@Experimental
class DecisionTree (private val strategy: Strategy) extends Serializable with Logging {

strategy.assertValid()

/**
* Method to train a decision tree model over an RDD
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
Expand Down Expand Up @@ -1368,10 +1370,14 @@ object DecisionTree extends Serializable with Logging {


/*
* Ensure #bins is always greater than the categories. For multiclass classification,
* #bins should be greater than 2^(maxCategories - 1) - 1.
* Ensure numBins is always greater than the categories. For multiclass classification,
* numBins should be greater than 2^(maxCategories - 1) - 1.
* It's a limitation of the current implementation but a reasonable trade-off since features
* with large number of categories get favored over continuous features.
*
* This needs to be checked here instead of in Strategy since numBins can be determined
* by the number of training examples.
* TODO: Allow this case, where we simply will know nothing about some categories.
*/
if (strategy.categoricalFeaturesInfo.size > 0) {
val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.configuration
import scala.collection.JavaConverters._

import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.tree.impurity.Impurity
import org.apache.spark.mllib.tree.impurity.{Variance, Entropy, Gini, Impurity}
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._

Expand Down Expand Up @@ -90,4 +90,33 @@ class Strategy (
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap)
}

private[tree] def assertValid(): Unit = {
algo match {
case Classification =>
require(numClassesForClassification >= 2,
s"DecisionTree Strategy for Classification must have numClassesForClassification >= 2," +
s" but numClassesForClassification = $numClassesForClassification.")
require(Set(Gini, Entropy).contains(impurity),
s"DecisionTree Strategy given invalid impurity for Classification: $impurity." +
s" Valid settings: Gini, Entropy")
case Regression =>
require(impurity == Variance,
s"DecisionTree Strategy given invalid impurity for Regression: $impurity." +
s" Valid settings: Variance")
case _ =>
throw new IllegalArgumentException(
s"DecisionTree Strategy given invalid algo parameter: $algo." +
s" Valid settings are: Classification, Regression.")
}
require(maxDepth >= 0, s"DecisionTree Strategy given invalid maxDepth parameter: $maxDepth." +
s" Valid values are integers >= 0.")
require(maxBins >= 2, s"DecisionTree Strategy given invalid maxBins parameter: $maxBins." +
s" Valid values are integers >= 2.")
categoricalFeaturesInfo.foreach { case (feature, arity) =>
require(arity >= 2,
s"DecisionTree Strategy given invalid categoricalFeaturesInfo setting:" +
s" feature $feature has $arity categories. The number of categories should be >= 2.")
}
}

}