-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-3042] [mllib] DecisionTree Filter top-down instead of bottom-up #1975
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
a95bc22
511ec85
bcf874a
f61e9d2
3211f02
0f676e2
b2ed1f3
b914f3b
c1565a5
a87e08f
8464a6e
e66f1b1
d036089
430d782
356daba
26d10dd
2d2aaaf
6b5651e
5f2dec2
f40381c
797f68a
931a3a7
6a38f48
db0d773
ac0b9f8
3726d20
a0ed0da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
Main change: Now using << instead of math.pow.
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -82,13 +82,11 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo | |
| // depth of the decision tree | ||
| val maxDepth = strategy.maxDepth | ||
| // the max number of nodes possible given the depth of the tree | ||
| val maxNumNodes = math.pow(2, maxDepth + 1).toInt - 1 | ||
| val maxNumNodes = (2 << maxDepth) - 1 | ||
| // Initialize an array to hold parent impurity calculations for each node. | ||
| val parentImpurities = new Array[Double](maxNumNodes) | ||
| // dummy value for top node (updated during first split calculation) | ||
| val nodes = new Array[Node](maxNumNodes) | ||
| val nodesInTree = Array.fill[Boolean](maxNumNodes)(false) // put into nodes array later? | ||
| nodesInTree(0) = true | ||
|
|
||
| // Calculate level for single group construction | ||
|
|
||
|
|
@@ -129,7 +127,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo | |
| metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer) | ||
| timer.stop("findBestSplits") | ||
|
|
||
| val levelNodeIndexOffset = math.pow(2, level).toInt - 1 | ||
| val levelNodeIndexOffset = (1 << level) - 1 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto: comment might help to understand the calculation using bit shift. |
||
| for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { | ||
| val nodeIndex = levelNodeIndexOffset + index | ||
| val isLeftChild = level != 0 && nodeIndex % 2 == 1 | ||
|
|
@@ -138,8 +136,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo | |
| } else { | ||
| (nodeIndex - 2) / 2 | ||
| } | ||
| // if (level == 0 || (nodesInTree(parentNodeIndex) && !nodes(parentNodeIndex).isLeaf)) | ||
| // TODO: Use above check to skip unused branch of tree | ||
| // Extract info for this node (index) at the current level. | ||
| timer.start("extractNodeInfo") | ||
| extractNodeInfo(nodeSplitStats, level, index, nodes) | ||
|
|
@@ -158,7 +154,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo | |
| timer.stop("extractInfoForLowerLevels") | ||
| logDebug("final best split = " + nodeSplitStats._1) | ||
| } | ||
| require(math.pow(2, level) == splitsStatsForLevel.length) | ||
| require((1 << level) == splitsStatsForLevel.length) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto: comment will help to maintain code.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In my next update, I actually moved all of these to intuitively named functions such as numNodesInLevel(). I'd like to stick with that, but leave that in the next update if that's OK. That will also help to consolidate the indexing code and make the 1-indexing change easier to make.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good. |
||
| // Check whether all the nodes at the current level at leaves. | ||
| val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) | ||
| logDebug("all leaf = " + allLeaf) | ||
|
|
@@ -196,7 +192,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo | |
| nodes: Array[Node]): Unit = { | ||
| val split = nodeSplitStats._1 | ||
| val stats = nodeSplitStats._2 | ||
| val nodeIndex = math.pow(2, level).toInt - 1 + index | ||
| val nodeIndex = (1 << level) - 1 + index | ||
| val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth) | ||
| val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) | ||
| logDebug("Node = " + node) | ||
|
|
@@ -212,24 +208,20 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo | |
| maxDepth: Int, | ||
| nodeSplitStats: (Split, InformationGainStats), | ||
| parentImpurities: Array[Double]): Unit = { | ||
|
|
||
| if (level >= maxDepth) { | ||
| return | ||
| } | ||
| // 0 corresponds to the left child node and 1 corresponds to the right child node. | ||
| var i = 0 | ||
| while (i <= 1) { | ||
| // Calculate the index of the node from the node level and the index at the current level. | ||
| val nodeIndex = math.pow(2, level + 1).toInt - 1 + 2 * index + i | ||
| val impurity = if (i == 0) { | ||
| nodeSplitStats._2.leftImpurity | ||
| } else { | ||
| nodeSplitStats._2.rightImpurity | ||
| } | ||
| logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity) | ||
| // noting the parent impurities | ||
| parentImpurities(nodeIndex) = impurity | ||
| i += 1 | ||
| } | ||
|
|
||
| val leftNodeIndex = (2 << level) - 1 + 2 * index | ||
| val leftImpurity = nodeSplitStats._2.leftImpurity | ||
| logDebug("leftNodeIndex = " + leftNodeIndex + ", impurity = " + leftImpurity) | ||
| parentImpurities(leftNodeIndex) = leftImpurity | ||
|
|
||
| val rightNodeIndex = leftNodeIndex + 1 | ||
| val rightImpurity = nodeSplitStats._2.rightImpurity | ||
| logDebug("rightNodeIndex = " + rightNodeIndex + ", impurity = " + rightImpurity) | ||
| parentImpurities(rightNodeIndex) = rightImpurity | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -464,7 +456,7 @@ object DecisionTree extends Serializable with Logging { | |
| // the nodes are divided into multiple groups at each level with the number of groups | ||
| // increasing exponentially per level. For example, if maxLevelForSingleGroup is 10, | ||
| // numGroups is equal to 2 at level 11 and 4 at level 12, respectively. | ||
| val numGroups = math.pow(2, level - maxLevelForSingleGroup).toInt | ||
| val numGroups = 1 << level - maxLevelForSingleGroup | ||
| logDebug("numGroups = " + numGroups) | ||
| var bestSplits = new Array[(Split, InformationGainStats)](0) | ||
| // Iterate over each group of nodes at a level. | ||
|
|
@@ -534,7 +526,7 @@ object DecisionTree extends Serializable with Logging { | |
|
|
||
| // numNodes: Number of nodes in this (level of tree, group), | ||
| // where nodes at deeper (larger) levels may be divided into groups. | ||
| val numNodes = math.pow(2, level).toInt / numGroups | ||
| val numNodes = (1 << level) / numGroups | ||
| logDebug("numNodes = " + numNodes) | ||
|
|
||
| // Find the number of features by looking at the first sample. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should get
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! I forgot to port that change from my next PR. |
||
|
|
@@ -563,24 +555,24 @@ object DecisionTree extends Serializable with Logging { | |
| * @return Leaf index if the data point reaches a leaf. | ||
| * Otherwise, last node reachable in tree matching this example. | ||
| */ | ||
| def predictNodeIndex(node: Node, features: Array[Int]): Int = { | ||
| def predictNodeIndex(node: Node, binnedFeatures: Array[Int]): Int = { | ||
| if (node.isLeaf) { | ||
| node.id | ||
| } else { | ||
| val featureIndex = node.split.get.feature | ||
| val splitLeft = node.split.get.featureType match { | ||
| case Continuous => { | ||
| val binIndex = features(featureIndex) | ||
| val binIndex = binnedFeatures(featureIndex) | ||
| val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold | ||
| // bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold] | ||
| // We do not need to check lowSplit since bins are separated by splits. | ||
| featureValueUpperBound <= node.split.get.threshold | ||
| } | ||
| case Categorical => { | ||
| val featureValue = if (metadata.isUnordered(featureIndex)) { | ||
| features(featureIndex) | ||
| binnedFeatures(featureIndex) | ||
| } else { | ||
| val binIndex = features(featureIndex) | ||
| val binIndex = binnedFeatures(featureIndex) | ||
| bins(featureIndex)(binIndex).category | ||
| } | ||
| node.split.get.categories.contains(featureValue) | ||
|
|
@@ -596,9 +588,9 @@ object DecisionTree extends Serializable with Logging { | |
| } | ||
| } else { | ||
| if (splitLeft) { | ||
| predictNodeIndex(node.leftNode.get, features) | ||
| predictNodeIndex(node.leftNode.get, binnedFeatures) | ||
| } else { | ||
| predictNodeIndex(node.rightNode.get, features) | ||
| predictNodeIndex(node.rightNode.get, binnedFeatures) | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -613,7 +605,7 @@ object DecisionTree extends Serializable with Logging { | |
| } | ||
|
|
||
| // Used for treePointToNodeIndex | ||
| val levelOffset = (math.pow(2, level) - 1).toInt | ||
| val levelOffset = (1 << level) - 1 | ||
|
|
||
| /** | ||
| * Find the node (indexed from 0 at the start of this level) for the given example. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May be the comment should reflect that the indexing is from the start of the groupShift at a given level. |
||
|
|
@@ -678,7 +670,7 @@ object DecisionTree extends Serializable with Logging { | |
| treePoint.label.toInt | ||
| // Find all matching bins and increment their values | ||
| val featureCategories = metadata.featureArity(featureIndex) | ||
| val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 | ||
| val numCategoricalBins = (1 << featureCategories - 1) - 1 | ||
| var binIndex = 0 | ||
| while (binIndex < numCategoricalBins) { | ||
| val aggIndex = aggShift + binIndex * numClasses | ||
|
|
@@ -764,9 +756,9 @@ object DecisionTree extends Serializable with Logging { | |
| 3 * numBins * numFeatures * nodeIndex + | ||
| 3 * numBins * featureIndex + | ||
| 3 * binIndex | ||
| agg(aggIndex) = agg(aggIndex) + 1 | ||
| agg(aggIndex + 1) = agg(aggIndex + 1) + label | ||
| agg(aggIndex + 2) = agg(aggIndex + 2) + label * label | ||
| agg(aggIndex) += 1 | ||
| agg(aggIndex + 1) += label | ||
| agg(aggIndex + 2) += label * label | ||
| featureIndex += 1 | ||
| } | ||
| } | ||
|
|
@@ -1165,7 +1157,7 @@ object DecisionTree extends Serializable with Logging { | |
| // Categorical feature | ||
| val featureCategories = metadata.featureArity(featureIndex) | ||
| if (metadata.isUnordered(featureIndex)) { | ||
| math.pow(2.0, featureCategories - 1).toInt - 1 | ||
| (1 << featureCategories - 1) - 1 | ||
| } else { | ||
| featureCategories | ||
| } | ||
|
|
@@ -1257,7 +1249,7 @@ object DecisionTree extends Serializable with Logging { | |
| // Iterating over all nodes at this level | ||
| var node = 0 | ||
| while (node < numNodes) { | ||
| val nodeImpurityIndex = math.pow(2, level).toInt - 1 + node + groupShift | ||
| val nodeImpurityIndex = (1 << level) - 1 + node + groupShift | ||
| val binsForNode: Array[Double] = getBinDataForNode(node) | ||
| logDebug("nodeImpurityIndex = " + nodeImpurityIndex) | ||
| val parentNodeImpurity = parentImpurities(nodeImpurityIndex) | ||
|
|
@@ -1302,7 +1294,7 @@ object DecisionTree extends Serializable with Logging { | |
| * For multiclass classification with a low-arity feature | ||
| * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits), | ||
| * the feature is split based on subsets of categories. | ||
| * There are math.pow(2, maxFeatureValue - 1) - 1 splits. | ||
| * There are (1 << maxFeatureValue - 1) - 1 splits. | ||
| * (b) "ordered features" | ||
| * For regression and binary classification, | ||
| * and for multiclass classification with a high-arity feature, | ||
|
|
@@ -1391,7 +1383,7 @@ object DecisionTree extends Serializable with Logging { | |
| if (metadata.isUnordered(featureIndex)) { | ||
| // 2^(maxFeatureValue- 1) - 1 combinations | ||
| var index = 0 | ||
| while (index < math.pow(2.0, featureCategories - 1).toInt - 1) { | ||
| while (index < (1 << featureCategories - 1) - 1) { | ||
| val categories: List[Double] | ||
| = extractMultiClassCategories(index + 1, featureCategories) | ||
| splits(featureIndex)(index) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -73,7 +73,7 @@ private[tree] object DTMetadata { | |
| val unorderedFeatures = new mutable.HashSet[Int]() | ||
| if (numClasses > 2) { | ||
| strategy.categoricalFeaturesInfo.foreach { case (f, k) => | ||
| val numUnorderedBins = math.pow(2, k - 1) - 1 | ||
| val numUnorderedBins = (1 << k - 1) - 1 | ||
|
||
| if (numUnorderedBins < maxBins) { | ||
| unorderedFeatures.add(f) | ||
| } else { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May be a comment explaining this calculation will help. Even the previous code might be a good comment. :-)