Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a95bc22
timing for DecisionTree internals
jkbradley Aug 5, 2014
511ec85
Merge remote-tracking branch 'upstream/master' into dt-timing
jkbradley Aug 6, 2014
bcf874a
Merge remote-tracking branch 'upstream/master' into dt-timing
jkbradley Aug 7, 2014
f61e9d2
Merge remote-tracking branch 'upstream/master' into dt-timing
jkbradley Aug 8, 2014
3211f02
Optimizing DecisionTree
jkbradley Aug 8, 2014
0f676e2
Optimizations + Bug fix for DecisionTree
jkbradley Aug 8, 2014
b2ed1f3
Merge remote-tracking branch 'upstream/master' into dt-opt
jkbradley Aug 8, 2014
b914f3b
DecisionTree optimization: eliminated filters + small changes
jkbradley Aug 9, 2014
c1565a5
Small DecisionTree updates:
jkbradley Aug 11, 2014
a87e08f
Merge remote-tracking branch 'upstream/master' into dt-opt1
jkbradley Aug 14, 2014
8464a6e
Moved TimeTracker to tree/impl/ in its own file, and cleaned it up. …
jkbradley Aug 14, 2014
e66f1b1
TreePoint
jkbradley Aug 14, 2014
d036089
Print timing info to logDebug.
jkbradley Aug 14, 2014
430d782
Added more debug info on binning error. Added some docs.
jkbradley Aug 14, 2014
356daba
Merge branch 'dt-opt1' into dt-opt2
jkbradley Aug 14, 2014
26d10dd
Removed tree/model/Filter.scala since no longer used. Removed debugg…
jkbradley Aug 15, 2014
2d2aaaf
Merge remote-tracking branch 'upstream/master' into dt-opt1
jkbradley Aug 15, 2014
6b5651e
Updates based on code review. 1 major change: persisting to memory +…
jkbradley Aug 15, 2014
5f2dec2
Fixed scalastyle issue in TreePoint
jkbradley Aug 15, 2014
f40381c
Merge branch 'dt-opt1' into dt-opt2
jkbradley Aug 15, 2014
797f68a
Fixed DecisionTreeSuite bug for training second level. Needed to upd…
jkbradley Aug 15, 2014
931a3a7
Merge remote-tracking branch 'upstream/master' into dt-opt2
jkbradley Aug 15, 2014
6a38f48
Added DTMetadata class for cleaner code
jkbradley Aug 16, 2014
db0d773
scala style fix
jkbradley Aug 16, 2014
ac0b9f8
Small updates based on code review.
jkbradley Aug 16, 2014
3726d20
Small code improvements based on code review.
jkbradley Aug 17, 2014
a0ed0da
Renamed DTMetadata to DecisionTreeMetadata. Small doc updates.
jkbradley Aug 17, 2014
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 30 additions & 100 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.mllib.tree

import java.util.Calendar

import org.apache.spark.mllib.linalg.Vector

Expand All @@ -28,47 +27,16 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.Logging
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
import org.apache.spark.mllib.tree.impl.TreePoint
import org.apache.spark.mllib.tree.impurity.{Impurities, Gini, Entropy, Impurity}
import org.apache.spark.mllib.tree.impl.{TimeTracker, TreePoint}
import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity}
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.random.XORShiftRandom

class TimeTracker {

var tmpTime: Long = Calendar.getInstance().getTimeInMillis

def reset(): Unit = {
tmpTime = Calendar.getInstance().getTimeInMillis
}

def elapsed(): Long = {
Calendar.getInstance().getTimeInMillis - tmpTime
}

var initTime: Long = 0 // Data retag and cache
var findSplitsBinsTime: Long = 0
var extractNodeInfoTime: Long = 0
var extractInfoForLowerLevelsTime: Long = 0
var findBestSplitsTime: Long = 0
var binAggregatesTime: Long = 0
var chooseSplitsTime: Long = 0

override def toString: String = {
s"DecisionTree timing\n" +
s"initTime: $initTime\n" +
s"findSplitsBinsTime: $findSplitsBinsTime\n" +
s"extractNodeInfoTime: $extractNodeInfoTime\n" +
s"extractInfoForLowerLevelsTime: $extractInfoForLowerLevelsTime\n" +
s"findBestSplitsTime: $findBestSplitsTime\n" +
s"binAggregatesTime: $binAggregatesTime\n" +
s"chooseSplitsTime: $chooseSplitsTime\n"
}
}

