-
Notifications
You must be signed in to change notification settings - Fork 29k
MLI-1 Decision Trees #79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
cd53eae
92cedce
8bca1e2
0012a77
03f534c
dad0afc
4798aae
80e8c66
b0eb866
98ec8d5
02c595c
733d6dd
154aa77
b0e3e76
c8f6d60
e23c2e5
53108ed
6df35b9
dbb7ac1
d504eb1
6b7de78
b09dc98
c0e522b
f067d68
5841c28
0dd7659
dd0c0d7
9372779
84f85d6
d3023b3
63e786b
cd2c2b4
eb8fcbe
794ff4d
d1ef4f6
ad1fc21
62c2562
6068356
2116360
632818f
ff363a7
4576b64
24500c5
c487e6a
f963ef5
201702f
62dc723
e1dd86f
f536ae9
7d54b4f
1e8c704
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
Signed-off-by: Manish Amde <[email protected]>
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,9 @@ | |
| */ | ||
| package org.apache.spark.mllib.tree.configuration | ||
|
|
||
| /** | ||
| * Enum to select the algorithm for the decision tree | ||
| */ | ||
| object Algo extends Enumeration { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The various |
||
| type Algo = Value | ||
| val Classification, Regression = Value | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,8 +18,18 @@ package org.apache.spark.mllib.tree.impurity | |
|
|
||
| import javax.naming.OperationNotSupportedException | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should use |
||
|
|
||
| /** | ||
| * Class for calculating the [[http://en.wikipedia.org/wiki/Gini_coefficient Gini | ||
| * coefficent]] during binary classification | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the wiki page for http://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity |
||
| */ | ||
| object Gini extends Impurity { | ||
|
|
||
| /** | ||
| * gini coefficient calculation | ||
| * @param c0 count of instances with label 0 | ||
| * @param c1 count of instances with label 1 | ||
| * @return gini coefficient value | ||
| */ | ||
| def calculate(c0 : Double, c1 : Double): Double = { | ||
| if (c0 == 0 || c1 == 0) { | ||
| 0 | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,9 +19,19 @@ package org.apache.spark.mllib.tree.impurity | |
| import javax.naming.OperationNotSupportedException | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should use |
||
| import org.apache.spark.Logging | ||
|
|
||
| /** | ||
| * Class for calculating variance during regression | ||
| */ | ||
| object Variance extends Impurity with Logging { | ||
| def calculate(c0: Double, c1: Double): Double = throw new OperationNotSupportedException("Variance.calculate") | ||
|
|
||
| /** | ||
| * variance calculation | ||
| * @param count number of instances | ||
| * @param sum sum of labels | ||
| * @param sumSquares summation of squares of the labels | ||
| * @return | ||
| */ | ||
| def calculate(count: Double, sum: Double, sumSquares: Double): Double = { | ||
| val squaredLoss = sumSquares - (sum*sum)/count | ||
| squaredLoss/count | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,12 +16,23 @@ | |
| */ | ||
| package org.apache.spark.mllib.tree.model | ||
|
|
||
| import org.apache.spark.mllib.regression.LabeledPoint | ||
| import org.apache.spark.mllib.tree.configuration.Algo._ | ||
| import org.apache.spark.rdd.RDD | ||
|
|
||
| /** | ||
| * Model to store the decision tree parameters | ||
| * @param topNode root node | ||
| * @param algo algorithm type -- classification or regression | ||
| */ | ||
| class DecisionTreeModel(val topNode : Node, val algo : Algo) extends Serializable { | ||
|
|
||
| def predict(features : Array[Double]) = { | ||
| /** | ||
| * Predict values for a single data point using the model trained. | ||
| * | ||
| * @param features array representing a single data point | ||
| * @return Double prediction from the trained model | ||
| */ | ||
| def predict(features : Array[Double]) : Double = { | ||
| algo match { | ||
| case Classification => { | ||
| if (topNode.predictIfLeaf(features) < 0.5) 0.0 else 1.0 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same question was asked by @srowen : It is easy to support multi-class? Also, why 0.5 is used here as the threshold?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 0.5 was just a threshold used for verification. I will make it configurable and for now return a double value between 0 and 1 similar to other classification algorithms in mllib. This will make it easier for performing ROC/AUC calculations. |
||
|
|
@@ -32,4 +43,15 @@ class DecisionTreeModel(val topNode : Node, val algo : Algo) extends Serializabl | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * Predict values for the given data set using the model trained. | ||
| * | ||
| * @param features RDD representing data points to be predicted | ||
| * @return RDD[Int] where each entry contains the corresponding prediction | ||
| */ | ||
| def predict(features: RDD[Array[Double]]): RDD[Double] = { | ||
| features.map(x => predict(x)) | ||
| } | ||
|
|
||
|
|
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,14 @@ | |
| */ | ||
| package org.apache.spark.mllib.tree.model | ||
|
|
||
| /** | ||
| * Information gain statistics for each split | ||
| * @param gain information gain value | ||
| * @param impurity current node impurity | ||
| * @param leftImpurity left node impurity | ||
| * @param rightImpurity right node impurity | ||
| * @param predict predicted value | ||
| */ | ||
| class InformationGainStats( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| val gain : Double, | ||
| val impurity: Double, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,13 @@ package org.apache.spark.mllib.tree.model | |
|
|
||
| import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType | ||
|
|
||
| /** | ||
| * Split applied to a feature | ||
| * @param feature feature index | ||
| * @param threshold threshold for continuous feature | ||
| * @param featureType type of feature -- categorical or continuous | ||
| * @param categories accepted values for categorical variables | ||
| */ | ||
| case class Split( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The functionality of I think depending on the choice here, we require |
||
| feature: Int, | ||
| threshold : Double, | ||
|
|
@@ -29,12 +36,28 @@ case class Split( | |
| ", categories = " + categories | ||
| } | ||
|
|
||
| class DummyLowSplit(feature: Int, kind : FeatureType) | ||
| extends Split(feature, Double.MinValue, kind, List()) | ||
| /** | ||
| * Split with minimum threshold for continuous features. Helps with the smallest bin creation. | ||
| * @param feature feature index | ||
| * @param featureType type of feature -- categorical or continuous | ||
| */ | ||
| class DummyLowSplit(feature: Int, featureType : FeatureType) | ||
| extends Split(feature, Double.MinValue, featureType, List()) | ||
|
|
||
| class DummyHighSplit(feature: Int, kind : FeatureType) | ||
| extends Split(feature, Double.MaxValue, kind, List()) | ||
| /** | ||
| * Split with maximum threshold for continuous features. Helps with the highest bin creation. | ||
| * @param feature feature index | ||
| * @param featureType type of feature -- categorical or continuous | ||
| */ | ||
| class DummyHighSplit(feature: Int, featureType : FeatureType) | ||
| extends Split(feature, Double.MaxValue, featureType, List()) | ||
|
|
||
| class DummyCategoricalSplit(feature: Int, kind : FeatureType) | ||
| extends Split(feature, Double.MaxValue, kind, List()) | ||
| /** | ||
| * Split with no acceptable feature values for categorical features. Helps with the first bin | ||
| * creation. | ||
| * @param feature feature index | ||
| * @param featureType type of feature -- categorical or continuous | ||
| */ | ||
| class DummyCategoricalSplit(feature: Int, featureType : FeatureType) | ||
| extends Split(feature, Double.MaxValue, featureType, List()) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use JavaDoc style.