Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a95bc22
timing for DecisionTree internals
jkbradley Aug 5, 2014
511ec85
Merge remote-tracking branch 'upstream/master' into dt-timing
jkbradley Aug 6, 2014
bcf874a
Merge remote-tracking branch 'upstream/master' into dt-timing
jkbradley Aug 7, 2014
f61e9d2
Merge remote-tracking branch 'upstream/master' into dt-timing
jkbradley Aug 8, 2014
3211f02
Optimizing DecisionTree
jkbradley Aug 8, 2014
0f676e2
Optimizations + Bug fix for DecisionTree
jkbradley Aug 8, 2014
b2ed1f3
Merge remote-tracking branch 'upstream/master' into dt-opt
jkbradley Aug 8, 2014
b914f3b
DecisionTree optimization: eliminated filters + small changes
jkbradley Aug 9, 2014
c1565a5
Small DecisionTree updates:
jkbradley Aug 11, 2014
a87e08f
Merge remote-tracking branch 'upstream/master' into dt-opt1
jkbradley Aug 14, 2014
8464a6e
Moved TimeTracker to tree/impl/ in its own file, and cleaned it up. …
jkbradley Aug 14, 2014
e66f1b1
TreePoint
jkbradley Aug 14, 2014
d036089
Print timing info to logDebug.
jkbradley Aug 14, 2014
430d782
Added more debug info on binning error. Added some docs.
jkbradley Aug 14, 2014
356daba
Merge branch 'dt-opt1' into dt-opt2
jkbradley Aug 14, 2014
26d10dd
Removed tree/model/Filter.scala since no longer used. Removed debugg…
jkbradley Aug 15, 2014
2d2aaaf
Merge remote-tracking branch 'upstream/master' into dt-opt1
jkbradley Aug 15, 2014
6b5651e
Updates based on code review. 1 major change: persisting to memory +…
jkbradley Aug 15, 2014
5f2dec2
Fixed scalastyle issue in TreePoint
jkbradley Aug 15, 2014
f40381c
Merge branch 'dt-opt1' into dt-opt2
jkbradley Aug 15, 2014
797f68a
Fixed DecisionTreeSuite bug for training second level. Needed to upd…
jkbradley Aug 15, 2014
931a3a7
Merge remote-tracking branch 'upstream/master' into dt-opt2
jkbradley Aug 15, 2014
6a38f48
Added DTMetadata class for cleaner code
jkbradley Aug 16, 2014
db0d773
scala style fix
jkbradley Aug 16, 2014
ac0b9f8
Small updates based on code review.
jkbradley Aug 16, 2014
3726d20
Small code improvements based on code review.
jkbradley Aug 17, 2014
a0ed0da
Renamed DTMetadata to DecisionTreeMetadata. Small doc updates.
jkbradley Aug 17, 2014
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Updates based on code review. 1 major change: persisting to memory + …
…disk, not just memory.

Details:

DecisionTree
* Changed: .cache() -> .persist(StorageLevel.MEMORY_AND_DISK)
** This gave major performance improvements on small tests.  E.g., 500K examples, 500 features, depth 5, on MacBook, took 292 sec with cache() and 112 when using disk as well.
* Change for to while loops
* Small cleanups

TimeTracker
* Removed useless timing in DecisionTree

TreePoint
* Renamed features to binnedFeatures
  • Loading branch information
jkbradley committed Aug 15, 2014
commit 6b5651e7671315f78aef42344ab514e3cf8052df
49 changes: 24 additions & 25 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.mllib.tree


import scala.collection.JavaConverters._

import org.apache.spark.annotation.Experimental
Expand All @@ -32,6 +31,7 @@ import org.apache.spark.mllib.tree.impl.{TimeTracker, TreePoint}
import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity}
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.random.XORShiftRandom


Expand Down Expand Up @@ -59,11 +59,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo

timer.start("total")

// Cache input RDD for speedup during multiple passes.
timer.start("init")

val retaggedInput = input.retag(classOf[LabeledPoint])
logDebug("algo = " + strategy.algo)
timer.stop("init")

// Find the splits and the corresponding bins (interval between the splits) using a sample
// of the input data.
Expand All @@ -73,9 +72,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
timer.stop("findSplitsBins")
logDebug("numBins = " + numBins)

timer.start("init")
val treeInput = TreePoint.convertToTreeRDD(retaggedInput, strategy, bins).cache()
timer.stop("init")
// Cache input RDD for speedup during multiple passes.
val treeInput = TreePoint.convertToTreeRDD(retaggedInput, strategy, bins)
.persist(StorageLevel.MEMORY_AND_DISK)

