Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a95bc22
timing for DecisionTree internals
jkbradley Aug 5, 2014
511ec85
Merge remote-tracking branch 'upstream/master' into dt-timing
jkbradley Aug 6, 2014
bcf874a
Merge remote-tracking branch 'upstream/master' into dt-timing
jkbradley Aug 7, 2014
f61e9d2
Merge remote-tracking branch 'upstream/master' into dt-timing
jkbradley Aug 8, 2014
3211f02
Optimizing DecisionTree
jkbradley Aug 8, 2014
0f676e2
Optimizations + Bug fix for DecisionTree
jkbradley Aug 8, 2014
b2ed1f3
Merge remote-tracking branch 'upstream/master' into dt-opt
jkbradley Aug 8, 2014
b914f3b
DecisionTree optimization: eliminated filters + small changes
jkbradley Aug 9, 2014
c1565a5
Small DecisionTree updates:
jkbradley Aug 11, 2014
a87e08f
Merge remote-tracking branch 'upstream/master' into dt-opt1
jkbradley Aug 14, 2014
8464a6e
Moved TimeTracker to tree/impl/ in its own file, and cleaned it up. …
jkbradley Aug 14, 2014
e66f1b1
TreePoint
jkbradley Aug 14, 2014
d036089
Print timing info to logDebug.
jkbradley Aug 14, 2014
430d782
Added more debug info on binning error. Added some docs.
jkbradley Aug 14, 2014
356daba
Merge branch 'dt-opt1' into dt-opt2
jkbradley Aug 14, 2014
26d10dd
Removed tree/model/Filter.scala since no longer used. Removed debugg…
jkbradley Aug 15, 2014
2d2aaaf
Merge remote-tracking branch 'upstream/master' into dt-opt1
jkbradley Aug 15, 2014
6b5651e
Updates based on code review. 1 major change: persisting to memory +…
jkbradley Aug 15, 2014
5f2dec2
Fixed scalastyle issue in TreePoint
jkbradley Aug 15, 2014
f40381c
Merge branch 'dt-opt1' into dt-opt2
jkbradley Aug 15, 2014
797f68a
Fixed DecisionTreeSuite bug for training second level. Needed to upd…
jkbradley Aug 15, 2014
931a3a7
Merge remote-tracking branch 'upstream/master' into dt-opt2
jkbradley Aug 15, 2014
6a38f48
Added DTMetadata class for cleaner code
jkbradley Aug 16, 2014
db0d773
scala style fix
jkbradley Aug 16, 2014
ac0b9f8
Small updates based on code review.
jkbradley Aug 16, 2014
3726d20
Small code improvements based on code review.
jkbradley Aug 17, 2014
a0ed0da
Renamed DTMetadata to DecisionTreeMetadata. Small doc updates.
jkbradley Aug 17, 2014
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Removed tree/model/Filter.scala since no longer used. Removed debuggi…
…ng println calls in DecisionTree.scala.
  • Loading branch information
jkbradley committed Aug 15, 2014
commit 26d10dd58ee218102bd205c1e6d68fda5a45cf4b
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo

timer.start("total")

// Cache input RDD for speedup during multiple passes.
timer.start("init")
val retaggedInput = input.retag(classOf[LabeledPoint])
logDebug("algo = " + strategy.algo)
Expand All @@ -77,17 +76,15 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
logDebug("numBins = " + numBins)

timer.start("init")
// Bin feature values (TreePoint representation).
// Cache input RDD for speedup during multiple passes.
val treeInput = TreePoint.convertToTreeRDD(retaggedInput, strategy, bins).cache()
timer.stop("init")

// depth of the decision tree
val maxDepth = strategy.maxDepth
// the max number of nodes possible given the depth of the tree
val maxNumNodes = math.pow(2, maxDepth + 1).toInt - 1
// Initialize an array to hold filters applied to points for each node.
//val filters = new Array[List[Filter]](maxNumNodes)
// The filter at the top node is an empty list.
//filters(0) = List()
// Initialize an array to hold parent impurity calculations for each node.
val parentImpurities = new Array[Double](maxNumNodes)
// dummy value for top node (updated during first split calculation)
Expand Down Expand Up @@ -118,9 +115,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
/*
* The main idea here is to perform level-wise training of the decision tree nodes thus
* reducing the passes over the data from l to log2(l) where l is the total number of nodes.
* Each data sample is checked for validity w.r.t to each node at a given level -- i.e.,
* the sample is only used for the split calculation at the node if the sampled would have
* still survived the filters of the parent nodes.
* Each data sample is handled by a particular node at that level (or it reaches a leaf
* beforehand and is not used in later levels.
*/

var level = 0
Expand Down Expand Up @@ -169,7 +165,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
}
require(math.pow(2, level) == splitsStatsForLevel.length)
// Check whether all the nodes at the current level at leaves.
println(s"LOOP over levels: level=$level, splitStats...gains: ${splitsStatsForLevel.map(_._2.gain).mkString(",")}")
val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
logDebug("all leaf = " + allLeaf)
if (allLeaf) {
Expand Down Expand Up @@ -237,8 +232,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity)
// noting the parent impurities
parentImpurities(nodeIndex) = impurity
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since most of the code is removed, maybe we can unfold the while loop:

