Skip to content
Closed
Prev Previous commit
Next Next commit
Get rid of calculateImpurityStats
  • Loading branch information
MechCoder committed Jun 28, 2016
commit ca8b36088b74cacb7f162fb793070c4d3c6a1a8c
Original file line number Diff line number Diff line change
Expand Up @@ -613,65 +613,6 @@ private[spark] object RandomForest extends Logging {
}
}

/**
* Calculate the impurity statistics for a given (feature, split) based upon left/right
* aggregates.
*
* @param stats the recycle impurity statistics for this feature's all splits,
* only 'impurity' and 'impurityCalculator' are valid between each iteration
* @param leftImpurityCalculator left node aggregates for this (feature, split)
* @param rightImpurityCalculator right node aggregate for this (feature, split)
* @param metadata learning and dataset metadata for DecisionTree
* @return Impurity statistics for this (feature, split)
*/
private def calculateImpurityStats(
stats: ImpurityStats,
leftImpurityCalculator: ImpurityCalculator,
rightImpurityCalculator: ImpurityCalculator,
metadata: DecisionTreeMetadata): ImpurityStats = {

val parentImpurityCalculator: ImpurityCalculator = if (stats == null) {
leftImpurityCalculator.copy.add(rightImpurityCalculator)
} else {
stats.impurityCalculator
}

val impurity: Double = if (stats == null) {
parentImpurityCalculator.calculate()
} else {
stats.impurity
}

val leftCount = leftImpurityCalculator.count
val rightCount = rightImpurityCalculator.count

val totalCount = leftCount + rightCount

// If left child or right child doesn't satisfy minimum instances per node,
// then this split is invalid, return invalid information gain stats.
if ((leftCount < metadata.minInstancesPerNode) ||
(rightCount < metadata.minInstancesPerNode)) {
return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
}

val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
val rightImpurity = rightImpurityCalculator.calculate()

val leftWeight = leftCount / totalCount.toDouble
val rightWeight = rightCount / totalCount.toDouble

val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity

// if information gain doesn't satisfy minimum information gain,
// then this split is invalid, return invalid information gain stats.
if (gain < metadata.minInfoGain) {
return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
}

new ImpurityStats(gain, impurity, parentImpurityCalculator,
leftImpurityCalculator, rightImpurityCalculator)
}

/**
* Find the best split for a node.
*
Expand All @@ -684,13 +625,7 @@ private[spark] object RandomForest extends Logging {
featuresForNode: Option[Array[Int]],
node: LearningNode): (Split, ImpurityStats) = {

// Calculate InformationGain and ImpurityStats if current node is top node
val level = LearningNode.indexToLevel(node.id)
var gainAndImpurityStats: ImpurityStats = if (level == 0) {
null
} else {
node.stats
}

// For each (feature, split), calculate the gain, and select the best (feature, split).
val (bestSplit, bestGain, bestFeatureOffset, bestSplitIndex) =
Expand Down Expand Up @@ -802,12 +737,21 @@ private[spark] object RandomForest extends Logging {
}
}.maxBy(_._2)

val leftChildStats = binAggregates.getImpurityCalculator(
val leftImpurityCalculator = binAggregates.getImpurityCalculator(
bestFeatureOffset, bestSplitIndex)
val rightChildStats = binAggregates.getParentImpurityCalculator()
rightChildStats.subtract(leftChildStats)
val bestFeatureGainStats = calculateImpurityStats(gainAndImpurityStats,
leftChildStats, rightChildStats, binAggregates.metadata)
val parentImpurityCalculator = binAggregates.getParentImpurityCalculator()
val rightImpurityCalculator = parentImpurityCalculator.copy.subtract(
leftImpurityCalculator)
val bestFeatureGainStats = {
if (bestGain == Double.MinValue) {
ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
}
else {
new ImpurityStats(bestGain, parentImpurityCalculator.calculate(),
parentImpurityCalculator, leftImpurityCalculator,
rightImpurityCalculator)
}
}
(bestSplit, bestFeatureGainStats)
}

Expand Down