// depth of the decision tree
val maxDepth = strategy.maxDepth
Expand All @@ -90,7 +89,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
// dummy value for top node (updated during first split calculation)
val nodes = new Array[Node](maxNumNodes)
// num features
val numFeatures = treeInput.take(1)(0).features.size
val numFeatures = treeInput.take(1)(0).binnedFeatures.size

// Calculate level for single group construction

Expand All @@ -110,6 +109,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
(math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt, 0)
logDebug("max level for single group = " + maxLevelForSingleGroup)

timer.stop("init")

/*
* The main idea here is to perform level-wise training of the decision tree nodes thus
* reducing the passes over the data from l to log2(l) where l is the total number of nodes.
Expand All @@ -126,7 +127,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
logDebug("level = " + level)
logDebug("#####################################")


// Find best split for all nodes at a level.
timer.start("findBestSplits")
val splitsStatsForLevel = DecisionTree.findBestSplits(treeInput, parentImpurities,
Expand Down Expand Up @@ -167,8 +167,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo

timer.stop("total")

logDebug("Internal timing for DecisionTree:")
logDebug(s"$timer")
logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer")

new DecisionTreeModel(topNode, strategy.algo)
}
Expand Down Expand Up @@ -226,7 +226,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
}
}


object DecisionTree extends Serializable with Logging {

/**
Expand Down Expand Up @@ -536,7 +535,7 @@ object DecisionTree extends Serializable with Logging {
logDebug("numNodes = " + numNodes)

// Find the number of features by looking at the first sample.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should get numFeatures from metadata. Calling first() triggers at least one job.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I forgot to port that change from my next PR.

val numFeatures = input.first().features.size
val numFeatures = input.first().binnedFeatures.size
logDebug("numFeatures = " + numFeatures)

// numBins: Number of bins = 1 + number of possible splits
Expand Down Expand Up @@ -578,12 +577,12 @@ object DecisionTree extends Serializable with Logging {
}

// Apply each filter and check sample validity. Return false when invalid condition found.
for (filter <- parentFilters) {
parentFilters.foreach { filter =>
val featureIndex = filter.split.feature
val comparison = filter.comparison
val isFeatureContinuous = filter.split.featureType == Continuous
if (isFeatureContinuous) {
val binId = treePoint.features(featureIndex)
val binId = treePoint.binnedFeatures(featureIndex)
val bin = bins(featureIndex)(binId)
val featureValue = bin.highSplit.threshold
val threshold = filter.split.threshold
Expand All @@ -598,9 +597,9 @@ object DecisionTree extends Serializable with Logging {
val isUnorderedFeature =
isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits
val featureValue = if (isUnorderedFeature) {
treePoint.features(featureIndex)
treePoint.binnedFeatures(featureIndex)
} else {
val binId = treePoint.features(featureIndex)
val binId = treePoint.binnedFeatures(featureIndex)
bins(featureIndex)(binId).category
}
val containsFeature = filter.split.categories.contains(featureValue)
Expand Down Expand Up @@ -648,9 +647,8 @@ object DecisionTree extends Serializable with Logging {
arr(shift) = InvalidBinIndex
} else {
var featureIndex = 0
// TODO: Vectorize this
while (featureIndex < numFeatures) {
arr(shift + featureIndex) = treePoint.features(featureIndex)
arr(shift + featureIndex) = treePoint.binnedFeatures(featureIndex)
featureIndex += 1
}
}
Expand All @@ -660,9 +658,8 @@ object DecisionTree extends Serializable with Logging {
}

// Find feature bins for all nodes at a level.
timer.start("findBinsForLevel")
timer.start("aggregation")
val binMappedRDD = input.map(x => findBinsForLevel(x))
timer.stop("findBinsForLevel")

/**
* Increment aggregate in location for (node, feature, bin, label).
Expand Down Expand Up @@ -907,13 +904,11 @@ object DecisionTree extends Serializable with Logging {
combinedAggregate
}


// Calculate bin aggregates.
timer.start("binAggregates")
val binAggregates = {
binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp)
}
timer.stop("binAggregates")
timer.stop("aggregation")
logDebug("binAggregates.length = " + binAggregates.length)

/**
Expand Down Expand Up @@ -1225,12 +1220,16 @@ object DecisionTree extends Serializable with Logging {
nodeImpurity: Double): Array[Array[InformationGainStats]] = {
val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1)

for (featureIndex <- 0 until numFeatures) {
var featureIndex = 0
while (featureIndex < numFeatures) {
val numSplitsForFeature = getNumSplitsForFeature(featureIndex)
for (splitIndex <- 0 until numSplitsForFeature) {
var splitIndex = 0
while (splitIndex < numSplitsForFeature) {
gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex,
splitIndex, rightNodeAgg, nodeImpurity)
splitIndex += 1
}
featureIndex += 1
}
gains
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ import org.apache.spark.annotation.Experimental
* Time tracker implementation which holds labeled timers.
*/
@Experimental
private[tree]
class TimeTracker extends Serializable {
private[tree] class TimeTracker extends Serializable {

private val starts: MutableHashMap[String, Long] = new MutableHashMap[String, Long]()

Expand All @@ -36,24 +35,24 @@ class TimeTracker extends Serializable {
* Starts a new timer, or re-starts a stopped timer.
*/
def start(timerLabel: String): Unit = {
val tmpTime = System.nanoTime()
val currentTime = System.nanoTime()
if (starts.contains(timerLabel)) {
throw new RuntimeException(s"TimeTracker.start(timerLabel) called again on" +
s" timerLabel = $timerLabel before that timer was stopped.")
}
starts(timerLabel) = tmpTime
starts(timerLabel) = currentTime
}

/**
* Stops a timer and returns the elapsed time in seconds.
*/
def stop(timerLabel: String): Double = {
val tmpTime = System.nanoTime()
val currentTime = System.nanoTime()
if (!starts.contains(timerLabel)) {
throw new RuntimeException(s"TimeTracker.stop(timerLabel) called on" +
s" timerLabel = $timerLabel, but that timer was not started.")
}
val elapsed = tmpTime - starts(timerLabel)
val elapsed = currentTime - starts(timerLabel)
starts.remove(timerLabel)
if (totals.contains(timerLabel)) {
totals(timerLabel) += elapsed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,12 @@ import org.apache.spark.rdd.RDD
* or any categorical feature used in regression or binary classification.
*
* @param label Label from LabeledPoint
* @param features Binned feature values.
* Same length as LabeledPoint.features, but values are bin indices.
* @param binnedFeatures Binned feature values.
* Same length as LabeledPoint.features, but values are bin indices.
*/
private[tree] class TreePoint(val label: Double, val features: Array[Int]) extends Serializable {
private[tree] class TreePoint(val label: Double, val binnedFeatures: Array[Int]) extends Serializable {
}


private[tree] object TreePoint {

/**
Expand Down Expand Up @@ -76,7 +75,7 @@ private[tree] object TreePoint {
val numFeatures = labeledPoint.features.size
val numBins = bins(0).size
val arr = new Array[Int](numFeatures)
var featureIndex = 0 // offset by 1 for label
var featureIndex = 0
while (featureIndex < numFeatures) {
val featureInfo = categoricalFeaturesInfo.get(featureIndex)
val isFeatureContinuous = featureInfo.isEmpty
Expand All @@ -98,7 +97,6 @@ private[tree] object TreePoint {
new TreePoint(labeledPoint.label, arr)
}


/**
* Find bin for one (labeledPoint, feature).
*
Expand Down Expand Up @@ -129,11 +127,9 @@ private[tree] object TreePoint {
val highThreshold = bin.highSplit.threshold
if ((lowThreshold < feature) && (highThreshold >= feature)) {
return mid
}
else if (lowThreshold >= feature) {
} else if (lowThreshold >= feature) {
right = mid - 1
}
else {
} else {
left = mid + 1
}
}
Expand Down Expand Up @@ -181,7 +177,8 @@ private[tree] object TreePoint {
// Perform binary search for finding bin for continuous features.
val binIndex = binarySearchForBins()
if (binIndex == -1) {
throw new UnknownError("No bin was found for continuous feature." +
throw new RuntimeException("No bin was found for continuous feature." +
" This error can occur when given invalid data values (such as NaN)." +
s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}")
}
binIndex
Expand All @@ -193,7 +190,8 @@ private[tree] object TreePoint {
sequentialBinSearchForOrderedCategoricalFeature()
}
if (binIndex == -1) {
throw new UnknownError("No bin was found for categorical feature." +
throw new RuntimeException("No bin was found for categorical feature." +
" This error can occur when given invalid data values (such as NaN)." +
s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}")
}
binIndex
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,16 @@

package org.apache.spark.mllib.tree

import org.apache.spark.mllib.tree.impl.TreePoint

import scala.collection.JavaConverters._

import org.scalatest.FunSuite

import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split}
import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy}
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy}
import org.apache.spark.mllib.tree.impl.TreePoint
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.regression.LabeledPoint
Expand All @@ -43,10 +42,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
prediction != expected.label
}
val accuracy = (input.length - numOffPredictions).toDouble / input.length
if (accuracy < requiredAccuracy) {
println(s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.")
}
assert(accuracy >= requiredAccuracy)
assert(accuracy >= requiredAccuracy,
s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.")
}

def validateRegressor(
Expand All @@ -59,7 +56,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
err * err
}.sum
val mse = squaredError / input.length
assert(mse <= requiredMSE)
assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.")
}

test("split and bin calculation") {
Expand Down