Skip to content

Commit 51ef781

Browse files
committed
Fixed bug introduced by last commit: Variance impurity calculation was incorrect since counts were swapped accidentally
1 parent fd65372 commit 51ef781

File tree

1 file changed

+6
-6
lines changed
  • mllib/src/main/scala/org/apache/spark/mllib/tree/impurity

1 file changed

+6
-6
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,21 +75,21 @@ private[tree] class VarianceAggregator extends ImpurityAggregator(3) with Serial
7575
}
7676

7777
def add(label: Double): Unit = {
78-
counts(0) += label
79-
counts(1) += label * label
80-
counts(2) += 1
78+
counts(0) += 1
79+
counts(1) += label
80+
counts(2) += label * label
8181
}
8282

83-
def count: Long = counts(2).toLong
83+
def count: Long = counts(0).toLong
8484

8585
def predict: Double = if (count == 0) {
8686
0
8787
} else {
88-
counts(0) / counts(2)
88+
counts(1) / count
8989
}
9090

9191
override def toString: String = {
92-
s"VarianceAggregator(sum = ${counts(0)}, sum2 = ${counts(1)}, cnt = ${counts(2)})"
92+
s"VarianceAggregator(cnt = ${counts(0)}, sum = ${counts(1)}, sum2 = ${counts(2)})"
9393
}
9494

9595
def newAggregator: VarianceAggregator = {

0 commit comments

Comments
 (0)