@@ -813,15 +813,9 @@ object DecisionTree extends Serializable with Logging {
813813 logDebug(" node impurity = " + nodeImpurity)
814814
815815 // For each (feature, split), calculate the gain, and select the best (feature, split).
816- // Initialize with infeasible values.
817- var bestFeatureIndex = Int .MinValue
818- var bestSplitIndex = Int .MinValue
819- var bestGainStats = new InformationGainStats (Double .MinValue , - 1.0 , - 1.0 , - 1.0 , - 1.0 )
820- var featureIndex = 0
821- // TODO: Change loops over splits into iterators.
822- while (featureIndex < metadata.numFeatures) {
816+ Range (0 , metadata.numFeatures).map { featureIndex =>
823817 val numSplits = metadata.numSplits(featureIndex)
824- if (metadata.isContinuous(featureIndex)) {
818+ val (bestSplitIndex, bestGainStats) = if (metadata.isContinuous(featureIndex)) {
825819 // println(s"binsToBestSplit: feature $featureIndex (continuous)")
826820 // Cumulative sum (scanLeft) of bin statistics.
827821 // Afterwards, binAggregates for a bin is the sum of aggregates for
@@ -833,39 +827,26 @@ object DecisionTree extends Serializable with Logging {
833827 splitIndex += 1
834828 }
835829 // Find best split.
836- splitIndex = 0
837- while (splitIndex < numSplits) {
838- val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIndex)
830+ Range (0 , numSplits).map { case splitIdx =>
831+ val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
839832 val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
840833 rightChildStats.subtract(leftChildStats)
841834 val gainStats =
842835 calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
843- if (gainStats.gain > bestGainStats.gain) {
844- bestGainStats = gainStats
845- bestFeatureIndex = featureIndex
846- bestSplitIndex = splitIndex
847- }
848- splitIndex += 1
849- }
836+ (splitIdx, gainStats)
837+ }.maxBy(_._2.gain)
850838 } else if (metadata.isUnordered(featureIndex)) {
851839 // println(s"binsToBestSplit: feature $featureIndex (unordered cat)")
852840 // Unordered categorical feature
853841 val (leftChildOffset, rightChildOffset) =
854842 binAggregates.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndex)
855- var splitIndex = 0
856- while (splitIndex < numSplits) {
843+ Range (0 , numSplits).map { splitIndex =>
857844 val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
858845 val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
859846 val gainStats =
860847 calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
861- // println(s"\t split $splitIndex: gain: ${bestGainStats.gain}")
862- if (gainStats.gain > bestGainStats.gain) {
863- bestGainStats = gainStats
864- bestFeatureIndex = featureIndex
865- bestSplitIndex = splitIndex
866- }
867- splitIndex += 1
868- }
848+ (splitIndex, gainStats)
849+ }.maxBy(_._2.gain)
869850 } else {
870851 // println(s"binsToBestSplit: feature $featureIndex (ordered cat)")
871852 // Ordered categorical feature
@@ -880,25 +861,17 @@ object DecisionTree extends Serializable with Logging {
880861 splitIndex += 1
881862 }
882863 // Find best split.
883- splitIndex = 0
884- while (splitIndex < numSplits) {
864+ Range (0 , numSplits).map { splitIndex =>
885865 val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIndex)
886866 val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
887867 rightChildStats.subtract(leftChildStats)
888868 val gainStats =
889869 calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
890- // println(s"\t split $splitIndex: gain: ${bestGainStats.gain}")
891- if (gainStats.gain > bestGainStats.gain) {
892- bestGainStats = gainStats
893- bestFeatureIndex = featureIndex
894- bestSplitIndex = splitIndex
895- }
896- splitIndex += 1
897- }
870+ (splitIndex, gainStats)
871+ }.maxBy(_._2.gain)
898872 }
899- featureIndex += 1
900- }
901- (bestFeatureIndex, bestSplitIndex, bestGainStats)
873+ (featureIndex, bestSplitIndex, bestGainStats)
874+ }.maxBy(_._3.gain)
902875 }
903876
904877 /**
0 commit comments