-
Notifications
You must be signed in to change notification settings - Fork 29.1k
MLI-1 Decision Trees #79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
cd53eae
92cedce
8bca1e2
0012a77
03f534c
dad0afc
4798aae
80e8c66
b0eb866
98ec8d5
02c595c
733d6dd
154aa77
b0e3e76
c8f6d60
e23c2e5
53108ed
6df35b9
dbb7ac1
d504eb1
6b7de78
b09dc98
c0e522b
f067d68
5841c28
0dd7659
dd0c0d7
9372779
84f85d6
d3023b3
63e786b
cd2c2b4
eb8fcbe
794ff4d
d1ef4f6
ad1fc21
62c2562
6068356
2116360
632818f
ff363a7
4576b64
24500c5
c487e6a
f963ef5
201702f
62dc723
e1dd86f
f536ae9
7d54b4f
1e8c704
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
Signed-off-by: Manish Amde <manish9ue@gmail.com>
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,8 +23,9 @@ import org.apache.spark.mllib.tree.model._ | |
| import org.apache.spark.{SparkContext, Logging} | ||
| import org.apache.spark.mllib.regression.LabeledPoint | ||
| import org.apache.spark.mllib.tree.model.Split | ||
| import org.apache.spark.mllib.tree.impurity.Gini | ||
| import scala.util.control.Breaks._ | ||
| import org.apache.spark.mllib.tree.configuration.Strategy | ||
| import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ | ||
|
|
||
|
|
||
| class DecisionTree(val strategy : Strategy) extends Serializable with Logging { | ||
|
|
@@ -34,8 +35,6 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { | |
| //Cache input RDD for speedup during multiple passes | ||
| input.cache() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the current implementation of other algorithms in MLlib, we let users to choose whether the data should be cached or not. How many passes does your algorithm need? |
||
|
|
||
| //TODO: Find all splits and bins using quantiles including support for categorical features, single-pass | ||
| //TODO: Think about broadcasting this | ||
| val (splits, bins) = DecisionTree.find_splits_bins(input, strategy) | ||
| logDebug("numSplits = " + bins(0).length) | ||
| strategy.numBins = bins(0).length | ||
|
|
@@ -133,7 +132,7 @@ object DecisionTree extends Serializable with Logging { | |
|
|
||
| @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data for DecisionTree | ||
| @param parentImpurities Impurities for all parent nodes for the current level | ||
| @param strategy [[org.apache.spark.mllib.tree.Strategy]] instance containing parameters for construction the DecisionTree | ||
| @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing parameters for construction the DecisionTree | ||
| @param level Level of the tree | ||
| @param filters Filter for all nodes at a given level | ||
| @param splits possible splits for all features | ||
|
|
@@ -406,27 +405,18 @@ object DecisionTree extends Serializable with Logging { | |
| val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) | ||
| val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) | ||
|
|
||
| //logDebug("gains.size = " + gains.size) | ||
| //logDebug("gains(0).size = " + gains(0).size) | ||
|
|
||
| val (bestFeatureIndex,bestSplitIndex, gainStats) = { | ||
| var bestFeatureIndex = 0 | ||
| var bestSplitIndex = 0 | ||
| //Initialization with infeasible values | ||
| var bestGainStats = new InformationGainStats(Double.MinValue,-1.0,-1.0,0,-1.0,0) | ||
| // var maxGain = Double.MinValue | ||
| // var leftSamples = Long.MinValue | ||
| // var rightSamples = Long.MinValue | ||
| for (featureIndex <- 0 until numFeatures) { | ||
| for (splitIndex <- 0 until numSplits - 1){ | ||
| val gainStats = gains(featureIndex)(splitIndex) | ||
| //logDebug("featureIndex = " + featureIndex + ", splitIndex = " + splitIndex + ", gain = " + gain) | ||
| if(gainStats.gain > bestGainStats.gain) { | ||
| bestGainStats = gainStats | ||
| bestFeatureIndex = featureIndex | ||
| bestSplitIndex = splitIndex | ||
| //logDebug("bestFeatureIndex = " + bestFeatureIndex + ", bestSplitIndex = " + bestSplitIndex) | ||
| //logDebug( "gain stats = " + bestGainStats) | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -455,7 +445,7 @@ object DecisionTree extends Serializable with Logging { | |
| Returns split and bins for decision tree calculation. | ||
|
|
||
| @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data for DecisionTree | ||
| @param strategy [[org.apache.spark.mllib.tree.Strategy]] instance containing parameters for construction the DecisionTree | ||
| @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing parameters for construction the DecisionTree | ||
| @return a tuple of (splits,bins) where Split is an Array[Array[Split]] of size (numFeatures,numSplits-1) and bins is an | ||
| Array[Array[Bin]] of size (numFeatures,numSplits1) | ||
| */ | ||
|
|
@@ -483,7 +473,7 @@ object DecisionTree extends Serializable with Logging { | |
| logDebug("stride = " + stride) | ||
|
|
||
| strategy.quantileCalculationStrategy match { | ||
| case "sort" => { | ||
| case Sort => { | ||
| val splits = Array.ofDim[Split](numFeatures,numBins-1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove the extra space after "="
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. put an extra space after "," |
||
| val bins = Array.ofDim[Bin](numFeatures,numBins) | ||
|
|
||
|
|
@@ -514,10 +504,10 @@ object DecisionTree extends Serializable with Logging { | |
|
|
||
| (splits,bins) | ||
| } | ||
| case "minMax" => { | ||
| case MinMax => { | ||
| (Array.ofDim[Split](numFeatures,numBins),Array.ofDim[Bin](numFeatures,numBins+2)) | ||
| } | ||
| case "approximateHistogram" => { | ||
| case ApproxHist => { | ||
| throw new UnsupportedOperationException("approximate histogram not supported yet.") | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| package org.apache.spark.mllib.tree.configuration | ||
|
|
||
| object Algo extends Enumeration { | ||
|
||
| type Algo = Value | ||
| val Classification, Regression = Value | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| package org.apache.spark.mllib.tree.configuration | ||
|
|
||
| object QuantileStrategy extends Enumeration { | ||
| type QuantileStrategy = Value | ||
| val Sort, MinMax, ApproxHist = Value | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put an extra space after "//".