Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a95bc22
timing for DecisionTree internals
jkbradley Aug 5, 2014
511ec85
Merge remote-tracking branch 'upstream/master' into dt-timing
jkbradley Aug 6, 2014
bcf874a
Merge remote-tracking branch 'upstream/master' into dt-timing
jkbradley Aug 7, 2014
f61e9d2
Merge remote-tracking branch 'upstream/master' into dt-timing
jkbradley Aug 8, 2014
3211f02
Optimizing DecisionTree
jkbradley Aug 8, 2014
0f676e2
Optimizations + Bug fix for DecisionTree
jkbradley Aug 8, 2014
b2ed1f3
Merge remote-tracking branch 'upstream/master' into dt-opt
jkbradley Aug 8, 2014
b914f3b
DecisionTree optimization: eliminated filters + small changes
jkbradley Aug 9, 2014
c1565a5
Small DecisionTree updates:
jkbradley Aug 11, 2014
a87e08f
Merge remote-tracking branch 'upstream/master' into dt-opt1
jkbradley Aug 14, 2014
8464a6e
Moved TimeTracker to tree/impl/ in its own file, and cleaned it up. …
jkbradley Aug 14, 2014
e66f1b1
TreePoint
jkbradley Aug 14, 2014
d036089
Print timing info to logDebug.
jkbradley Aug 14, 2014
430d782
Added more debug info on binning error. Added some docs.
jkbradley Aug 14, 2014
356daba
Merge branch 'dt-opt1' into dt-opt2
jkbradley Aug 14, 2014
26d10dd
Removed tree/model/Filter.scala since no longer used. Removed debugg…
jkbradley Aug 15, 2014
2d2aaaf
Merge remote-tracking branch 'upstream/master' into dt-opt1
jkbradley Aug 15, 2014
6b5651e
Updates based on code review. 1 major change: persisting to memory +…
jkbradley Aug 15, 2014
5f2dec2
Fixed scalastyle issue in TreePoint
jkbradley Aug 15, 2014
f40381c
Merge branch 'dt-opt1' into dt-opt2
jkbradley Aug 15, 2014
797f68a
Fixed DecisionTreeSuite bug for training second level. Needed to upd…
jkbradley Aug 15, 2014
931a3a7
Merge remote-tracking branch 'upstream/master' into dt-opt2
jkbradley Aug 15, 2014
6a38f48
Added DTMetadata class for cleaner code
jkbradley Aug 16, 2014
db0d773
scala style fix
jkbradley Aug 16, 2014
ac0b9f8
Small updates based on code review.
jkbradley Aug 16, 2014
3726d20
Small code improvements based on code review.
jkbradley Aug 17, 2014
a0ed0da
Renamed DTMetadata to DecisionTreeMetadata. Small doc updates.
jkbradley Aug 17, 2014
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Small updates based on code review.
Main change: Now using << instead of math.pow.
  • Loading branch information
jkbradley committed Aug 16, 2014
commit ac0b9f84ededb9aaee477f439f711d9be8e890bd
72 changes: 32 additions & 40 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,11 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
// depth of the decision tree
val maxDepth = strategy.maxDepth
// the max number of nodes possible given the depth of the tree
val maxNumNodes = math.pow(2, maxDepth + 1).toInt - 1
val maxNumNodes = (2 << maxDepth) - 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be a comment explaining this calculation will help. Even the previous code might be a good comment. :-)

// Initialize an array to hold parent impurity calculations for each node.
val parentImpurities = new Array[Double](maxNumNodes)
// dummy value for top node (updated during first split calculation)
val nodes = new Array[Node](maxNumNodes)
val nodesInTree = Array.fill[Boolean](maxNumNodes)(false) // put into nodes array later?
nodesInTree(0) = true

// Calculate level for single group construction

Expand Down Expand Up @@ -129,7 +127,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
timer.stop("findBestSplits")

val levelNodeIndexOffset = math.pow(2, level).toInt - 1
val levelNodeIndexOffset = (1 << level) - 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto: comment might help to understand the calculation using bit shift.

