Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
cd53eae
skeletal framework
manishamde Nov 28, 2013
92cedce
basic building blocks for intermediate RDD calculation. untested.
manishamde Dec 2, 2013
8bca1e2
additional code for creating intermediate RDD
manishamde Dec 9, 2013
0012a77
basic stump working
manishamde Dec 10, 2013
03f534c
some more tests
manishamde Dec 10, 2013
dad0afc
decison stump functionality working
manishamde Dec 15, 2013
4798aae
added gain stats class
manishamde Dec 15, 2013
80e8c66
working version of multi-level split calculation
manishamde Dec 16, 2013
b0eb866
added logic to handle leaf nodes
manishamde Dec 16, 2013
98ec8d5
tree building and prediction logic
manishamde Dec 22, 2013
02c595c
added command line parsing
manishamde Dec 22, 2013
733d6dd
fixed tests
manishamde Dec 22, 2013
154aa77
enums for configurations
manishamde Dec 23, 2013
b0e3e76
adding enum for feature type
manishamde Jan 12, 2014
c8f6d60
adding enum for feature type
manishamde Jan 12, 2014
e23c2e5
added regression support
manishamde Jan 19, 2014
53108ed
fixing index for highest bin
manishamde Jan 20, 2014
6df35b9
regression predict logic
manishamde Jan 21, 2014
dbb7ac1
categorical feature support
manishamde Jan 23, 2014
d504eb1
more tests for categorical features
manishamde Jan 23, 2014
6b7de78
minor refactoring and tests
manishamde Jan 26, 2014
b09dc98
minor refactoring
manishamde Jan 26, 2014
c0e522b
updated predict and split threshold logic
manishamde Jan 27, 2014
f067d68
minor cleanup
manishamde Jan 27, 2014
5841c28
unit tests for categorical features
manishamde Jan 27, 2014
0dd7659
basic doc
manishamde Jan 27, 2014
dd0c0d7
minor: some docs
manishamde Jan 27, 2014
9372779
code style: max line lenght <= 100
manishamde Feb 17, 2014
84f85d6
code documentation
manishamde Feb 28, 2014
d3023b3
adding more docs for nested methods
manishamde Mar 6, 2014
63e786b
added multiple train methods for java compatability
manishamde Mar 6, 2014
cd2c2b4
fixing code style based on feedback
manishamde Mar 7, 2014
eb8fcbe
minor code style updates
manishamde Mar 7, 2014
794ff4d
minor improvements to docs and style
manishamde Mar 10, 2014
d1ef4f6
more documentation
manishamde Mar 10, 2014
ad1fc21
incorporated mengxr's code style suggestions
manishamde Mar 11, 2014
62c2562
fixing comment indentation
manishamde Mar 11, 2014
6068356
ensuring num bins is always greater than max number of categories
manishamde Mar 12, 2014
2116360
removing dummy bin calculation for categorical variables
manishamde Mar 12, 2014
632818f
removing threshold for classification predict method
manishamde Mar 13, 2014
ff363a7
binary search for bins and while loop for categorical feature bins
manishamde Mar 17, 2014
4576b64
documentation and for to while loop conversion
manishamde Mar 23, 2014
24500c5
minor style updates
mengxr Mar 23, 2014
c487e6a
Merge pull request #1 from mengxr/dtree
manishamde Mar 23, 2014
f963ef5
making methods private
manishamde Mar 23, 2014
201702f
making some more methods private
manishamde Mar 23, 2014
62dc723
updating javadoc and converting helper methods to package private to …
manishamde Mar 24, 2014
e1dd86f
implementing code style suggestions
manishamde Mar 25, 2014
f536ae9
another pass on code style
mengxr Mar 31, 2014
7d54b4f
Merge pull request #4 from mengxr/dtree
manishamde Mar 31, 2014
1e8c704
remove numBins field in the Strategy class
manishamde Apr 1, 2014
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
additional code for creating intermediate RDD
Signed-off-by: Manish Amde <[email protected]>
  • Loading branch information
