-
Notifications
You must be signed in to change notification settings - Fork 29k
MLI-1 Decision Trees #79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
cd53eae
92cedce
8bca1e2
0012a77
03f534c
dad0afc
4798aae
80e8c66
b0eb866
98ec8d5
02c595c
733d6dd
154aa77
b0e3e76
c8f6d60
e23c2e5
53108ed
6df35b9
dbb7ac1
d504eb1
6b7de78
b09dc98
c0e522b
f067d68
5841c28
0dd7659
dd0c0d7
9372779
84f85d6
d3023b3
63e786b
cd2c2b4
eb8fcbe
794ff4d
d1ef4f6
ad1fc21
62c2562
6068356
2116360
632818f
ff363a7
4576b64
24500c5
c487e6a
f963ef5
201702f
62dc723
e1dd86f
f536ae9
7d54b4f
1e8c704
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,23 +50,34 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log | |
|
|
||
| //Cache input RDD for speedup during multiple passes | ||
| input.cache() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the current implementation of other algorithms in MLlib, we let users to choose whether the data should be cached or not. How many passes does your algorithm need? |
||
| logDebug("algo = " + strategy.algo) | ||
|
|
||
| //Finding the splits and the corresponding bins (interval between the splits) using a sample | ||
| // of the input data. | ||
| val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) | ||
| logDebug("numSplits = " + bins(0).length) | ||
|
|
||
| //Noting numBins for the input data | ||
| strategy.numBins = bins(0).length | ||
|
|
||
| //The depth of the decision tree | ||
| val maxDepth = strategy.maxDepth | ||
|
|
||
| //The max number of nodes possible given the depth of the tree | ||
| val maxNumNodes = scala.math.pow(2, maxDepth).toInt - 1 | ||
| //Initalizing an array to hold filters applied to points for each node | ||
| val filters = new Array[List[Filter]](maxNumNodes) | ||
| //The filter at the top node is an empty list | ||
| filters(0) = List() | ||
| //Initializing an array to hold parent impurity calculations for each node | ||
| val parentImpurities = new Array[Double](maxNumNodes) | ||
| //Dummy value for top node (updated during first split calculation) | ||
| //parentImpurities(0) = Double.MinValue | ||
| val nodes = new Array[Node](maxNumNodes) | ||
|
|
||
| logDebug("algo = " + strategy.algo) | ||
|
|
||
| //The main-idea here is to perform level-wise training of the decision tree nodes thus | ||
| // reducing the passes over the data from l to log2(l) where l is the total number of nodes. | ||
| // Each data sample is checked for validity w.r.t to each node at a given level -- i.e., | ||
| // the sample is only used for the split calculation at the node if the sampled would have | ||
| // still survived the filters of the parent nodes. | ||
| breakable { | ||
| for (level <- 0 until maxDepth) { | ||
|
|
||
|
|
@@ -79,36 +90,41 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log | |
| level, filters, splits, bins) | ||
|
|
||
| for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { | ||
|
|
||
| //Extract info for nodes at the current level | ||
| extractNodeInfo(nodeSplitStats, level, index, nodes) | ||
| //Extract info for nodes at the next lower level | ||
| extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, | ||
| filters) | ||
| logDebug("final best split = " + nodeSplitStats._1) | ||
|
|
||
| } | ||
| require(scala.math.pow(2, level) == splitsStatsForLevel.length) | ||
|
|
||
| //Check whether all the nodes at the current level at leaves | ||
| val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) | ||
| logDebug("all leaf = " + allLeaf) | ||
| if (allLeaf) break | ||
| if (allLeaf) break //no more tree construction | ||
|
|
||
| } | ||
| } | ||
|
|
||
| //Initialize the top or root node of the tree | ||
| val topNode = nodes(0) | ||
| //Build the full tree using the node info calculated in the level-wise best split calculations | ||
| topNode.build(nodes) | ||
|
|
||
| val decisionTreeModel = { | ||
| return new DecisionTreeModel(topNode, strategy.algo) | ||
| } | ||
| return decisionTreeModel | ||
| //Return a decision tree model | ||
| return new DecisionTreeModel(topNode, strategy.algo) | ||
| } | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove extra blank line. |
||
|
|
||
| /** | ||
| * Extract the decision tree node information for th given tree level and node index | ||
| */ | ||
| private def extractNodeInfo( | ||
| nodeSplitStats: (Split, InformationGainStats), | ||
| level: Int, index: Int, | ||
| nodes: Array[Node]) { | ||
| level: Int, | ||
| index: Int, | ||
| nodes: Array[Node]) | ||
| : Unit = { | ||
|
|
||
| val split = nodeSplitStats._1 | ||
| val stats = nodeSplitStats._2 | ||
|
|
@@ -119,35 +135,37 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log | |
| nodes(nodeIndex) = node | ||
| } | ||
|
|
||
| /** | ||
| * Extract the decision tree node information for the children of the node | ||
| */ | ||
| private def extractInfoForLowerLevels( | ||
| level: Int, | ||
| index: Int, | ||
| maxDepth: Int, | ||
| nodeSplitStats: (Split, InformationGainStats), | ||
| parentImpurities: Array[Double], | ||
| filters: Array[List[Filter]]) { | ||
| filters: Array[List[Filter]]) | ||
| : Unit = { | ||
|
|
||
| // 0 corresponds to the left child node and 1 corresponds to the right child node. | ||
| for (i <- 0 to 1) { | ||
|
|
||
| //Calculating the index of the node from the node level and the index at the current level | ||
| val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i | ||
|
|
||
| if (level < maxDepth - 1) { | ||
|
|
||
| val impurity = if (i == 0) { | ||
| nodeSplitStats._2.leftImpurity | ||
| } else { | ||
| nodeSplitStats._2.rightImpurity | ||
| } | ||
|
|
||
| logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity) | ||
| //noting the parent impurities | ||
| parentImpurities(nodeIndex) = impurity | ||
| //noting the parents filters for the child nodes | ||
| val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1) | ||
| filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2) | ||
|
|
||
| for (filter <- filters(nodeIndex)) { | ||
| logDebug("Filter = " + filter) | ||
| } | ||
|
|
||
| } | ||
| } | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put an extra space after "//".