Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
cd53eae
skeletal framework
manishamde Nov 28, 2013
92cedce
basic building blocks for intermediate RDD calculation. untested.
manishamde Dec 2, 2013
8bca1e2
additional code for creating intermediate RDD
manishamde Dec 9, 2013
0012a77
basic stump working
manishamde Dec 10, 2013
03f534c
some more tests
manishamde Dec 10, 2013
dad0afc
decison stump functionality working
manishamde Dec 15, 2013
4798aae
added gain stats class
manishamde Dec 15, 2013
80e8c66
working version of multi-level split calculation
manishamde Dec 16, 2013
b0eb866
added logic to handle leaf nodes
manishamde Dec 16, 2013
98ec8d5
tree building and prediction logic
manishamde Dec 22, 2013
02c595c
added command line parsing
manishamde Dec 22, 2013
733d6dd
fixed tests
manishamde Dec 22, 2013
154aa77
enums for configurations
manishamde Dec 23, 2013
b0e3e76
adding enum for feature type
manishamde Jan 12, 2014
c8f6d60
adding enum for feature type
manishamde Jan 12, 2014
e23c2e5
added regression support
manishamde Jan 19, 2014
53108ed
fixing index for highest bin
manishamde Jan 20, 2014
6df35b9
regression predict logic
manishamde Jan 21, 2014
dbb7ac1
categorical feature support
manishamde Jan 23, 2014
d504eb1
more tests for categorical features
manishamde Jan 23, 2014
6b7de78
minor refactoring and tests
manishamde Jan 26, 2014
b09dc98
minor refactoring
manishamde Jan 26, 2014
c0e522b
updated predict and split threshold logic
manishamde Jan 27, 2014
f067d68
minor cleanup
manishamde Jan 27, 2014
5841c28
unit tests for categorical features
manishamde Jan 27, 2014
0dd7659
basic doc
manishamde Jan 27, 2014
dd0c0d7
minor: some docs
manishamde Jan 27, 2014
9372779
code style: max line lenght <= 100
manishamde Feb 17, 2014
84f85d6
code documentation
manishamde Feb 28, 2014
d3023b3
adding more docs for nested methods
manishamde Mar 6, 2014
63e786b
added multiple train methods for java compatability
manishamde Mar 6, 2014
cd2c2b4
fixing code style based on feedback
manishamde Mar 7, 2014
eb8fcbe
minor code style updates
manishamde Mar 7, 2014
794ff4d
minor improvements to docs and style
manishamde Mar 10, 2014
d1ef4f6
more documentation
manishamde Mar 10, 2014
ad1fc21
incorporated mengxr's code style suggestions
manishamde Mar 11, 2014
62c2562
fixing comment indentation
manishamde Mar 11, 2014
6068356
ensuring num bins is always greater than max number of categories
manishamde Mar 12, 2014
2116360
removing dummy bin calculation for categorical variables
manishamde Mar 12, 2014
632818f
removing threshold for classification predict method
manishamde Mar 13, 2014
ff363a7
binary search for bins and while loop for categorical feature bins
manishamde Mar 17, 2014
4576b64
documentation and for to while loop conversion
manishamde Mar 23, 2014
24500c5
minor style updates
mengxr Mar 23, 2014
c487e6a
Merge pull request #1 from mengxr/dtree
manishamde Mar 23, 2014
f963ef5
making methods private
manishamde Mar 23, 2014
201702f
making some more methods private
manishamde Mar 23, 2014
62dc723
updating javadoc and converting helper methods to package private to …
manishamde Mar 24, 2014
e1dd86f
implementing code style suggestions
manishamde Mar 25, 2014
f536ae9
another pass on code style
mengxr Mar 31, 2014
7d54b4f
Merge pull request #4 from mengxr/dtree
manishamde Mar 31, 2014
1e8c704
remove numBins field in the Strategy class
manishamde Apr 1, 2014
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
code documentation
Signed-off-by: Manish Amde <[email protected]>
  • Loading branch information
manishamde committed Feb 28, 2014
commit 84f85d6d0a1fe7ed60149cc6b29a9ff76ef09abd
31 changes: 18 additions & 13 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.Algo._

