Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions docs/mllib-classification-regression.md
Original file line number Diff line number Diff line change
Expand Up @@ -294,12 +294,9 @@ The recursive tree construction is stopped at a node when one of the two conditi
1. The node depth is equal to the `maxDepth` training paramemter
2. No split candidate leads to an information gain at the node.

### Practical Limitations

The tree implementation stores an Array[Double] of size *O(#features \* #splits \* 2^maxDepth)* in memory for aggregating histograms over partitions. The current implementation might not scale to very deep trees since the memory requirement grows exponentially with tree depth.

Please drop us a line if you encounter any issues. We are planning to solve this problem in the near future and real-world examples will be great.
### Implementation Details
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, the decision tree guide is now in mllib-decision-tree.md.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.


The tree implementation stores an Array[Double] of size *O(#features \* #splits \* 2^maxDepth)* in memory for aggregating histograms over partitions. Based upon the 'maxMemory' parameter set during training (default is 128 MB), the task is broken down into smaller groups to avoid out-of-memory errors during computation.

## Implementation in MLlib

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.random.XORShiftRandom
import org.apache.spark.util.Utils.memoryStringToMb
import org.apache.spark.mllib.linalg.{Vector, Vectors}

/**
Expand Down Expand Up @@ -58,7 +59,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
// Find the splits and the corresponding bins (interval between the splits) using a sample
// of the input data.
val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
logDebug("numSplits = " + bins(0).length)
val numBins = bins(0).length
logDebug("numBins = " + numBins)

// depth of the decision tree
val maxDepth = strategy.maxDepth
Expand All @@ -72,7 +74,28 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
val parentImpurities = new Array[Double](maxNumNodes)
// dummy value for top node (updated during first split calculation)
val nodes = new Array[Node](maxNumNodes)
// num features
val numFeatures = input.take(1)(0).features.size

// Calculate level for single group construction

// Max memory usage for aggregates
val maxMemoryUsage = strategy.maxMemory * 1024 * 1024
logDebug("max memory usage for aggregates = " + maxMemoryUsage)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+ "bytes."?

val numElementsPerNode = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary to have an extra { } pair.

strategy.algo match {
case Classification => 2 * numBins * numFeatures
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove extra space between * and numFeatures.

case Regression => 3 * numBins * numFeatures
}
}
logDebug("numElementsPerNode = " + numElementsPerNode)
val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array
val maxNumberOfNodesPerGroup = scala.math.max(maxMemoryUsage / arraySizePerNode, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just use math.max

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@techaddict Happy to change it. It is cosmetic or is there something more to it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@manishamde just cleanliness.

logDebug("maxNumberOfNodesPerGroup = " + maxNumberOfNodesPerGroup)
// nodes at a level is 2^(level-1). level is zero indexed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nodes -> Number of nodes. If level is zero indexed, when level = 0, we have 2^(level - 1) = 1/2. Is it expected?

val maxLevelForSingleGroup = scala.math.max(
(scala.math.log(maxNumberOfNodesPerGroup) / scala.math.log(2)).floor.toInt - 1, 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

logDebug("max level for single group = " + maxLevelForSingleGroup)

/*
* The main idea here is to perform level-wise training of the decision tree nodes thus
Expand All @@ -92,7 +115,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo

// Find best split for all nodes at a level.
val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy,
level, filters, splits, bins)
level, filters, splits, bins, maxLevelForSingleGroup)

for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
// Extract info for nodes at the current level.
Expand All @@ -110,6 +133,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
}
}

logDebug("#####################################")
logDebug("Extracting tree model")
logDebug("#####################################")

// Initialize the top or root node of the tree.
val topNode = nodes(0)
// Build the full tree using the node info calculated in the level-wise best split calculations.
Expand Down Expand Up @@ -260,6 +287,7 @@ object DecisionTree extends Serializable with Logging {
* @param filters Filters for all nodes at a given level
* @param splits possible splits for all features
* @param bins possible bins for all features
* @param maxLevelForSingleGroup the deepest level for single-group level-wise computation.
* @return array of splits with best splits for all nodes at a given level.
*/
protected[tree] def findBestSplits(
Expand All @@ -269,7 +297,50 @@ object DecisionTree extends Serializable with Logging {
level: Int,
filters: Array[List[Filter]],
splits: Array[Array[Split]],
bins: Array[Array[Bin]]): Array[(Split, InformationGainStats)] = {
bins: Array[Array[Bin]],
maxLevelForSingleGroup: Int): Array[(Split, InformationGainStats)] = {
// split into groups to avoid memory overflow during aggregation
if (level > maxLevelForSingleGroup) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove extra space after >.

val numGroups = scala.math.pow(2, (level - maxLevelForSingleGroup)).toInt
logDebug("numGroups = " + numGroups)
var groupIndex = 0
var bestSplits = new Array[(Split, InformationGainStats)](0)
while (groupIndex < numGroups) {
val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level,
filters, splits, bins, numGroups, groupIndex)
bestSplits = Array.concat(bestSplits, bestSplitsForGroup)
groupIndex += 1
}
bestSplits
} else {
findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins)
}
}

