-
Notifications
You must be signed in to change notification settings - Fork 29k
MLI-1 Decision Trees #79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
cd53eae
92cedce
8bca1e2
0012a77
03f534c
dad0afc
4798aae
80e8c66
b0eb866
98ec8d5
02c595c
733d6dd
154aa77
b0e3e76
c8f6d60
e23c2e5
53108ed
6df35b9
dbb7ac1
d504eb1
6b7de78
b09dc98
c0e522b
f067d68
5841c28
0dd7659
dd0c0d7
9372779
84f85d6
d3023b3
63e786b
cd2c2b4
eb8fcbe
794ff4d
d1ef4f6
ad1fc21
62c2562
6068356
2116360
632818f
ff363a7
4576b64
24500c5
c487e6a
f963ef5
201702f
62dc723
e1dd86f
f536ae9
7d54b4f
1e8c704
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
Signed-off-by: Manish Amde <[email protected]>
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,7 +37,6 @@ class DecisionTree(val strategy : Strategy) { | |
| val (splits, bins) = DecisionTree.find_splits_bins(input, strategy) | ||
|
|
||
| //TODO: Level-wise training of tree and obtain Decision Tree model | ||
|
|
||
| val maxDepth = strategy.maxDepth | ||
|
|
||
| val maxNumNodes = scala.math.pow(2,maxDepth).toInt - 1 | ||
|
|
@@ -55,8 +54,20 @@ class DecisionTree(val strategy : Strategy) { | |
|
|
||
| } | ||
|
|
||
| object DecisionTree extends Logging { | ||
| object DecisionTree extends Serializable { | ||
|
|
||
| /* | ||
| Returns an Array[Split] of optimal splits for all nodes at a given level | ||
|
|
||
| @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data for DecisionTree | ||
| @param strategy [[org.apache.spark.mllib.tree.Strategy]] instance containing parameters for construction the DecisionTree | ||
| @param level Level of the tree | ||
| @param filters Filter for all nodes at a given level | ||
| @param splits possible splits for all features | ||
| @param bins possible bins for all features | ||
|
|
||
| @return Array[Split] instance for best splits for all nodes at a given level. | ||
| */ | ||
| def findBestSplits( | ||
| input : RDD[LabeledPoint], | ||
| strategy: Strategy, | ||
|
|
@@ -65,6 +76,16 @@ object DecisionTree extends Logging { | |
| splits : Array[Array[Split]], | ||
| bins : Array[Array[Bin]]) : Array[Split] = { | ||
|
|
||
| //TODO: Move these calculations outside | ||
| val numNodes = scala.math.pow(2, level).toInt | ||
| println("numNodes = " + numNodes) | ||
| //Find the number of features by looking at the first sample | ||
| val numFeatures = input.take(1)(0).features.length | ||
| println("numFeatures = " + numFeatures) | ||
| val numSplits = strategy.numSplits | ||
| println("numSplits = " + numSplits) | ||
|
|
||
| /*Find the filters used before reaching the current code*/ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use "/** ... */" |
||
| def findParentFilters(nodeIndex: Int): List[Filter] = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are several nested methods defined inside findBestSplits. Some of them are complex enough to have unit tests of their own. |
||
| if (level == 0) { | ||
| List[Filter]() | ||
|
|
@@ -75,6 +96,10 @@ object DecisionTree extends Logging { | |
| } | ||
| } | ||
|
|
||
| /*Find whether the sample is valid input for the current node. | ||
|
|
||
| In other words, does it pass through all the filters for the current node. | ||
| */ | ||
| def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = { | ||
|
|
||
| for (filter <- parentFilters) { | ||
|
|
@@ -91,79 +116,130 @@ object DecisionTree extends Logging { | |
| true | ||
| } | ||
|
|
||
| /*Finds the right bin for the given feature*/ | ||
| def findBin(featureIndex: Int, labeledPoint: LabeledPoint) : Int = { | ||
|
|
||
| //TODO: Do binary search | ||
| for (binIndex <- 0 until strategy.numSplits) { | ||
| val bin = bins(featureIndex)(binIndex) | ||
| //TODO: Remove this requirement post basic functional testing | ||
| require(bin.lowSplit.feature == featureIndex) | ||
| require(bin.highSplit.feature == featureIndex) | ||
| //TODO: Remove this requirement post basic functional | ||
| val lowThreshold = bin.lowSplit.threshold | ||
| val highThreshold = bin.highSplit.threshold | ||
| val features = labeledPoint.features | ||
| if ((lowThreshold < features(featureIndex)) & (highThreshold < features(featureIndex))) { | ||
| if ((lowThreshold <= features(featureIndex)) & (highThreshold > features(featureIndex))) { | ||
| return binIndex | ||
| } | ||
| } | ||
| throw new UnknownError("no bin was found.") | ||
|
|
||
| } | ||
| def findBinsForLevel: Array[Double] = { | ||
|
|
||
| val numNodes = scala.math.pow(2, level).toInt | ||
| //Find the number of features by looking at the first sample | ||
| val numFeatures = input.take(1)(0).features.length | ||
| /*Finds bins for all nodes (and all features) at a given level | ||
| k features, l nodes | ||
| Storage label, b11, b12, b13, .., bk, b21, b22, .. ,bl1, bl2, .. ,blk | ||
| Denotes invalid sample for tree by noting bin for feature 1 as -1 | ||
| */ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. need an extra space for indentation |
||
| def findBinsForLevel(labeledPoint : LabeledPoint) : Array[Double] = { | ||
|
|
||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. extra empty line |
||
| //TODO: Bit pack more by removing redundant label storage | ||
| // calculating bin index and label per feature per node | ||
| val arr = new Array[Double](2 * numFeatures * numNodes) | ||
| val arr = new Array[Double](1+(numFeatures * numNodes)) | ||
| arr(0) = labeledPoint.label | ||
| for (nodeIndex <- 0 until numNodes) { | ||
| val parentFilters = findParentFilters(nodeIndex) | ||
| //Find out whether the sample qualifies for the particular node | ||
| val sampleValid = isSampleValid(parentFilters, labeledPoint) | ||
| val shift = 2 * numFeatures * nodeIndex | ||
| if (sampleValid) { | ||
| val shift = 1 + numFeatures * nodeIndex | ||
| if (!sampleValid) { | ||
| //Add to invalid bin index -1 | ||
| for (featureIndex <- shift until (shift + numFeatures) by 2) { | ||
| arr(featureIndex + 1) = -1 | ||
| arr(featureIndex + 2) = labeledPoint.label | ||
| for (featureIndex <- 0 until numFeatures) { | ||
| arr(shift+featureIndex) = -1 | ||
| //TODO: Break since marking one bin is sufficient | ||
| } | ||
| } else { | ||
| for (featureIndex <- 0 until numFeatures) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. replace the for loop by a while loop |
||
| arr(shift + (featureIndex * 2) + 1) = findBin(featureIndex, labeledPoint) | ||
| arr(shift + (featureIndex * 2) + 2) = labeledPoint.label | ||
| //println("shift+featureIndex =" + (shift+featureIndex)) | ||
| arr(shift + featureIndex) = findBin(featureIndex, labeledPoint) | ||
| } | ||
| } | ||
|
|
||
| } | ||
| arr | ||
| } | ||
|
|
||
| val binMappedRDD = input.map(labeledPoint => findBinsForLevel) | ||
| /* | ||
| Performs a sequential aggreation over a partition | ||
|
|
||
| @param agg Array[Double] storing aggregate calculation of size numSplits*numFeatures*numNodes for classification | ||
| and 3*numSplits*numFeatures*numNodes for regression | ||
| @param arr Array[Double] of size 1+(numFeatures*numNodes) | ||
| @return Array[Double] storing aggregate calculation of size numSplits*numFeatures*numNodes for classification | ||
| and 3*numSplits*numFeatures*numNodes for regression | ||
| */ | ||
| def binSeqOp(agg : Array[Double], arr: Array[Double]) : Array[Double] = { | ||
| for (node <- 0 until numNodes) { | ||
| val validSignalIndex = 1+numFeatures*node | ||
| val isSampleValidForNode = if(arr(validSignalIndex) != -1) true else false | ||
| if(isSampleValidForNode) { | ||
| for (feature <- 0 until numFeatures){ | ||
| val arrShift = 1 + numFeatures*node | ||
| val aggShift = numSplits*numFeatures*node | ||
| val arrIndex = arrShift + feature | ||
| val aggIndex = aggShift + feature*numSplits + arr(arrIndex).toInt | ||
| agg(aggIndex) = agg(aggIndex) + 1 | ||
| } | ||
| } | ||
| } | ||
| agg | ||
| } | ||
|
|
||
| def binCombOp(par1 : Array[Double], par2: Array[Double]) : Array[Double] = { | ||
| par1 | ||
| } | ||
|
|
||
| println("input = " + input.count) | ||
| val binMappedRDD = input.map(x => findBinsForLevel(x)) | ||
| println("binMappedRDD.count = " + binMappedRDD.count) | ||
| //calculate bin aggregates | ||
|
|
||
| val binAggregates = binMappedRDD.aggregate(Array.fill[Double](numSplits*numFeatures*numNodes)(0))(binSeqOp,binCombOp) | ||
|
|
||
| //find best split | ||
| println("binAggregates.length = " + binAggregates.length) | ||
|
|
||
|
|
||
| Array[Split]() | ||
| val bestSplits = new Array[Split](numNodes) | ||
| for (node <- 0 until numNodes){ | ||
| val binsForNode = binAggregates.slice(node,numSplits*node) | ||
| } | ||
|
|
||
| bestSplits | ||
| } | ||
|
|
||
| /* | ||
| Returns split and bins for decision tree calculation. | ||
|
|
||
| @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data for DecisionTree | ||
| @param strategy [[org.apache.spark.mllib.tree.Strategy]] instance containing parameters for construction the DecisionTree | ||
| @return a tuple of (splits,bins) where Split is an Array[Array[Split]] of size (numFeatures,numSplits-1) and bins is an | ||
| Array[Array[Bin]] of size (numFeatures,numSplits1) | ||
| */ | ||
| def find_splits_bins(input : RDD[LabeledPoint], strategy : Strategy) : (Array[Array[Split]], Array[Array[Bin]]) = { | ||
|
|
||
| val numSplits = strategy.numSplits | ||
| logDebug("numSplits = " + numSplits) | ||
| println("numSplits = " + numSplits) | ||
|
|
||
| //Calculate the number of sample for approximate quantile calculation | ||
| //TODO: Justify this calculation | ||
| val requiredSamples = numSplits*numSplits | ||
| val count = input.count() | ||
| val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0 | ||
| logDebug("fraction of data used for calculating quantiles = " + fraction) | ||
| println("fraction of data used for calculating quantiles = " + fraction) | ||
|
|
||
| //sampled input for RDD calculation | ||
| val sampledInput = input.sample(false, fraction, 42).collect() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why fixing the seed here? It means certain item won't get selected. |
||
| val numSamples = sampledInput.length | ||
|
|
||
| //TODO: Remove this requirement | ||
| require(numSamples > numSplits, "length of input samples should be greater than numSplits") | ||
|
|
||
| //Find the number of features by looking at the first sample | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,7 @@ | |
| */ | ||
| package org.apache.spark.mllib.tree.impurity | ||
|
|
||
| trait Impurity { | ||
| trait Impurity extends Serializable { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The For a generic interface, an additional The |
||
|
|
||
| def calculate(c0 : Double, c1 : Double): Double | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. JavaDoc for public methods. |
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove extra blank line.