manishamde committed Feb 28, 2014
commit 8bca1e20b703fd90bc6fcdbed5d36b42a0bdf66e
124 changes: 100 additions & 24 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ class DecisionTree(val strategy : Strategy) {
val (splits, bins) = DecisionTree.find_splits_bins(input, strategy)

//TODO: Level-wise training of tree and obtain Decision Tree model

val maxDepth = strategy.maxDepth

val maxNumNodes = scala.math.pow(2,maxDepth).toInt - 1
Expand All @@ -55,8 +54,20 @@ class DecisionTree(val strategy : Strategy) {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove extra blank line.

}

object DecisionTree extends Logging {
object DecisionTree extends Serializable {

/*
Returns an Array[Split] of optimal splits for all nodes at a given level

@param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data for DecisionTree
@param strategy [[org.apache.spark.mllib.tree.Strategy]] instance containing parameters for construction the DecisionTree
@param level Level of the tree
@param filters Filter for all nodes at a given level
@param splits possible splits for all features
@param bins possible bins for all features

@return Array[Split] instance for best splits for all nodes at a given level.
*/
def findBestSplits(
input : RDD[LabeledPoint],
strategy: Strategy,
Expand All @@ -65,6 +76,16 @@ object DecisionTree extends Logging {
splits : Array[Array[Split]],
bins : Array[Array[Bin]]) : Array[Split] = {

//TODO: Move these calculations outside
val numNodes = scala.math.pow(2, level).toInt
println("numNodes = " + numNodes)
//Find the number of features by looking at the first sample
val numFeatures = input.take(1)(0).features.length
println("numFeatures = " + numFeatures)
val numSplits = strategy.numSplits
println("numSplits = " + numSplits)

/*Find the filters used before reaching the current code*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use "/** ... */"

def findParentFilters(nodeIndex: Int): List[Filter] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are several nested methods defined inside findBestSplits. Some of them are complex enough to have unit tests of their own.

if (level == 0) {
List[Filter]()
Expand All @@ -75,6 +96,10 @@ object DecisionTree extends Logging {
}
}

/*Find whether the sample is valid input for the current node.

In other words, does it pass through all the filters for the current node.
*/
def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = {

for (filter <- parentFilters) {
Expand All @@ -91,79 +116,130 @@ object DecisionTree extends Logging {
true
}

/*Finds the right bin for the given feature*/
def findBin(featureIndex: Int, labeledPoint: LabeledPoint) : Int = {

//TODO: Do binary search
for (binIndex <- 0 until strategy.numSplits) {
val bin = bins(featureIndex)(binIndex)
//TODO: Remove this requirement post basic functional testing
require(bin.lowSplit.feature == featureIndex)
require(bin.highSplit.feature == featureIndex)
//TODO: Remove this requirement post basic functional
val lowThreshold = bin.lowSplit.threshold
val highThreshold = bin.highSplit.threshold
val features = labeledPoint.features
if ((lowThreshold < features(featureIndex)) & (highThreshold < features(featureIndex))) {
if ((lowThreshold <= features(featureIndex)) & (highThreshold > features(featureIndex))) {
return binIndex
}
}
throw new UnknownError("no bin was found.")

}
def findBinsForLevel: Array[Double] = {

val numNodes = scala.math.pow(2, level).toInt
//Find the number of features by looking at the first sample
val numFeatures = input.take(1)(0).features.length
/*Finds bins for all nodes (and all features) at a given level
k features, l nodes
Storage label, b11, b12, b13, .., bk, b21, b22, .. ,bl1, bl2, .. ,blk
Denotes invalid sample for tree by noting bin for feature 1 as -1
*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need an extra space for indentation

def findBinsForLevel(labeledPoint : LabeledPoint) : Array[Double] = {


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra empty line

//TODO: Bit pack more by removing redundant label storage
// calculating bin index and label per feature per node
val arr = new Array[Double](2 * numFeatures * numNodes)
val arr = new Array[Double](1+(numFeatures * numNodes))
arr(0) = labeledPoint.label
for (nodeIndex <- 0 until numNodes) {
val parentFilters = findParentFilters(nodeIndex)
//Find out whether the sample qualifies for the particular node
val sampleValid = isSampleValid(parentFilters, labeledPoint)
val shift = 2 * numFeatures * nodeIndex
if (sampleValid) {
val shift = 1 + numFeatures * nodeIndex
if (!sampleValid) {
//Add to invalid bin index -1
for (featureIndex <- shift until (shift + numFeatures) by 2) {
arr(featureIndex + 1) = -1
arr(featureIndex + 2) = labeledPoint.label
for (featureIndex <- 0 until numFeatures) {
arr(shift+featureIndex) = -1
//TODO: Break since marking one bin is sufficient
}
} else {
for (featureIndex <- 0 until numFeatures) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace the for loop by a while loop

arr(shift + (featureIndex * 2) + 1) = findBin(featureIndex, labeledPoint)
arr(shift + (featureIndex * 2) + 2) = labeledPoint.label
//println("shift+featureIndex =" + (shift+featureIndex))
arr(shift + featureIndex) = findBin(featureIndex, labeledPoint)
}
}

}
arr
}

val binMappedRDD = input.map(labeledPoint => findBinsForLevel)
/*
Performs a sequential aggreation over a partition

@param agg Array[Double] storing aggregate calculation of size numSplits*numFeatures*numNodes for classification
and 3*numSplits*numFeatures*numNodes for regression
@param arr Array[Double] of size 1+(numFeatures*numNodes)
@return Array[Double] storing aggregate calculation of size numSplits*numFeatures*numNodes for classification
and 3*numSplits*numFeatures*numNodes for regression
*/
def binSeqOp(agg : Array[Double], arr: Array[Double]) : Array[Double] = {
for (node <- 0 until numNodes) {
val validSignalIndex = 1+numFeatures*node
val isSampleValidForNode = if(arr(validSignalIndex) != -1) true else false
if(isSampleValidForNode) {
for (feature <- 0 until numFeatures){
val arrShift = 1 + numFeatures*node
val aggShift = numSplits*numFeatures*node
val arrIndex = arrShift + feature
val aggIndex = aggShift + feature*numSplits + arr(arrIndex).toInt
agg(aggIndex) = agg(aggIndex) + 1
}
}
}
agg
}

def binCombOp(par1 : Array[Double], par2: Array[Double]) : Array[Double] = {
par1
}

println("input = " + input.count)
val binMappedRDD = input.map(x => findBinsForLevel(x))
println("binMappedRDD.count = " + binMappedRDD.count)
//calculate bin aggregates

val binAggregates = binMappedRDD.aggregate(Array.fill[Double](numSplits*numFeatures*numNodes)(0))(binSeqOp,binCombOp)

//find best split
println("binAggregates.length = " + binAggregates.length)


Array[Split]()
val bestSplits = new Array[Split](numNodes)
for (node <- 0 until numNodes){
val binsForNode = binAggregates.slice(node,numSplits*node)
}

bestSplits
}

/*
Returns split and bins for decision tree calculation.

@param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data for DecisionTree
@param strategy [[org.apache.spark.mllib.tree.Strategy]] instance containing parameters for construction the DecisionTree
@return a tuple of (splits,bins) where Split is an Array[Array[Split]] of size (numFeatures,numSplits-1) and bins is an
Array[Array[Bin]] of size (numFeatures,numSplits1)
*/
def find_splits_bins(input : RDD[LabeledPoint], strategy : Strategy) : (Array[Array[Split]], Array[Array[Bin]]) = {

val numSplits = strategy.numSplits
logDebug("numSplits = " + numSplits)
println("numSplits = " + numSplits)

//Calculate the number of sample for approximate quantile calculation
//TODO: Justify this calculation
val requiredSamples = numSplits*numSplits
val count = input.count()
val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0
logDebug("fraction of data used for calculating quantiles = " + fraction)
println("fraction of data used for calculating quantiles = " + fraction)

//sampled input for RDD calculation
val sampledInput = input.sample(false, fraction, 42).collect()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why fixing the seed here? It means certain item won't get selected.

val numSamples = sampledInput.length

//TODO: Remove this requirement
require(numSamples > numSplits, "length of input samples should be greater than numSplits")

//Find the number of features by looking at the first sample
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.mllib.tree

import org.apache.spark.mllib.tree.impurity.Impurity

class Strategy (
case class Strategy (
val kind : String,
val impurity : Impurity,
val maxDepth : Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package org.apache.spark.mllib.tree.impurity

trait Impurity {
trait Impurity extends Serializable {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Impurity should be renamed to Error or something more technical and familiar. Also see the comments earlier for the necessity and example design of a generic Error interface.

The calculate method can be renamed to something verbose like error.

For a generic interface, an additional ErrorStats trait and error(errorStats: ErrorStats) method can be added. For example, Variance or more aptly, SquareError, would implement case class SquareErrorStats(count: Long, mean: Double, meanSquare: Double) and error(errorStats) = errorStats.meanSquare - errorStats.mean * errorStats.mean / count. Note that ErrorStats should have aggregation methods, e.g., it's easy to see the implementation for SquareErrorStats.

The Variance class should be renamed to SquareError, Entropy to EntropyError or KLDivergence, Gini to GiniError.


def calculate(c0 : Double, c1 : Double): Double
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JavaDoc for public methods.


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.jblas._
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.impurity.Gini
import org.apache.spark.mllib.tree.model.Filter

class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {

Expand All @@ -54,6 +55,23 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(bins(0).length==100)
println(splits(1)(98))
}

test("stump"){
val arr = DecisionTreeSuite.generateReverseOrderedLabeledPoints()
assert(arr.length == 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy("regression",Gini,3,100,"sort")
val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy)
assert(splits.length==2)
assert(splits(0).length==99)
assert(bins.length==2)
assert(bins(0).length==100)
assert(splits(0).length==99)
assert(bins(0).length==100)
println(splits(1)(98))
DecisionTree.findBestSplits(rdd,strategy,0,Array[List[Filter]](),splits,bins)
}

}

object DecisionTreeSuite {
Expand Down