/**
* :: Experimental ::
Expand All @@ -91,26 +59,26 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
def train(input: RDD[LabeledPoint]): DecisionTreeModel = {

val timer = new TimeTracker()
timer.reset()

timer.start("total")

// Cache input RDD for speedup during multiple passes.
timer.start("init")
val retaggedInput = input.retag(classOf[LabeledPoint])
logDebug("algo = " + strategy.algo)

timer.initTime += timer.elapsed()
timer.reset()
timer.stop("init")

// Find the splits and the corresponding bins (interval between the splits) using a sample
// of the input data.
timer.start("findSplitsBins")
val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(retaggedInput, strategy)
val numBins = bins(0).length
timer.stop("findSplitsBins")
logDebug("numBins = " + numBins)

timer.findSplitsBinsTime += timer.elapsed()

timer.reset()
val treeInput = TreePoint.convertToTreeRDD(retaggedInput, strategy, bins)
timer.initTime += timer.elapsed()
timer.start("init")
val treeInput = TreePoint.convertToTreeRDD(retaggedInput, strategy, bins).cache()
timer.stop("init")

// depth of the decision tree
val maxDepth = strategy.maxDepth
Expand All @@ -127,7 +95,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
val nodesInTree = Array.fill[Boolean](maxNumNodes)(false) // put into nodes array later?
nodesInTree(0) = true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it used anywhere else?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope; I intended to use it for another optimization (ignoring unused branches of the tree) but will postpone that. Will remove it for now.

// num features
val numFeatures = retaggedInput.take(1)(0).features.size
val numFeatures = treeInput.take(1)(0).features.size

// Calculate level for single group construction

Expand Down Expand Up @@ -155,10 +123,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
* still survived the filters of the parent nodes.
*/

var findBestSplitsTime: Long = 0
var extractNodeInfoTime: Long = 0
var extractInfoForLowerLevelsTime: Long = 0

var level = 0
var break = false
while (level <= maxDepth && !break) {
Expand All @@ -169,10 +133,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo


// Find best split for all nodes at a level.
timer.reset()
timer.start("findBestSplits")
val splitsStatsForLevel = DecisionTree.findBestSplits(treeInput, parentImpurities,
strategy, level, nodes, splits, bins, maxLevelForSingleGroup, unorderedFeatures, timer)
timer.findBestSplitsTime += timer.elapsed()
timer.stop("findBestSplits")

val levelNodeIndexOffset = math.pow(2, level).toInt - 1
for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
Expand All @@ -186,9 +150,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
// if (level == 0 || (nodesInTree(parentNodeIndex) && !nodes(parentNodeIndex).isLeaf))
// TODO: Use above check to skip unused branch of tree
// Extract info for this node (index) at the current level.
timer.reset()
timer.start("extractNodeInfo")
extractNodeInfo(nodeSplitStats, level, index, nodes)
timer.extractNodeInfoTime += timer.elapsed()
timer.stop("extractNodeInfo")
if (level != 0) {
// Set parent.
if (isLeftChild) {
Expand All @@ -198,9 +162,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
}
}
// Extract info for nodes at the next lower level.
timer.reset()
timer.start("extractInfoForLowerLevels")
extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities)
timer.extractInfoForLowerLevelsTime += timer.elapsed()
timer.stop("extractInfoForLowerLevels")
logDebug("final best split = " + nodeSplitStats._1)
}
require(math.pow(2, level) == splitsStatsForLevel.length)
Expand All @@ -215,8 +179,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
}
}

println(timer)

