Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a95bc22
timing for DecisionTree internals
jkbradley Aug 5, 2014
511ec85
Merge remote-tracking branch 'upstream/master' into dt-timing
jkbradley Aug 6, 2014
bcf874a
Merge remote-tracking branch 'upstream/master' into dt-timing
jkbradley Aug 7, 2014
f61e9d2
Merge remote-tracking branch 'upstream/master' into dt-timing
jkbradley Aug 8, 2014
3211f02
Optimizing DecisionTree
jkbradley Aug 8, 2014
0f676e2
Optimizations + Bug fix for DecisionTree
jkbradley Aug 8, 2014
b2ed1f3
Merge remote-tracking branch 'upstream/master' into dt-opt
jkbradley Aug 8, 2014
b914f3b
DecisionTree optimization: eliminated filters + small changes
jkbradley Aug 9, 2014
c1565a5
Small DecisionTree updates:
jkbradley Aug 11, 2014
a87e08f
Merge remote-tracking branch 'upstream/master' into dt-opt1
jkbradley Aug 14, 2014
8464a6e
Moved TimeTracker to tree/impl/ in its own file, and cleaned it up. …
jkbradley Aug 14, 2014
e66f1b1
TreePoint
jkbradley Aug 14, 2014
d036089
Print timing info to logDebug.
jkbradley Aug 14, 2014
430d782
Added more debug info on binning error. Added some docs.
jkbradley Aug 14, 2014
356daba
Merge branch 'dt-opt1' into dt-opt2
jkbradley Aug 14, 2014
26d10dd
Removed tree/model/Filter.scala since no longer used. Removed debugg…
jkbradley Aug 15, 2014
2d2aaaf
Merge remote-tracking branch 'upstream/master' into dt-opt1
jkbradley Aug 15, 2014
6b5651e
Updates based on code review. 1 major change: persisting to memory +…
jkbradley Aug 15, 2014
5f2dec2
Fixed scalastyle issue in TreePoint
jkbradley Aug 15, 2014
f40381c
Merge branch 'dt-opt1' into dt-opt2
jkbradley Aug 15, 2014
797f68a
Fixed DecisionTreeSuite bug for training second level. Needed to upd…
jkbradley Aug 15, 2014
931a3a7
Merge remote-tracking branch 'upstream/master' into dt-opt2
jkbradley Aug 15, 2014
6a38f48
Added DTMetadata class for cleaner code
jkbradley Aug 16, 2014
db0d773
scala style fix
jkbradley Aug 16, 2014
ac0b9f8
Small updates based on code review.
jkbradley Aug 16, 2014
3726d20
Small code improvements based on code review.
jkbradley Aug 17, 2014
a0ed0da
Renamed DTMetadata to DecisionTreeMetadata. Small doc updates.
jkbradley Aug 17, 2014
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Renamed DTMetadata to DecisionTreeMetadata. Small doc updates.
  • Loading branch information
jkbradley committed Aug 17, 2014
commit a0ed0daa4c3622e19626de7aa3b29e07c6015ff2
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
import org.apache.spark.mllib.tree.impl.{DTMetadata, TimeTracker, TreePoint}
import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TimeTracker, TreePoint}
import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity}
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -62,7 +62,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
timer.start("init")

val retaggedInput = input.retag(classOf[LabeledPoint])
val metadata = DTMetadata.buildMetadata(retaggedInput, strategy)
val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, strategy)
logDebug("algo = " + strategy.algo)