val leftNodeIndex = math.pow(2, level + 1).toInt - 1 + 2 * index
parentImpurities(leftNodeIndex) = nodeSplitStats._2.leftImpurity
val rightNodeIndex = leftNodeIndex + 1
parentImpurities(rightNodeIndex) = nodeSplitStats._2.rightImpurity

Btw, could the code be simplified if we don't use leftNode and rightNode but childNodes: Array[Node]?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure that would help here, but at some point down the line (not this PR), I do want to include this impurity info in the growing tree structure itself. I will make a JIRA for it.

// noting the parents filters for the child nodes
val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1)
i += 1
}
}
Expand Down Expand Up @@ -461,7 +454,6 @@ object DecisionTree extends Serializable with Logging {
* @param maxLevelForSingleGroup the deepest level for single-group level-wise computation.
* @param unorderedFeatures Set of unordered (categorical) features.
* @return array (over nodes) of splits with best split for each node at a given level.
* TODO: UPDATE DOC
*/
protected[tree] def findBestSplits(
input: RDD[TreePoint],
Expand Down Expand Up @@ -512,7 +504,6 @@ object DecisionTree extends Serializable with Logging {
* @param numGroups total number of node groups at the current level. Default value is set to 1.
* @param groupIndex index of the node group being processed. Default value is set to 0.
* @return array of splits with best splits for all nodes at a given level.
* TODO: UPDATE DOC
*/
private def findBestSplitsPerGroup(
input: RDD[TreePoint],
Expand All @@ -539,7 +530,7 @@ object DecisionTree extends Serializable with Logging {
* We use a bin-wise best split computation strategy instead of a straightforward best split
* computation strategy. Instead of analyzing each sample for contribution to the left/right
* child node impurity of every split, we first categorize each feature of a sample into a
* bin. Each bin is an interval between a low and high split. Since each splits, and thus bin,
* bin. Each bin is an interval between a low and high split. Since each split, and thus bin,
* is ordered (read ordering for categorical variables in the findSplitsBins method),
* we exploit this structure to calculate aggregates for bins and then use these aggregates
* to calculate information gain for each split.
Expand Down Expand Up @@ -660,7 +651,6 @@ object DecisionTree extends Serializable with Logging {
* numClasses * numBins * numFeatures * numNodes.
* Indexed by (node, feature, bin, label) where label is the least significant bit.
* @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
* TODO: UPDATE DOC
*/
def updateBinForOrderedFeature(
treePoint: TreePoint,
Expand All @@ -681,21 +671,19 @@ object DecisionTree extends Serializable with Logging {
* where [bins] ranges over all bins.
* Updates left or right side of aggregate depending on split.
*
* @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
* @param treePoint Data point being aggregated.
* @param agg Indexed by (left/right, node, feature, bin, label)
* where label is the least significant bit.
* The left/right specifier is a 0/1 index indicating left/right child info.
* @param rightChildShift Offset for right side of agg.
* TODO: UPDATE DOC
* TODO: Make arg order same as for ordered feature.
*/
def updateBinForUnorderedFeature(
nodeIndex: Int,
featureIndex: Int,
treePoint: TreePoint,
agg: Array[Double],
rightChildShift: Int): Unit = {
//println(s"-- updateBinForUnorderedFeature node:$nodeIndex, feature:$featureIndex, label:$label.")
val featureValue = treePoint.features(featureIndex)
// Update the left or right count for one bin.
val aggShift =
Expand Down Expand Up @@ -780,7 +768,6 @@ object DecisionTree extends Serializable with Logging {
* @return agg
*/
def regressionBinSeqOp(agg: Array[Double], treePoint: TreePoint, nodeIndex: Int): Unit = {
// TODO: Move stuff outside loop.
val label = treePoint.label
// Iterate over all features.
var featureIndex = 0
Expand All @@ -791,9 +778,6 @@ object DecisionTree extends Serializable with Logging {
3 * numBins * numFeatures * nodeIndex +
3 * numBins * featureIndex +
3 * binIndex
if (aggIndex >= agg.size) {
println(s"aggIndex = $aggIndex, agg.size = ${agg.size}. binIndex = $binIndex, featureIndex = $featureIndex, nodeIndex = $nodeIndex, numBins = $numBins, numFeatures = $numFeatures")
}
agg(aggIndex) = agg(aggIndex) + 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use agg(aggIndex) += 1?

agg(aggIndex + 1) = agg(aggIndex + 1) + label
agg(aggIndex + 2) = agg(aggIndex + 2) + label * label
Expand Down Expand Up @@ -1025,7 +1009,6 @@ object DecisionTree extends Serializable with Logging {
* Element i (i = 1, ..., numSplits - 1) is set to be
* the cumulative sum (from right) over binData for bins
* numBins - 1, ..., numBins - 1 - i.
* TODO: We could avoid doing one of these cumulative sums.
*/
def findAggForOrderedFeatureClassification(
leftNodeAgg: Array[Array[Array[Double]]],
Expand Down Expand Up @@ -1196,16 +1179,6 @@ object DecisionTree extends Serializable with Logging {
} else {
featureCategories
}
/*
val isSpaceSufficientForAllCategoricalSplits =
numBins > math.pow(2, featureCategories.toInt - 1) - 1
if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
math.pow(2.0, featureCategories - 1).toInt - 1
} else {
// Ordered features
featureCategories
}
*/
}
}

Expand Down

This file was deleted.