-
Notifications
You must be signed in to change notification settings - Fork 29k
MLI-1 Decision Trees #79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
cd53eae
92cedce
8bca1e2
0012a77
03f534c
dad0afc
4798aae
80e8c66
b0eb866
98ec8d5
02c595c
733d6dd
154aa77
b0e3e76
c8f6d60
e23c2e5
53108ed
6df35b9
dbb7ac1
d504eb1
6b7de78
b09dc98
c0e522b
f067d68
5841c28
0dd7659
dd0c0d7
9372779
84f85d6
d3023b3
63e786b
cd2c2b4
eb8fcbe
794ff4d
d1ef4f6
ad1fc21
62c2562
6068356
2116360
632818f
ff363a7
4576b64
24500c5
c487e6a
f963ef5
201702f
62dc723
e1dd86f
f536ae9
7d54b4f
1e8c704
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
Signed-off-by: Manish Amde <[email protected]>
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| package org.apache.spark.mllib.classification | ||
|
|
||
| class ClassificationTree { | ||
|
|
||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| package org.apache.spark.mllib.regression | ||
|
|
||
| class RegressionTree { | ||
|
|
||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.mllib.tree | ||
|
|
||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.mllib.regression.LabeledPoint | ||
| import org.apache.spark.mllib.tree.model.{Split, Bin, DecisionTreeModel} | ||
|
|
||
|
|
||
| class DecisionTree(val strategy : Strategy) { | ||
|
|
||
| def train(input : RDD[LabeledPoint]) : DecisionTreeModel = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no space between "input" and ":" |
||
|
|
||
| //Cache input RDD for speedup during multiple passes | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Put an extra space after "//". |
||
| input.cache() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the current implementation of other algorithms in MLlib, we let users to choose whether the data should be cached or not. How many passes does your algorithm need? |
||
|
|
||
| //TODO: Find all splits and bins using quantiles including support for categorical features, single-pass | ||
| val (splits, bins) = DecisionTree.find_splits_bins(input, strategy) | ||
|
|
||
| //TODO: Level-wise training of tree and obtain Decision Tree model | ||
|
|
||
|
|
||
| return new DecisionTreeModel() | ||
| } | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove extra blank line. |
||
| } | ||
|
|
||
| object DecisionTree { | ||
| def find_splits_bins(input : RDD[LabeledPoint], strategy : Strategy) : (Array[Array[Split]], Array[Array[Bin]]) = { | ||
| val numSplits = strategy.numSplits | ||
| //TODO: Justify this calculation | ||
| val requiredSamples : Long = numSplits*numSplits | ||
| val count : Long = input.count() | ||
| val numSamples : Long = if (requiredSamples < count) requiredSamples else count | ||
| val numFeatures = input.take(1)(0).features.length | ||
| (Array.ofDim[Split](numFeatures,numSplits),Array.ofDim[Bin](numFeatures,numSplits)) | ||
| } | ||
|
|
||
| } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. new line at the end |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| This package contains the default implementation of the decision tree algorithm. | ||
|
|
||
| The decision tree algorithm supports: | ||
| + information loss calculation with entropy and gini for classification and variance for regression | ||
| + node model pruning | ||
| + printing to dot files | ||
| + unit tests | ||
|
|
||
| #Performance testing | ||
|
|
||
| #Future Extensions | ||
|
|
||
| + Random forests | ||
| + Boosting | ||
| + Extremely randomized trees |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| package org.apache.spark.mllib.tree | ||
|
|
||
| import org.apache.spark.mllib.tree.impurity.Impurity | ||
|
|
||
| class Strategy ( | ||
| val kind : String, | ||
| val impurity : Impurity, | ||
| val maxDepth : Int, | ||
| val numSplits : Int, | ||
| val quantileCalculationStrategy : String = "sampleAndSort") { | ||
|
|
||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| package org.apache.spark.mllib.tree.impurity | ||
|
|
||
| object Entropy extends Impurity { | ||
|
|
||
| def log2(x: Double) = scala.math.log(x) / scala.math.log(2) | ||
|
|
||
| def calculate(c0: Double, c1: Double): Double = { | ||
| if (c0 == 0 || c1 == 0) { | ||
| 0 | ||
| } else { | ||
| val total = c0 + c1 | ||
| val f0 = c0 / total | ||
| val f1 = c1 / total | ||
| -(f0 * log2(f0)) - (f1 * log2(f1)) | ||
| } | ||
| } | ||
|
|
||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| package org.apache.spark.mllib.tree.impurity | ||
|
|
||
| object Gini extends Impurity { | ||
|
|
||
| def calculate(c0 : Double, c1 : Double): Double = { | ||
| val total = c0 + c1 | ||
| val f0 = c0 / total | ||
| val f1 = c1 / total | ||
| 1 - f0*f0 - f1*f1 | ||
| } | ||
|
|
||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| package org.apache.spark.mllib.tree.impurity | ||
|
|
||
| trait Impurity { | ||
|
|
||
| def calculate(c0 : Double, c1 : Double): Double | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. JavaDoc for public methods. |
||
|
|
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| package org.apache.spark.mllib.tree.impurity | ||
|
|
||
| import javax.naming.OperationNotSupportedException | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should use |
||
|
|
||
| object Variance extends Impurity { | ||
| def calculate(c0: Double, c1: Double): Double = throw new OperationNotSupportedException("Variance.calculate") | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| package org.apache.spark.mllib.tree.model | ||
|
|
||
| case class Bin(kind : String, lowSplit : Split, highSplit : Split) { | ||
|
|
||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| package org.apache.spark.mllib.tree.model | ||
|
|
||
| class DecisionTreeModel { | ||
|
|
||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| package org.apache.spark.mllib.tree.model | ||
|
|
||
| case class Split( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The functionality of I think depending on the choice here, we require |
||
| val feature: Int, | ||
| val threshold : Double, | ||
| val kind : String) { | ||
|
|
||
| } | ||
|
|
||
| class dummyLowSplit(kind : String) extends Split(Int.MinValue, Double.MinValue, kind) | ||
|
|
||
| class dummyHighSplit(kind : String) extends Split(Int.MaxValue, Double.MaxValue, kind) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A plan should be made to have a consistent hierarchy and organization of various algorithms in the MLLib package. A separate
treesubpackage seems unnecessary.