-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-3042] [mllib] DecisionTree Filter top-down instead of bottom-up #1975
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
a95bc22
511ec85
bcf874a
f61e9d2
3211f02
0f676e2
b2ed1f3
b914f3b
c1565a5
a87e08f
8464a6e
e66f1b1
d036089
430d782
356daba
26d10dd
2d2aaaf
6b5651e
5f2dec2
f40381c
797f68a
931a3a7
6a38f48
db0d773
ac0b9f8
3726d20
a0ed0da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,6 @@ | |
|
|
||
| package org.apache.spark.mllib.tree | ||
|
|
||
| import java.util.Calendar | ||
|
|
||
| import org.apache.spark.mllib.linalg.Vector | ||
|
|
||
|
|
@@ -28,47 +27,16 @@ import org.apache.spark.annotation.Experimental | |
| import org.apache.spark.api.java.JavaRDD | ||
| import org.apache.spark.Logging | ||
| import org.apache.spark.mllib.regression.LabeledPoint | ||
| import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} | ||
| import org.apache.spark.mllib.tree.configuration.Strategy | ||
| import org.apache.spark.mllib.tree.configuration.Algo._ | ||
| import org.apache.spark.mllib.tree.configuration.FeatureType._ | ||
| import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ | ||
| import org.apache.spark.mllib.tree.impl.TreePoint | ||
| import org.apache.spark.mllib.tree.impurity.{Impurities, Gini, Entropy, Impurity} | ||
| import org.apache.spark.mllib.tree.impl.{TimeTracker, TreePoint} | ||
| import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity} | ||
| import org.apache.spark.mllib.tree.model._ | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.util.random.XORShiftRandom | ||
|
|
||
| class TimeTracker { | ||
|
|
||
| var tmpTime: Long = Calendar.getInstance().getTimeInMillis | ||
|
|
||
| def reset(): Unit = { | ||
| tmpTime = Calendar.getInstance().getTimeInMillis | ||
| } | ||
|
|
||
| def elapsed(): Long = { | ||
| Calendar.getInstance().getTimeInMillis - tmpTime | ||
| } | ||
|
|
||
| var initTime: Long = 0 // Data retag and cache | ||
| var findSplitsBinsTime: Long = 0 | ||
| var extractNodeInfoTime: Long = 0 | ||
| var extractInfoForLowerLevelsTime: Long = 0 | ||
| var findBestSplitsTime: Long = 0 | ||
| var binAggregatesTime: Long = 0 | ||
| var chooseSplitsTime: Long = 0 | ||
|
|
||
| override def toString: String = { | ||
| s"DecisionTree timing\n" + | ||
| s"initTime: $initTime\n" + | ||
| s"findSplitsBinsTime: $findSplitsBinsTime\n" + | ||
| s"extractNodeInfoTime: $extractNodeInfoTime\n" + | ||
| s"extractInfoForLowerLevelsTime: $extractInfoForLowerLevelsTime\n" + | ||
| s"findBestSplitsTime: $findBestSplitsTime\n" + | ||
| s"binAggregatesTime: $binAggregatesTime\n" + | ||
| s"chooseSplitsTime: $chooseSplitsTime\n" | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
|
|
@@ -91,26 +59,26 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo | |
| def train(input: RDD[LabeledPoint]): DecisionTreeModel = { | ||
|
|
||
| val timer = new TimeTracker() | ||
| timer.reset() | ||
|
|
||
| timer.start("total") | ||
|
|
||
| // Cache input RDD for speedup during multiple passes. | ||
| timer.start("init") | ||
| val retaggedInput = input.retag(classOf[LabeledPoint]) | ||
| logDebug("algo = " + strategy.algo) | ||
|
|
||
| timer.initTime += timer.elapsed() | ||
| timer.reset() | ||
| timer.stop("init") | ||
|
|
||
| // Find the splits and the corresponding bins (interval between the splits) using a sample | ||
| // of the input data. | ||
| timer.start("findSplitsBins") | ||
| val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(retaggedInput, strategy) | ||
| val numBins = bins(0).length | ||
| timer.stop("findSplitsBins") | ||
| logDebug("numBins = " + numBins) | ||
|
|
||
| timer.findSplitsBinsTime += timer.elapsed() | ||
|
|
||
| timer.reset() | ||
| val treeInput = TreePoint.convertToTreeRDD(retaggedInput, strategy, bins) | ||
| timer.initTime += timer.elapsed() | ||
| timer.start("init") | ||
| val treeInput = TreePoint.convertToTreeRDD(retaggedInput, strategy, bins).cache() | ||
| timer.stop("init") | ||
|
|
||
| // depth of the decision tree | ||
| val maxDepth = strategy.maxDepth | ||
|
|
@@ -127,7 +95,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo | |
| val nodesInTree = Array.fill[Boolean](maxNumNodes)(false) // put into nodes array later? | ||
| nodesInTree(0) = true | ||
| // num features | ||
| val numFeatures = retaggedInput.take(1)(0).features.size | ||
| val numFeatures = treeInput.take(1)(0).features.size | ||
|
|
||
| // Calculate level for single group construction | ||
|
|
||
|
|
@@ -155,10 +123,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo | |
| * still survived the filters of the parent nodes. | ||
| */ | ||
|
|
||
| var findBestSplitsTime: Long = 0 | ||
| var extractNodeInfoTime: Long = 0 | ||
| var extractInfoForLowerLevelsTime: Long = 0 | ||
|
|
||
| var level = 0 | ||
| var break = false | ||
| while (level <= maxDepth && !break) { | ||
|
|
@@ -169,10 +133,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo | |
|
|
||
|
|
||
| // Find best split for all nodes at a level. | ||
| timer.reset() | ||
| timer.start("findBestSplits") | ||
| val splitsStatsForLevel = DecisionTree.findBestSplits(treeInput, parentImpurities, | ||
| strategy, level, nodes, splits, bins, maxLevelForSingleGroup, unorderedFeatures, timer) | ||
| timer.findBestSplitsTime += timer.elapsed() | ||
| timer.stop("findBestSplits") | ||
|
|
||
| val levelNodeIndexOffset = math.pow(2, level).toInt - 1 | ||
| for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { | ||
|
|
@@ -186,9 +150,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo | |
| // if (level == 0 || (nodesInTree(parentNodeIndex) && !nodes(parentNodeIndex).isLeaf)) | ||
| // TODO: Use above check to skip unused branch of tree | ||
| // Extract info for this node (index) at the current level. | ||
| timer.reset() | ||
| timer.start("extractNodeInfo") | ||
| extractNodeInfo(nodeSplitStats, level, index, nodes) | ||
| timer.extractNodeInfoTime += timer.elapsed() | ||
| timer.stop("extractNodeInfo") | ||
| if (level != 0) { | ||
| // Set parent. | ||
| if (isLeftChild) { | ||
|
|
@@ -198,9 +162,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo | |
| } | ||
| } | ||
| // Extract info for nodes at the next lower level. | ||
| timer.reset() | ||
| timer.start("extractInfoForLowerLevels") | ||
| extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities) | ||
| timer.extractInfoForLowerLevelsTime += timer.elapsed() | ||
| timer.stop("extractInfoForLowerLevels") | ||
| logDebug("final best split = " + nodeSplitStats._1) | ||
| } | ||
| require(math.pow(2, level) == splitsStatsForLevel.length) | ||
|
|
@@ -215,8 +179,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo | |
| } | ||
| } | ||
|
|
||
| println(timer) | ||
|
|
||
| logDebug("#####################################") | ||
| logDebug("Extracting tree model") | ||
| logDebug("#####################################") | ||
|
|
@@ -226,6 +188,11 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo | |
| // Build the full tree using the node info calculated in the level-wise best split calculations. | ||
| topNode.build(nodes) | ||
|
|
||
| timer.stop("total") | ||
|
|
||
| logDebug("Internal timing for DecisionTree:") | ||
| logDebug(s"$timer") | ||
|
|
||
| new DecisionTreeModel(topNode, strategy.algo) | ||
| } | ||
|
|
||
|
|
@@ -257,11 +224,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo | |
| parentImpurities: Array[Double]): Unit = { | ||
| if (level >= maxDepth) | ||
| return | ||
| //filters: Array[List[Filter]]): Unit = { | ||
| // 0 corresponds to the left child node and 1 corresponds to the right child node. | ||
| var i = 0 | ||
| while (i <= 1) { | ||
| // Calculate the index of the node from the node level and the index at the current level. | ||
| // Calculate the index of the node from the node level and the index at the current level. | ||
| val nodeIndex = math.pow(2, level + 1).toInt - 1 + 2 * index + i | ||
| val impurity = if (i == 0) { | ||
| nodeSplitStats._2.leftImpurity | ||
|
|
@@ -273,13 +239,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo | |
| parentImpurities(nodeIndex) = impurity | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since most of the code is removed, maybe we can unfold the while loop: Btw, could the code be simplified if we don't use
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure that would help here, but at some point down the line (not this PR), I do want to include this impurity info in the growing tree structure itself. I will make a JIRA for it. |
||
| // noting the parents filters for the child nodes | ||
| val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1) | ||
| /* | ||
| filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2) | ||
| //println(s"extractInfoForLowerLevels: Set filters(node:$nodeIndex): ${filters(nodeIndex).mkString(", ")}") | ||
| for (filter <- filters(nodeIndex)) { | ||
| logDebug("Filter = " + filter) | ||
| } | ||
| */ | ||
| i += 1 | ||
| } | ||
| } | ||
|
|
@@ -516,7 +475,6 @@ object DecisionTree extends Serializable with Logging { | |
| unorderedFeatures: Set[Int], | ||
| timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = { | ||
| // split into groups to avoid memory overflow during aggregation | ||
| //println(s"findBestSplits: level = $level") | ||
| if (level > maxLevelForSingleGroup) { | ||
| // When information for all nodes at a given level cannot be stored in memory, | ||
| // the nodes are divided into multiple groups at each level with the number of groups | ||
|
|
@@ -540,7 +498,7 @@ object DecisionTree extends Serializable with Logging { | |
| } | ||
| } | ||
|
|
||
| /** | ||
| /** | ||
| * Returns an array of optimal splits for a group of nodes at a given level | ||
| * | ||
| * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] | ||
|
|
@@ -898,39 +856,15 @@ object DecisionTree extends Serializable with Logging { | |
| combinedAggregate | ||
| } | ||
|
|
||
| timer.reset() | ||
|
|
||
| // Calculate bin aggregates. | ||
| timer.start("binAggregates") | ||
| val binAggregates = { | ||
| input.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp, binCombOp) | ||
| } | ||
| timer.stop("binAggregates") | ||
| logDebug("binAggregates.length = " + binAggregates.length) | ||
|
|
||
| timer.binAggregatesTime += timer.elapsed() | ||
| //2 * numClasses * numBins * numFeatures * numNodes for unordered features. | ||
| // (left/right, node, feature, bin, label) | ||
| /* | ||
| println(s"binAggregates:") | ||
| for (i <- Range(0,2)) { | ||
| for (n <- Range(0,numNodes)) { | ||
| for (f <- Range(0,numFeatures)) { | ||
| for (b <- Range(0,4)) { | ||
| for (c <- Range(0,numClasses)) { | ||
| val idx = i * numClasses * numBins * numFeatures * numNodes + | ||
| n * numClasses * numBins * numFeatures + | ||
| f * numBins * numFeatures + | ||
| b * numFeatures + | ||
| c | ||
| if (binAggregates(idx) != 0) { | ||
| println(s"\t ($i, c:$c, b:$b, f:$f, n:$n): ${binAggregates(idx)}") | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| */ | ||
|
|
||
| /** | ||
| * Calculate the information gain for a given (feature, split) based upon left/right aggregates. | ||
| * @param leftNodeAgg left node aggregates for this (feature, split) | ||
|
|
@@ -965,7 +899,6 @@ object DecisionTree extends Serializable with Logging { | |
| val totalCount = leftTotalCount + rightTotalCount | ||
| if (totalCount == 0) { | ||
| // Return arbitrary prediction. | ||
| //println(s"BLAH: feature $featureIndex, split $splitIndex. totalCount == 0") | ||
| return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0) | ||
| } | ||
|
|
||
|
|
@@ -1311,7 +1244,6 @@ object DecisionTree extends Serializable with Logging { | |
| bestGainStats = gainStats | ||
| bestFeatureIndex = featureIndex | ||
| bestSplitIndex = splitIndex | ||
| //println(s" feature $featureIndex UPGRADED split $splitIndex: ${splits(featureIndex)(splitIndex)}: gainstats: $gainStats") | ||
| } | ||
| splitIndex += 1 | ||
| } | ||
|
|
@@ -1356,9 +1288,8 @@ object DecisionTree extends Serializable with Logging { | |
| } | ||
| } | ||
|
|
||
| timer.reset() | ||
|
|
||
| // Calculate best splits for all nodes at a given level | ||
| timer.start("chooseSplits") | ||
| val bestSplits = new Array[(Split, InformationGainStats)](numNodes) | ||
| // Iterating over all nodes at this level | ||
| var node = 0 | ||
|
|
@@ -1369,10 +1300,9 @@ object DecisionTree extends Serializable with Logging { | |
| val parentNodeImpurity = parentImpurities(nodeImpurityIndex) | ||
| logDebug("parent node impurity = " + parentNodeImpurity) | ||
| bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) | ||
| //println(s"bestSplits(node:$node): ${bestSplits(node)}") | ||
| node += 1 | ||
| } | ||
| timer.chooseSplitsTime += timer.elapsed() | ||
| timer.stop("chooseSplits") | ||
|
|
||
| bestSplits | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it used anywhere else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope; I intended to use it for another optimization (ignoring unused branches of the tree) but will postpone that. Will remove it for now.