/**
* Returns an array of optimal splits for a group of nodes at a given level
*
* @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
* for DecisionTree
* @param parentImpurities Impurities for all parent nodes for the current level
* @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
* parameters for construction the DecisionTree
* @param level Level of the tree
* @param filters Filters for all nodes at a given level
* @param splits possible splits for all features
* @param bins possible bins for all features
* @return array of splits with best splits for all nodes at a given level.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add docs for numGroups and groupIndex.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mengxr I already added more documentation. Is there something more I am missing here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the docs of numGroups and groupIndex.

*/
private def findBestSplitsPerGroup(
input: RDD[LabeledPoint],
parentImpurities: Array[Double],
strategy: Strategy,
level: Int,
filters: Array[List[Filter]],
splits: Array[Array[Split]],
bins: Array[Array[Bin]],
numGroups: Int = 1,
groupIndex: Int = 0): Array[(Split, InformationGainStats)] = {

/*
* The high-level description for the best split optimizations are noted here.
Expand All @@ -296,20 +367,23 @@ object DecisionTree extends Serializable with Logging {
*/

// common calculations for multiple nested methods
val numNodes = scala.math.pow(2, level).toInt
val numNodes = scala.math.pow(2, level).toInt / numGroups
logDebug("numNodes = " + numNodes)
// Find the number of features by looking at the first sample.
val numFeatures = input.first().features.size
logDebug("numFeatures = " + numFeatures)
val numBins = bins(0).length
logDebug("numBins = " + numBins)

// shift when more than one group is used at deep tree level
val groupShift = numNodes * groupIndex

/** Find the filters used before reaching the current code. */
def findParentFilters(nodeIndex: Int): List[Filter] = {
if (level == 0) {
List[Filter]()
} else {
val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex
val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex + groupShift
filters(nodeFilterIndex)
}
}
Expand Down Expand Up @@ -878,7 +952,7 @@ object DecisionTree extends Serializable with Logging {
// Iterating over all nodes at this level
var node = 0
while (node < numNodes) {
val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node
val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node + groupShift
val binsForNode: Array[Double] = getBinDataForNode(node)
logDebug("nodeImpurityIndex = " + nodeImpurityIndex)
val parentNodeImpurity = parentImpurities(nodeImpurityIndex)
Expand Down Expand Up @@ -1085,10 +1159,13 @@ object DecisionTree extends Serializable with Logging {

val maxDepth = options.getOrElse('maxDepth, "1").toString.toInt
val maxBins = options.getOrElse('maxBins, "100").toString.toInt
val maxMemUsage = memoryStringToMb(options.getOrElse('maxMemory, "128m").toString)

val strategy = new Strategy(algo, impurity, maxDepth, maxBins)
val strategy = new Strategy(algo, impurity, maxDepth, maxBins, maxMemory=maxMemUsage)
val model = DecisionTree.train(trainData, strategy)



// Load test data.
val testData = loadLabeledData(sc, options.get('testDataDir).get.toString)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* k) implies the feature n is categorical with k categories 0,
* 1, 2, ... , k-1. It's important to note that features are
* zero-indexed.
* @param maxMemory maximum memory in MB allocated to histogram aggregation. Default value is
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maxMemory -> maxMemoryInMB?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

* 128 MB.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the indentation correct?

*
*/
@Experimental
class Strategy (
Expand All @@ -43,4 +46,5 @@ class Strategy (
val maxDepth: Int,
val maxBins: Int = 100,
val quantileCalculationStrategy: QuantileStrategy = Sort,
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int]()) extends Serializable
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
val maxMemory: Int = 128) extends Serializable
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
Array[List[Filter]](), splits, bins)
Array[List[Filter]](), splits, bins, 10)

val split = bestSplits(0)._1
assert(split.categories.length === 1)
Expand All @@ -281,7 +281,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
Array[List[Filter]](), splits, bins)
Array[List[Filter]](), splits, bins, 10)

val split = bestSplits(0)._1
assert(split.categories.length === 1)
Expand Down Expand Up @@ -310,7 +310,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(bins(0).length === 100)

val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
Array[List[Filter]](), splits, bins)
Array[List[Filter]](), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._1.threshold === 10)
Expand All @@ -333,7 +333,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(bins(0).length === 100)

val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
Array[List[Filter]](), splits, bins)
Array[List[Filter]](), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._1.threshold === 10)
Expand All @@ -357,7 +357,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(bins(0).length === 100)

val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
Array[List[Filter]](), splits, bins)
Array[List[Filter]](), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._1.threshold === 10)
Expand All @@ -381,7 +381,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(bins(0).length === 100)

val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
Array[List[Filter]](), splits, bins)
Array[List[Filter]](), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._1.threshold === 10)
Expand Down