for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
val nodeIndex = levelNodeIndexOffset + index
val isLeftChild = level != 0 && nodeIndex % 2 == 1
Expand All @@ -138,8 +136,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
} else {
(nodeIndex - 2) / 2
}
// if (level == 0 || (nodesInTree(parentNodeIndex) && !nodes(parentNodeIndex).isLeaf))
// TODO: Use above check to skip unused branch of tree
// Extract info for this node (index) at the current level.
timer.start("extractNodeInfo")
extractNodeInfo(nodeSplitStats, level, index, nodes)
Expand All @@ -158,7 +154,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
timer.stop("extractInfoForLowerLevels")
logDebug("final best split = " + nodeSplitStats._1)
}
require(math.pow(2, level) == splitsStatsForLevel.length)
require((1 << level) == splitsStatsForLevel.length)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto: comment will help to maintain code.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my next update, I actually moved all of these to intuitively named functions such as numNodesInLevel(). I'd like to stick with that, but leave that in the next update if that's OK. That will also help to consolidate the indexing code and make the 1-indexing change easier to make.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good.

// Check whether all the nodes at the current level at leaves.
val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
logDebug("all leaf = " + allLeaf)
Expand Down Expand Up @@ -196,7 +192,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
nodes: Array[Node]): Unit = {
val split = nodeSplitStats._1
val stats = nodeSplitStats._2
val nodeIndex = math.pow(2, level).toInt - 1 + index
val nodeIndex = (1 << level) - 1 + index
val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth)
val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
logDebug("Node = " + node)
Expand All @@ -212,24 +208,20 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
maxDepth: Int,
nodeSplitStats: (Split, InformationGainStats),
parentImpurities: Array[Double]): Unit = {

if (level >= maxDepth) {
return
}
// 0 corresponds to the left child node and 1 corresponds to the right child node.
var i = 0
while (i <= 1) {
// Calculate the index of the node from the node level and the index at the current level.
val nodeIndex = math.pow(2, level + 1).toInt - 1 + 2 * index + i
val impurity = if (i == 0) {
nodeSplitStats._2.leftImpurity
} else {
nodeSplitStats._2.rightImpurity
}
logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity)
// noting the parent impurities
parentImpurities(nodeIndex) = impurity
i += 1
}

val leftNodeIndex = (2 << level) - 1 + 2 * index
val leftImpurity = nodeSplitStats._2.leftImpurity
logDebug("leftNodeIndex = " + leftNodeIndex + ", impurity = " + leftImpurity)
parentImpurities(leftNodeIndex) = leftImpurity

val rightNodeIndex = leftNodeIndex + 1
val rightImpurity = nodeSplitStats._2.rightImpurity
logDebug("rightNodeIndex = " + rightNodeIndex + ", impurity = " + rightImpurity)
parentImpurities(rightNodeIndex) = rightImpurity
}
}

