Skip to content

Commit 794ff4d

Browse files
committed
minor improvements to docs and style
1 parent eb8fcbe commit 794ff4d

File tree

1 file changed

+26
-15
lines changed

1 file changed

+26
-15
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package org.apache.spark.mllib.tree
1919

20-
import scala.util.control.Breaks._
2120
import org.apache.spark.SparkContext._
2221
import org.apache.spark.rdd.RDD
2322
import org.apache.spark.mllib.tree.model._
@@ -29,6 +28,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
2928
import org.apache.spark.mllib.tree.configuration.FeatureType._
3029
import org.apache.spark.mllib.tree.configuration.Algo._
3130
import org.apache.spark.mllib.tree.impurity.{Variance, Entropy, Gini, Impurity}
31+
import scala.util.control.Breaks._
3232

3333
/**
3434
* A class that implements a decision tree algorithm for classification and regression. It
@@ -181,8 +181,8 @@ object DecisionTree extends Serializable with Logging {
181181
input: RDD[LabeledPoint],
182182
algo: Algo,
183183
impurity: Impurity,
184-
maxDepth: Int
185-
): DecisionTreeModel = {
184+
maxDepth: Int)
185+
: DecisionTreeModel = {
186186
val strategy = new Strategy(algo,impurity,maxDepth)
187187
new DecisionTree(strategy).train(input: RDD[LabeledPoint])
188188
}
@@ -211,8 +211,8 @@ object DecisionTree extends Serializable with Logging {
211211
maxDepth: Int,
212212
maxBins: Int,
213213
quantileCalculationStrategy: QuantileStrategy,
214-
categoricalFeaturesInfo: Map[Int,Int]
215-
): DecisionTreeModel = {
214+
categoricalFeaturesInfo: Map[Int,Int])
215+
: DecisionTreeModel = {
216216
val strategy = new Strategy(algo, impurity, maxDepth, maxBins, quantileCalculationStrategy,
217217
categoricalFeaturesInfo)
218218
new DecisionTree(strategy).train(input: RDD[LabeledPoint])
@@ -238,7 +238,8 @@ object DecisionTree extends Serializable with Logging {
238238
level: Int,
239239
filters: Array[List[Filter]],
240240
splits: Array[Array[Split]],
241-
bins: Array[Array[Bin]]): Array[(Split, InformationGainStats)] = {
241+
bins: Array[Array[Bin]])
242+
: Array[(Split, InformationGainStats)] = {
242243

243244
//Common calculations for multiple nested methods
244245
val numNodes = scala.math.pow(2, level).toInt
@@ -301,7 +302,8 @@ object DecisionTree extends Serializable with Logging {
301302
def findBin(
302303
featureIndex: Int,
303304
labeledPoint: LabeledPoint,
304-
isFeatureContinuous: Boolean): Int = {
305+
isFeatureContinuous: Boolean)
306+
: Int = {
305307

306308
if (isFeatureContinuous){
307309
for (binIndex <- 0 until strategy.numBins) {
@@ -515,7 +517,8 @@ object DecisionTree extends Serializable with Logging {
515517
featureIndex: Int,
516518
splitIndex: Int,
517519
rightNodeAgg: Array[Array[Double]],
518-
topImpurity: Double): InformationGainStats = {
520+
topImpurity: Double)
521+
: InformationGainStats = {
519522

520523
strategy.algo match {
521524
case Classification => {
@@ -694,7 +697,8 @@ object DecisionTree extends Serializable with Logging {
694697
def calculateGainsForAllNodeSplits(
695698
leftNodeAgg: Array[Array[Double]],
696699
rightNodeAgg: Array[Array[Double]],
697-
nodeImpurity: Double): Array[Array[InformationGainStats]] = {
700+
nodeImpurity: Double)
701+
: Array[Array[InformationGainStats]] = {
698702

699703
val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1)
700704

@@ -715,7 +719,8 @@ object DecisionTree extends Serializable with Logging {
715719
*/
716720
def binsToBestSplit(
717721
binData: Array[Double],
718-
nodeImpurity: Double): (Split, InformationGainStats) = {
722+
nodeImpurity: Double)
723+
: (Split, InformationGainStats) = {
719724

720725
logDebug("node impurity = " + nodeImpurity)
721726
val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData)
@@ -786,7 +791,8 @@ object DecisionTree extends Serializable with Logging {
786791
*/
787792
def findSplitsBins(
788793
input: RDD[LabeledPoint],
789-
strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = {
794+
strategy: Strategy)
795+
: (Array[Array[Split]], Array[Array[Bin]]) = {
790796

791797
val count = input.count()
792798

@@ -947,12 +953,11 @@ object DecisionTree extends Serializable with Logging {
947953
}
948954
val options = nextOption(Map(),arglist)
949955
logDebug(options.toString())
950-
//TODO: Add validation for input parameters
951956

952957
//Load training data
953958
val trainData = loadLabeledData(sc, options.get('trainDataDir).get.toString)
954959

955-
//Figure out the type of algorithm
960+
//Identify the type of algorithm
956961
val algoStr = options.get('algo).get.toString
957962
val algo = algoStr match {
958963
case "Classification" => Classification
@@ -1007,7 +1012,10 @@ object DecisionTree extends Serializable with Logging {
10071012
}
10081013
}
10091014

1010-
//TODO: Port them to a metrics package
1015+
//TODO: Port this method to a generic metrics package
1016+
/**
1017+
* Calculates the classifier accuracy.
1018+
*/
10111019
def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint]): Double = {
10121020
val correctCount = data.filter(y => model.predict(y.features) == y.label).count()
10131021
val count = data.count()
@@ -1016,7 +1024,10 @@ object DecisionTree extends Serializable with Logging {
10161024
correctCount.toDouble / count
10171025
}
10181026

1019-
//TODO: Make these generic MLTable metrics
1027+
//TODO: Port this method to a generic metrics package
1028+
/**
1029+
* Calculates the mean squared error for regression
1030+
*/
10201031
def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = {
10211032
val meanSumOfSquares =
10221033
data.map(y => (tree.predict(y.features) - y.label)*(tree.predict(y.features) - y.label))

0 commit comments

Comments
 (0)