1717
1818package org .apache .spark .mllib .tree
1919
20- import scala .util .control .Breaks ._
2120import org .apache .spark .SparkContext ._
2221import org .apache .spark .rdd .RDD
2322import org .apache .spark .mllib .tree .model ._
@@ -29,6 +28,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
2928import org .apache .spark .mllib .tree .configuration .FeatureType ._
3029import org .apache .spark .mllib .tree .configuration .Algo ._
3130import org .apache .spark .mllib .tree .impurity .{Variance , Entropy , Gini , Impurity }
31+ import scala .util .control .Breaks ._
3232
3333/**
3434 * A class that implements a decision tree algorithm for classification and regression. It
@@ -181,8 +181,8 @@ object DecisionTree extends Serializable with Logging {
181181 input : RDD [LabeledPoint ],
182182 algo : Algo ,
183183 impurity : Impurity ,
184- maxDepth : Int
185- ) : DecisionTreeModel = {
184+ maxDepth : Int )
185+ : DecisionTreeModel = {
186186 val strategy = new Strategy (algo,impurity,maxDepth)
187187 new DecisionTree (strategy).train(input : RDD [LabeledPoint ])
188188 }
@@ -211,8 +211,8 @@ object DecisionTree extends Serializable with Logging {
211211 maxDepth : Int ,
212212 maxBins : Int ,
213213 quantileCalculationStrategy : QuantileStrategy ,
214- categoricalFeaturesInfo : Map [Int ,Int ]
215- ) : DecisionTreeModel = {
214+ categoricalFeaturesInfo : Map [Int ,Int ])
215+ : DecisionTreeModel = {
216216 val strategy = new Strategy (algo, impurity, maxDepth, maxBins, quantileCalculationStrategy,
217217 categoricalFeaturesInfo)
218218 new DecisionTree (strategy).train(input : RDD [LabeledPoint ])
@@ -238,7 +238,8 @@ object DecisionTree extends Serializable with Logging {
238238 level : Int ,
239239 filters : Array [List [Filter ]],
240240 splits : Array [Array [Split ]],
241- bins : Array [Array [Bin ]]): Array [(Split , InformationGainStats )] = {
241+ bins : Array [Array [Bin ]])
242+ : Array [(Split , InformationGainStats )] = {
242243
243244 // Common calculations for multiple nested methods
244245 val numNodes = scala.math.pow(2 , level).toInt
@@ -301,7 +302,8 @@ object DecisionTree extends Serializable with Logging {
301302 def findBin (
302303 featureIndex : Int ,
303304 labeledPoint : LabeledPoint ,
304- isFeatureContinuous : Boolean ): Int = {
305+ isFeatureContinuous : Boolean )
306+ : Int = {
305307
306308 if (isFeatureContinuous){
307309 for (binIndex <- 0 until strategy.numBins) {
@@ -515,7 +517,8 @@ object DecisionTree extends Serializable with Logging {
515517 featureIndex : Int ,
516518 splitIndex : Int ,
517519 rightNodeAgg : Array [Array [Double ]],
518- topImpurity : Double ): InformationGainStats = {
520+ topImpurity : Double )
521+ : InformationGainStats = {
519522
520523 strategy.algo match {
521524 case Classification => {
@@ -694,7 +697,8 @@ object DecisionTree extends Serializable with Logging {
694697 def calculateGainsForAllNodeSplits (
695698 leftNodeAgg : Array [Array [Double ]],
696699 rightNodeAgg : Array [Array [Double ]],
697- nodeImpurity : Double ): Array [Array [InformationGainStats ]] = {
700+ nodeImpurity : Double )
701+ : Array [Array [InformationGainStats ]] = {
698702
699703 val gains = Array .ofDim[InformationGainStats ](numFeatures, numBins - 1 )
700704
@@ -715,7 +719,8 @@ object DecisionTree extends Serializable with Logging {
715719 */
716720 def binsToBestSplit (
717721 binData : Array [Double ],
718- nodeImpurity : Double ): (Split , InformationGainStats ) = {
722+ nodeImpurity : Double )
723+ : (Split , InformationGainStats ) = {
719724
720725 logDebug(" node impurity = " + nodeImpurity)
721726 val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData)
@@ -786,7 +791,8 @@ object DecisionTree extends Serializable with Logging {
786791 */
787792 def findSplitsBins (
788793 input : RDD [LabeledPoint ],
789- strategy : Strategy ): (Array [Array [Split ]], Array [Array [Bin ]]) = {
794+ strategy : Strategy )
795+ : (Array [Array [Split ]], Array [Array [Bin ]]) = {
790796
791797 val count = input.count()
792798
@@ -947,12 +953,11 @@ object DecisionTree extends Serializable with Logging {
947953 }
948954 val options = nextOption(Map (),arglist)
949955 logDebug(options.toString())
950- // TODO: Add validation for input parameters
951956
952957 // Load training data
953958 val trainData = loadLabeledData(sc, options.get(' trainDataDir ).get.toString)
954959
955- // Figure out the type of algorithm
960+ // Identify the type of algorithm
956961 val algoStr = options.get(' algo ).get.toString
957962 val algo = algoStr match {
958963 case " Classification" => Classification
@@ -1007,7 +1012,10 @@ object DecisionTree extends Serializable with Logging {
10071012 }
10081013 }
10091014
1010- // TODO: Port them to a metrics package
1015+ // TODO: Port this method to a generic metrics package
1016+ /**
1017+ * Calculates the classifier accuracy.
1018+ */
10111019 def accuracyScore (model : DecisionTreeModel , data : RDD [LabeledPoint ]): Double = {
10121020 val correctCount = data.filter(y => model.predict(y.features) == y.label).count()
10131021 val count = data.count()
@@ -1016,7 +1024,10 @@ object DecisionTree extends Serializable with Logging {
10161024 correctCount.toDouble / count
10171025 }
10181026
1019- // TODO: Make these generic MLTable metrics
1027+ // TODO: Port this method to a generic metrics package
1028+ /**
1029+ * Calculates the mean squared error for regression
1030+ */
10201031 def meanSquaredError (tree : DecisionTreeModel , data : RDD [LabeledPoint ]): Double = {
10211032 val meanSumOfSquares =
10221033 data.map(y => (tree.predict(y.features) - y.label)* (tree.predict(y.features) - y.label))
0 commit comments