Expand Down Expand Up @@ -464,7 +456,7 @@ object DecisionTree extends Serializable with Logging {
// the nodes are divided into multiple groups at each level with the number of groups
// increasing exponentially per level. For example, if maxLevelForSingleGroup is 10,
// numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
val numGroups = math.pow(2, level - maxLevelForSingleGroup).toInt
val numGroups = 1 << level - maxLevelForSingleGroup
logDebug("numGroups = " + numGroups)
var bestSplits = new Array[(Split, InformationGainStats)](0)
// Iterate over each group of nodes at a level.
Expand Down Expand Up @@ -534,7 +526,7 @@ object DecisionTree extends Serializable with Logging {

// numNodes: Number of nodes in this (level of tree, group),
// where nodes at deeper (larger) levels may be divided into groups.
val numNodes = math.pow(2, level).toInt / numGroups
val numNodes = (1 << level) / numGroups
logDebug("numNodes = " + numNodes)

// Find the number of features by looking at the first sample.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should get numFeatures from metadata. Calling first() triggers at least one job.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I forgot to port that change from my next PR.

Expand Down Expand Up @@ -563,24 +555,24 @@ object DecisionTree extends Serializable with Logging {
* @return Leaf index if the data point reaches a leaf.
* Otherwise, last node reachable in tree matching this example.
*/
def predictNodeIndex(node: Node, features: Array[Int]): Int = {
def predictNodeIndex(node: Node, binnedFeatures: Array[Int]): Int = {
if (node.isLeaf) {
node.id
} else {
val featureIndex = node.split.get.feature
val splitLeft = node.split.get.featureType match {
case Continuous => {
val binIndex = features(featureIndex)
val binIndex = binnedFeatures(featureIndex)
val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold
// bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold]
// We do not need to check lowSplit since bins are separated by splits.
featureValueUpperBound <= node.split.get.threshold
}
case Categorical => {
val featureValue = if (metadata.isUnordered(featureIndex)) {
features(featureIndex)
binnedFeatures(featureIndex)
} else {
val binIndex = features(featureIndex)
val binIndex = binnedFeatures(featureIndex)
bins(featureIndex)(binIndex).category
}
node.split.get.categories.contains(featureValue)
Expand All @@ -596,9 +588,9 @@ object DecisionTree extends Serializable with Logging {
}
} else {
if (splitLeft) {
predictNodeIndex(node.leftNode.get, features)
predictNodeIndex(node.leftNode.get, binnedFeatures)
} else {
predictNodeIndex(node.rightNode.get, features)
predictNodeIndex(node.rightNode.get, binnedFeatures)
}
}
}
Expand All @@ -613,7 +605,7 @@ object DecisionTree extends Serializable with Logging {
}

// Used for treePointToNodeIndex
val levelOffset = (math.pow(2, level) - 1).toInt
val levelOffset = (1 << level) - 1

/**
* Find the node (indexed from 0 at the start of this level) for the given example.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be the comment should reflect that the indexing is from the start of the groupShift at a given level.

Expand Down Expand Up @@ -678,7 +670,7 @@ object DecisionTree extends Serializable with Logging {
treePoint.label.toInt
// Find all matching bins and increment their values
val featureCategories = metadata.featureArity(featureIndex)
val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1
val numCategoricalBins = (1 << featureCategories - 1) - 1
var binIndex = 0
while (binIndex < numCategoricalBins) {
val aggIndex = aggShift + binIndex * numClasses
Expand Down Expand Up @@ -764,9 +756,9 @@ object DecisionTree extends Serializable with Logging {
3 * numBins * numFeatures * nodeIndex +
3 * numBins * featureIndex +
3 * binIndex
agg(aggIndex) = agg(aggIndex) + 1
agg(aggIndex + 1) = agg(aggIndex + 1) + label
agg(aggIndex + 2) = agg(aggIndex + 2) + label * label
agg(aggIndex) += 1
agg(aggIndex + 1) += label
agg(aggIndex + 2) += label * label
featureIndex += 1
}
}
Expand Down Expand Up @@ -1165,7 +1157,7 @@ object DecisionTree extends Serializable with Logging {
// Categorical feature
val featureCategories = metadata.featureArity(featureIndex)
if (metadata.isUnordered(featureIndex)) {
math.pow(2.0, featureCategories - 1).toInt - 1
(1 << featureCategories - 1) - 1
} else {
featureCategories
}
Expand Down Expand Up @@ -1257,7 +1249,7 @@ object DecisionTree extends Serializable with Logging {
// Iterating over all nodes at this level
var node = 0
while (node < numNodes) {
val nodeImpurityIndex = math.pow(2, level).toInt - 1 + node + groupShift
val nodeImpurityIndex = (1 << level) - 1 + node + groupShift
val binsForNode: Array[Double] = getBinDataForNode(node)
logDebug("nodeImpurityIndex = " + nodeImpurityIndex)
val parentNodeImpurity = parentImpurities(nodeImpurityIndex)
Expand Down Expand Up @@ -1302,7 +1294,7 @@ object DecisionTree extends Serializable with Logging {
* For multiclass classification with a low-arity feature
* (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
* the feature is split based on subsets of categories.
* There are math.pow(2, maxFeatureValue - 1) - 1 splits.
* There are (1 << maxFeatureValue - 1) - 1 splits.
* (b) "ordered features"
* For regression and binary classification,
* and for multiclass classification with a high-arity feature,
Expand Down Expand Up @@ -1391,7 +1383,7 @@ object DecisionTree extends Serializable with Logging {
if (metadata.isUnordered(featureIndex)) {
// 2^(maxFeatureValue- 1) - 1 combinations
var index = 0
while (index < math.pow(2.0, featureCategories - 1).toInt - 1) {
while (index < (1 << featureCategories - 1) - 1) {
val categories: List[Double]
= extractMultiClassCategories(index + 1, featureCategories)
splits(featureIndex)(index)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ private[tree] object DTMetadata {
val unorderedFeatures = new mutable.HashSet[Int]()
if (numClasses > 2) {
strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
val numUnorderedBins = math.pow(2, k - 1) - 1
val numUnorderedBins = (1 << k - 1) - 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to check the value of k first. If k > 30, the result will be unexpected. Using 1L instead of 1 may help.

if (numUnorderedBins < maxBins) {
unorderedFeatures.add(f)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.mllib.tree.impl

import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.model.Bin
import org.apache.spark.rdd.RDD

Expand Down