logDebug("#####################################")
logDebug("Extracting tree model")
logDebug("#####################################")
Expand All @@ -226,6 +188,11 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
// Build the full tree using the node info calculated in the level-wise best split calculations.
topNode.build(nodes)

timer.stop("total")

logDebug("Internal timing for DecisionTree:")
logDebug(s"$timer")

new DecisionTreeModel(topNode, strategy.algo)
}

Expand Down Expand Up @@ -257,11 +224,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
parentImpurities: Array[Double]): Unit = {
if (level >= maxDepth)
return
//filters: Array[List[Filter]]): Unit = {
// 0 corresponds to the left child node and 1 corresponds to the right child node.
var i = 0
while (i <= 1) {
// Calculate the index of the node from the node level and the index at the current level.
// Calculate the index of the node from the node level and the index at the current level.
val nodeIndex = math.pow(2, level + 1).toInt - 1 + 2 * index + i
val impurity = if (i == 0) {
nodeSplitStats._2.leftImpurity
Expand All @@ -273,13 +239,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
parentImpurities(nodeIndex) = impurity
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since most of the code is removed, maybe we can unfold the while loop:

val leftNodeIndex = math.pow(2, level + 1).toInt - 1 + 2 * index
parentImpurities(leftNodeIndex) = nodeSplitStats._2.leftImpurity
val rightNodeIndex = leftNodeIndex + 1
parentImpurities(rightNodeIndex) = nodeSplitStats._2.rightImpurity

Btw, could the code be simplified if we don't use leftNode and rightNode but childNodes: Array[Node]?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure that would help here, but at some point down the line (not this PR), I do want to include this impurity info in the growing tree structure itself. I will make a JIRA for it.

// noting the parents filters for the child nodes
val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1)
/*
filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2)
//println(s"extractInfoForLowerLevels: Set filters(node:$nodeIndex): ${filters(nodeIndex).mkString(", ")}")
for (filter <- filters(nodeIndex)) {
logDebug("Filter = " + filter)
}
*/
i += 1
}
}
Expand Down Expand Up @@ -516,7 +475,6 @@ object DecisionTree extends Serializable with Logging {
unorderedFeatures: Set[Int],
timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = {
// split into groups to avoid memory overflow during aggregation
//println(s"findBestSplits: level = $level")
if (level > maxLevelForSingleGroup) {
// When information for all nodes at a given level cannot be stored in memory,
// the nodes are divided into multiple groups at each level with the number of groups
Expand All @@ -540,7 +498,7 @@ object DecisionTree extends Serializable with Logging {
}
}

/**
/**
* Returns an array of optimal splits for a group of nodes at a given level
*
* @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
Expand Down Expand Up @@ -898,39 +856,15 @@ object DecisionTree extends Serializable with Logging {
combinedAggregate
}

timer.reset()

// Calculate bin aggregates.
timer.start("binAggregates")
val binAggregates = {
input.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp, binCombOp)
}
timer.stop("binAggregates")
logDebug("binAggregates.length = " + binAggregates.length)

timer.binAggregatesTime += timer.elapsed()
//2 * numClasses * numBins * numFeatures * numNodes for unordered features.
// (left/right, node, feature, bin, label)
/*
println(s"binAggregates:")
for (i <- Range(0,2)) {
for (n <- Range(0,numNodes)) {
for (f <- Range(0,numFeatures)) {
for (b <- Range(0,4)) {
for (c <- Range(0,numClasses)) {
val idx = i * numClasses * numBins * numFeatures * numNodes +
n * numClasses * numBins * numFeatures +
f * numBins * numFeatures +
b * numFeatures +
c
if (binAggregates(idx) != 0) {
println(s"\t ($i, c:$c, b:$b, f:$f, n:$n): ${binAggregates(idx)}")
}
}
}
}
}
}
*/

/**
* Calculate the information gain for a given (feature, split) based upon left/right aggregates.
* @param leftNodeAgg left node aggregates for this (feature, split)
Expand Down Expand Up @@ -965,7 +899,6 @@ object DecisionTree extends Serializable with Logging {
val totalCount = leftTotalCount + rightTotalCount
if (totalCount == 0) {
// Return arbitrary prediction.
//println(s"BLAH: feature $featureIndex, split $splitIndex. totalCount == 0")
return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0)
}

Expand Down Expand Up @@ -1311,7 +1244,6 @@ object DecisionTree extends Serializable with Logging {
bestGainStats = gainStats
bestFeatureIndex = featureIndex
bestSplitIndex = splitIndex
//println(s" feature $featureIndex UPGRADED split $splitIndex: ${splits(featureIndex)(splitIndex)}: gainstats: $gainStats")
}
splitIndex += 1
}
Expand Down Expand Up @@ -1356,9 +1288,8 @@ object DecisionTree extends Serializable with Logging {
}
}

timer.reset()

// Calculate best splits for all nodes at a given level
timer.start("chooseSplits")
val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
// Iterating over all nodes at this level
var node = 0
Expand All @@ -1369,10 +1300,9 @@ object DecisionTree extends Serializable with Logging {
val parentNodeImpurity = parentImpurities(nodeImpurityIndex)
logDebug("parent node impurity = " + parentNodeImpurity)
bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity)
//println(s"bestSplits(node:$node): ${bestSplits(node)}")
node += 1
}
timer.chooseSplitsTime += timer.elapsed()
timer.stop("chooseSplits")

bestSplits
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,30 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
/**
* :: Experimental ::
* Stores all the configuration options for tree construction
* @param algo classification or regression
* @param impurity criterion used for information gain calculation
* @param algo Learning goal. Supported:
* [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
* [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
* @param impurity Criterion used for information gain calculation.
* Supported for Classification: [[org.apache.spark.mllib.tree.impurity.Gini]],
* [[org.apache.spark.mllib.tree.impurity.Entropy]].
* Supported for Regression: [[org.apache.spark.mllib.tree.impurity.Variance]].
* @param maxDepth Maximum depth of the tree.
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
* @param numClassesForClassification number of classes for classification. Default value is 2
* leads to binary classification
* @param maxBins maximum number of bins used for splitting features
* @param quantileCalculationStrategy algorithm for calculating quantiles
* @param numClassesForClassification Number of classes for classification.
* (Ignored for regression.)
* Default value is 2 (binary classification).
* @param maxBins Maximum number of bins used for discretizing continuous features and
* for choosing how to split on features at each node.
* More bins give higher granularity.
* @param quantileCalculationStrategy Algorithm for calculating quantiles. Supported:
* [[org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort]]
* @param categoricalFeaturesInfo A map storing information about the categorical variables and the
* number of discrete values they take. For example, an entry (n ->
* k) implies the feature n is categorical with k categories 0,
* 1, 2, ... , k-1. It's important to note that features are
* zero-indexed.
* @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is
* @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is
* 128 MB.
*
*/
@Experimental
class Strategy (
Expand All @@ -64,20 +72,7 @@ class Strategy (
= isMulticlassClassification && (categoricalFeaturesInfo.size > 0)

/**
* Java-friendly constructor.
*
* @param algo classification or regression
* @param impurity criterion used for information gain calculation
* @param maxDepth Maximum depth of the tree.
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
* @param numClassesForClassification number of classes for classification. Default value is 2
* leads to binary classification
* @param maxBins maximum number of bins used for splitting features
* @param categoricalFeaturesInfo A map storing information about the categorical variables and
* the number of discrete values they take. For example, an entry
* (n -> k) implies the feature n is categorical with k categories
* 0, 1, 2, ... , k-1. It's important to note that features are
* zero-indexed.
* Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]]
*/
def this(
algo: Algo,
Expand All @@ -90,6 +85,10 @@ class Strategy (
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap)
}

/**
* Check validity of parameters.
* Throws exception if invalid.
*/
private[tree] def assertValid(): Unit = {
algo match {
case Classification =>
Expand Down
Loading