// Find the splits and the corresponding bins (interval between the splits) using a sample
Expand Down Expand Up @@ -443,7 +443,7 @@ object DecisionTree extends Serializable with Logging {
protected[tree] def findBestSplits(
input: RDD[TreePoint],
parentImpurities: Array[Double],
metadata: DTMetadata,
metadata: DecisionTreeMetadata,
level: Int,
nodes: Array[Node],
splits: Array[Array[Split]],
Expand Down Expand Up @@ -489,7 +489,7 @@ object DecisionTree extends Serializable with Logging {
private def findBestSplitsPerGroup(
input: RDD[TreePoint],
parentImpurities: Array[Double],
metadata: DTMetadata,
metadata: DecisionTreeMetadata,
level: Int,
nodes: Array[Node],
splits: Array[Array[Split]],
Expand Down Expand Up @@ -551,7 +551,9 @@ object DecisionTree extends Serializable with Logging {

/**
* Get the node index corresponding to this data point.
* This is used during training, mimicking prediction.
* This function mimics prediction, passing an example from the root node down to a node
* at the current level being trained; that node's index is returned.
*
* @return Leaf index if the data point reaches a leaf.
* Otherwise, last node reachable in tree matching this example.
*/
Expand Down Expand Up @@ -608,7 +610,8 @@ object DecisionTree extends Serializable with Logging {
val levelOffset = (1 << level) - 1

/**
* Find the node (indexed from 0 at the start of this level) for the given example.
* Find the node index for the given example.
* Nodes are indexed from 0 at the start of this (level, group).
* If the example does not reach this level, returns a value < 0.
*/
def treePointToNodeIndex(treePoint: TreePoint): Int = {
Expand Down Expand Up @@ -1261,7 +1264,7 @@ object DecisionTree extends Serializable with Logging {
*
* @param numBins Number of bins = 1 + number of possible splits.
*/
private def getElementsPerNode(metadata: DTMetadata, numBins: Int): Int = {
private def getElementsPerNode(metadata: DecisionTreeMetadata, numBins: Int): Int = {
if (metadata.isClassification) {
if (metadata.isMulticlassWithCategoricalFeatures) {
2 * metadata.numClasses * numBins * metadata.numFeatures
Expand Down Expand Up @@ -1304,7 +1307,7 @@ object DecisionTree extends Serializable with Logging {
*/
protected[tree] def findSplitsBins(
input: RDD[LabeledPoint],
metadata: DTMetadata): (Array[Array[Split]], Array[Array[Bin]]) = {
metadata: DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = {

val count = input.count()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.rdd.RDD
* @param featureArity Map: categorical feature index --> arity.
* I.e., the feature takes values in {0, ..., arity - 1}.
*/
private[tree] class DTMetadata(
private[tree] class DecisionTreeMetadata(
val numFeatures: Int,
val numExamples: Long,
val numClasses: Int,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be helpful if we add doc for numClasses mentioning its value for regression problems.

Expand All @@ -59,9 +59,9 @@ private[tree] class DTMetadata(

}

private[tree] object DTMetadata {
private[tree] object DecisionTreeMetadata {

def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DTMetadata = {
def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeMetadata = {

val numFeatures = input.take(1)(0).features.size
val numExamples = input.count()
Expand Down Expand Up @@ -93,7 +93,7 @@ private[tree] object DTMetadata {
}
}

new DTMetadata(numFeatures, numExamples, numClasses, maxBins,
new DecisionTreeMetadata(numFeatures, numExamples, numClasses, maxBins,
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet,
strategy.impurity, strategy.quantileCalculationStrategy)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ private[tree] object TreePoint {
def convertToTreeRDD(
input: RDD[LabeledPoint],
bins: Array[Array[Bin]],
metadata: DTMetadata): RDD[TreePoint] = {
metadata: DecisionTreeMetadata): RDD[TreePoint] = {
input.map { x =>
TreePoint.labeledPointToTreePoint(x, bins, metadata)
}
Expand All @@ -67,7 +67,7 @@ private[tree] object TreePoint {
private def labeledPointToTreePoint(
labeledPoint: LabeledPoint,
bins: Array[Array[Bin]],
metadata: DTMetadata): TreePoint = {
metadata: DecisionTreeMetadata): TreePoint = {

val numFeatures = labeledPoint.features.size
val numBins = bins(0).size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.impl.{DTMetadata, TreePoint}
import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint}
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node}
import org.apache.spark.mllib.linalg.Vectors
Expand Down Expand Up @@ -64,7 +64,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Gini, 3, 2, 100)
val metadata = DTMetadata.buildMetadata(rdd, strategy)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
assert(splits.length === 2)
assert(bins.length === 2)
Expand All @@ -83,7 +83,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
val metadata = DTMetadata.buildMetadata(rdd, strategy)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
assert(splits.length === 2)
assert(bins.length === 2)
Expand Down Expand Up @@ -164,7 +164,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
val metadata = DTMetadata.buildMetadata(rdd, strategy)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)

// Check splits.
Expand Down Expand Up @@ -282,7 +282,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 100,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
val metadata = DTMetadata.buildMetadata(rdd, strategy)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)

// Expecting 2^2 - 1 = 3 bins/splits
Expand Down Expand Up @@ -377,7 +377,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 100,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 10, 1-> 10))
val metadata = DTMetadata.buildMetadata(rdd, strategy)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)

// 2^10 - 1 > 100, so categorical variables will be ordered
Expand Down Expand Up @@ -433,7 +433,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
maxDepth = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
val metadata = DTMetadata.buildMetadata(rdd, strategy)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0,
Expand Down Expand Up @@ -462,7 +462,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
maxDepth = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
val metadata = DTMetadata.buildMetadata(rdd, strategy)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0,
Expand Down Expand Up @@ -502,7 +502,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Gini, 3, 2, 100)
val metadata = DTMetadata.buildMetadata(rdd, strategy)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
assert(splits.length === 2)
assert(splits(0).length === 99)
Expand All @@ -526,7 +526,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Gini, 3, 2, 100)
val metadata = DTMetadata.buildMetadata(rdd, strategy)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
assert(splits.length === 2)
assert(splits(0).length === 99)
Expand All @@ -551,7 +551,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
val metadata = DTMetadata.buildMetadata(rdd, strategy)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
assert(splits.length === 2)
assert(splits(0).length === 99)
Expand All @@ -576,7 +576,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
val metadata = DTMetadata.buildMetadata(rdd, strategy)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
assert(splits.length === 2)
assert(splits(0).length === 99)
Expand All @@ -601,7 +601,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
val metadata = DTMetadata.buildMetadata(rdd, strategy)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
assert(splits.length === 2)
assert(splits(0).length === 99)
Expand Down Expand Up @@ -653,7 +653,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
val metadata = DTMetadata.buildMetadata(rdd, strategy)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
assert(strategy.isMulticlassClassification)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
Expand Down Expand Up @@ -710,7 +710,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 3, maxBins = maxBins,
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
assert(strategy.isMulticlassClassification)
val metadata = DTMetadata.buildMetadata(rdd, strategy)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)

val model = DecisionTree.train(rdd, strategy)
validateClassifier(model, arr, 1.0)
Expand Down Expand Up @@ -739,7 +739,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClassesForClassification = 3)
assert(strategy.isMulticlassClassification)
val metadata = DTMetadata.buildMetadata(rdd, strategy)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)

val model = DecisionTree.train(rdd, strategy)
validateClassifier(model, arr, 0.9)
Expand All @@ -765,7 +765,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3))
assert(strategy.isMulticlassClassification)
val metadata = DTMetadata.buildMetadata(rdd, strategy)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)

val model = DecisionTree.train(rdd, strategy)
validateClassifier(model, arr, 0.9)
Expand All @@ -790,7 +790,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
assert(strategy.isMulticlassClassification)
val metadata = DTMetadata.buildMetadata(rdd, strategy)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)

val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
Expand Down