/*
/**
A class that implements a decision tree algorithm for classification and regression.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use JavaDoc style.

It supports both continuous and categorical features.

Expand All @@ -40,7 +40,7 @@ quantile calculation strategy, etc.
*/
class DecisionTree(val strategy : Strategy) extends Serializable with Logging {

/*
/**
Method to train a decision tree model over an RDD

@param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
Expand Down Expand Up @@ -157,14 +157,14 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging {

object DecisionTree extends Serializable with Logging {

/*
/**
Returns an Array[Split] of optimal splits for all nodes at a given level

@param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
for DecisionTree
for DecisionTree
@param parentImpurities Impurities for all parent nodes for the current level
@param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
parameters for construction the DecisionTree
parameters for construction the DecisionTree
@param level Level of the tree
@param filters Filter for all nodes at a given level
@param splits possible splits for all features
Expand Down Expand Up @@ -200,7 +200,7 @@ object DecisionTree extends Serializable with Logging {
}
}

/*
/**
Find whether the sample is valid input for the current node.
In other words, does it pass through all the filters for the current node.
*/
Expand Down Expand Up @@ -236,7 +236,9 @@ object DecisionTree extends Serializable with Logging {
true
}

/*Finds the right bin for the given feature*/
/**
Finds the right bin for the given feature
*/
def findBin(featureIndex: Int, labeledPoint: LabeledPoint, isFeatureContinuous : Boolean) : Int = {

if (isFeatureContinuous){
Expand Down Expand Up @@ -266,7 +268,8 @@ object DecisionTree extends Serializable with Logging {

}

/*Finds bins for all nodes (and all features) at a given level
/**
Finds bins for all nodes (and all features) at a given level
k features, l nodes (level = log2(l))
Storage label, b11, b12, b13, .., bk, b21, b22, .. ,bl1, bl2, .. ,blk
Denotes invalid sample for tree by noting bin for feature 1 as -1
Expand Down Expand Up @@ -343,7 +346,8 @@ object DecisionTree extends Serializable with Logging {
}
}

/*Performs a sequential aggregation over a partition.
/**
Performs a sequential aggregation over a partition.

for p bins, k features, l nodes (level = log2(l)) storage is of the form:
b111_left_count,b111_right_count, .... , ..
Expand All @@ -370,7 +374,8 @@ object DecisionTree extends Serializable with Logging {
}
logDebug("binAggregateLength = " + binAggregateLength)

/*Combines the aggregates from partitions
/**
Combines the aggregates from partitions
@param agg1 Array containing aggregates from one or more partitions
@param agg2 Array containing aggregates from one or more partitions

Expand Down Expand Up @@ -507,7 +512,7 @@ object DecisionTree extends Serializable with Logging {
}
}

/*
/**
Extracts left and right split aggregates

@param binData Array[Double] of size 2*numFeatures*numSplits
Expand Down Expand Up @@ -604,7 +609,7 @@ object DecisionTree extends Serializable with Logging {
gains
}

/*
/**
Find the best split for a node given bin aggregate data

@param binData Array[Double] of size 2*numSplits*numFeatures
Expand Down Expand Up @@ -669,7 +674,7 @@ object DecisionTree extends Serializable with Logging {
bestSplits
}

/*
/**
Returns split and bins for decision tree calculation.

@param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
*/
package org.apache.spark.mllib.tree.configuration

/**
* Enum to select the algorithm for the decision tree
*/
object Algo extends Enumeration {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Algorithm Enumeration seems redundant given Impurity which implies the Algorithm anyway.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The various Enumeration classes in mllib.tree.configuration package are neat. A uniform design pattern for parameters and options should be used for MLLib and Spark, and this could be a start. Alternatively, if there is an existing pattern in use, it should be followed for decision tree as well.

type Algo = Value
val Classification, Regression = Value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
*/
package org.apache.spark.mllib.tree.configuration

/**
* Enum to describe whether a feature is "continuous" or "categorical"
*/
object FeatureType extends Enumeration {
type FeatureType = Value
val Continuous, Categorical = Value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
*/
package org.apache.spark.mllib.tree.configuration

/**
* Enum for selecting the quantile calculation strategy
*/
object QuantileStrategy extends Enumeration {
type QuantileStrategy = Value
val Sort, MinMax, ApproxHist = Value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@ import org.apache.spark.mllib.tree.impurity.Impurity
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._

/**
* Stores all the configuration options for tree construction
* @param algo classification or regression
* @param impurity criterion used for information gain calculation
* @param maxDepth maximum depth of the tree
* @param maxBins maximum number of bins used for splitting features
* @param quantileCalculationStrategy algorithm for calculating quantiles
* @param categoricalFeaturesInfo A map storing information about the categorical variables and the
* number of discrete values they take. For example, an entry (n ->
* k) implies the feature n is categorical with k categories 0,
* 1, 2, ... , k-1. It's important to note that features are
* zero-indexed.
*/
class Strategy (
val algo : Algo,
val impurity : Impurity,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,20 @@ package org.apache.spark.mllib.tree.impurity

import javax.naming.OperationNotSupportedException

/**
* Class for calculating [[http://en.wikipedia.org/wiki/Binary_entropy_function entropy]] during
* binary classification.
*/
object Entropy extends Impurity {

def log2(x: Double) = scala.math.log(x) / scala.math.log(2)

/**
* entropy calculation
* @param c0 count of instances with label 0
* @param c1 count of instances with label 1
* @return entropy value
*/
def calculate(c0: Double, c1: Double): Double = {
if (c0 == 0 || c1 == 0) {
0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,18 @@ package org.apache.spark.mllib.tree.impurity

import javax.naming.OperationNotSupportedException
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should use java.lang.UnsupportedOperationException.


/**
* Class for calculating the [[http://en.wikipedia.org/wiki/Gini_coefficient Gini
* coefficent]] during binary classification
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the wiki page for Gini coefficient, which is different from Gini impurity. Should change to

http://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity

*/
object Gini extends Impurity {

/**
* gini coefficient calculation
* @param c0 count of instances with label 0
* @param c1 count of instances with label 1
* @return gini coefficient value
*/
def calculate(c0 : Double, c1 : Double): Double = {
if (c0 == 0 || c1 == 0) {
0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,19 @@ package org.apache.spark.mllib.tree.impurity
import javax.naming.OperationNotSupportedException
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should use java.lang.UnsupportedOperationException and organize imports.

import org.apache.spark.Logging

/**
* Class for calculating variance during regression
*/
object Variance extends Impurity with Logging {
def calculate(c0: Double, c1: Double): Double = throw new OperationNotSupportedException("Variance.calculate")

/**
* variance calculation
* @param count number of instances
* @param sum sum of labels
* @param sumSquares summation of squares of the labels
* @return
*/
def calculate(count: Double, sum: Double, sumSquares: Double): Double = {
val squaredLoss = sumSquares - (sum*sum)/count
squaredLoss/count
Expand Down
11 changes: 11 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@ package org.apache.spark.mllib.tree.model

import org.apache.spark.mllib.tree.configuration.FeatureType._

/**
* Used for "binning" the features bins for faster best split calculation. For a continuous
* feature, a bin is determined by a low and a high "split". For a categorical feature,
* the a bin is determined using a single label value (category).
* @param lowSplit signifying the lower threshold for the continuous feature to be
* accepted in the bin
* @param highSplit signifying the upper threshold for the continuous feature to be
* accepted in the bin
* @param featureType type of feature -- categorical or continuous
* @param category categorical label value accepted in the bin
*/
case class Bin(lowSplit : Split, highSplit : Split, featureType : FeatureType, category : Double) {

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,23 @@
*/
package org.apache.spark.mllib.tree.model

import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.rdd.RDD

/**
* Model to store the decision tree parameters
* @param topNode root node
* @param algo algorithm type -- classification or regression
*/
class DecisionTreeModel(val topNode : Node, val algo : Algo) extends Serializable {

def predict(features : Array[Double]) = {
/**
* Predict values for a single data point using the model trained.
*
* @param features array representing a single data point
* @return Double prediction from the trained model
*/
def predict(features : Array[Double]) : Double = {
algo match {
case Classification => {
if (topNode.predictIfLeaf(features) < 0.5) 0.0 else 1.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question was asked by @srowen : It is easy to support multi-class?

Also, why 0.5 is used here as the threshold?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0.5 was just a threshold used for verification. I will make it configurable and for now return a double value between 0 and 1 similar to other classification algorithms in mllib. This will make it easier for performing ROC/AUC calculations.

Expand All @@ -32,4 +43,15 @@ class DecisionTreeModel(val topNode : Node, val algo : Algo) extends Serializabl
}
}

/**
* Predict values for the given data set using the model trained.
*
* @param features RDD representing data points to be predicted
* @return RDD[Int] where each entry contains the corresponding prediction
*/
def predict(features: RDD[Array[Double]]): RDD[Double] = {
features.map(x => predict(x))
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
*/
package org.apache.spark.mllib.tree.model

/**
* Filter specifying a split and type of comparison to be applied on features
* @param split split specifying the feature index, type and threshold
* @param comparison integer specifying <,=,>
*/
case class Filter(split : Split, comparison : Int) {
// Comparison -1,0,1 signifies <.=,>
override def toString = " split = " + split + "comparison = " + comparison
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@
*/
package org.apache.spark.mllib.tree.model

/**
* Information gain statistics for each split
* @param gain information gain value
* @param impurity current node impurity
* @param leftImpurity left node impurity
* @param rightImpurity right node impurity
* @param predict predicted value
*/
class InformationGainStats(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InformationGainStats and Split nicely separate the members of Node, but can also be flattened and put at top level. Would make storage and explanation slightly easier, albeit less unstructured.

val gain : Double,
val impurity: Double,
Expand Down
10 changes: 10 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ import org.apache.spark.Logging
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.FeatureType._

/**
* Node in a decision tree
* @param id integer node id
* @param predict predicted value at the node
* @param isLeaf whether the leaf is a node
* @param split split to calculate left and right nodes
* @param leftNode left child
* @param rightNode right child
* @param stats information gain stats
*/
class Node ( val id : Int,
val predict : Double,
val isLeaf : Boolean,
Expand Down
35 changes: 29 additions & 6 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ package org.apache.spark.mllib.tree.model

import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType

/**
* Split applied to a feature
* @param feature feature index
* @param threshold threshold for continuous feature
* @param featureType type of feature -- categorical or continuous
* @param categories accepted values for categorical variables
*/
case class Split(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The functionality of Split can be simplified by a modification. If I understand correctly, Split represents the left or right (low or high) branch of the parent node. Instead, it suffices to store the branching condition for each node as a splitting condition. This can be appropriately named as SplitPredicate or SplittingCondition or branching condition and consist of feature id, feature type (continuous or categorical), threshold, left branch categories and right branch categories.

I think depending on the choice here, we require Filter, but nonetheless I think it's redundant and we should exploit the recursive/linked structure of tree, which we are doing anyway.

feature: Int,
threshold : Double,
Expand All @@ -29,12 +36,28 @@ case class Split(
", categories = " + categories
}

class DummyLowSplit(feature: Int, kind : FeatureType)
extends Split(feature, Double.MinValue, kind, List())
/**
* Split with minimum threshold for continuous features. Helps with the smallest bin creation.
* @param feature feature index
* @param featureType type of feature -- categorical or continuous
*/
class DummyLowSplit(feature: Int, featureType : FeatureType)
extends Split(feature, Double.MinValue, featureType, List())

class DummyHighSplit(feature: Int, kind : FeatureType)
extends Split(feature, Double.MaxValue, kind, List())
/**
* Split with maximum threshold for continuous features. Helps with the highest bin creation.
* @param feature feature index
* @param featureType type of feature -- categorical or continuous
*/
class DummyHighSplit(feature: Int, featureType : FeatureType)
extends Split(feature, Double.MaxValue, featureType, List())

class DummyCategoricalSplit(feature: Int, kind : FeatureType)
extends Split(feature, Double.MaxValue, kind, List())
/**
* Split with no acceptable feature values for categorical features. Helps with the first bin
* creation.
* @param feature feature index
* @param featureType type of feature -- categorical or continuous
*/
class DummyCategoricalSplit(feature: Int, featureType : FeatureType)
extends Split(feature, Double.MaxValue, featureType, List())