Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[SPARK-3159][ML] decision tree pruning, unit tests now checking the t…
…otal sum of counters at leaf level
  • Loading branch information
asolimando committed Mar 2, 2018
commit 01836782c828db8ddd928726114d464e273b0a86
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.ml.tree.impl

import scala.annotation.tailrec
import scala.collection.mutable

import org.apache.spark.SparkFunSuite
Expand Down Expand Up @@ -640,9 +641,9 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// impurity measure, but always leading to the same prediction).
///////////////////////////////////////////////////////////////////////////////

test("SPARK-3159 tree model redundancy - binary classification") {
// The following dataset is set up such that splitting over feature 1 for points having
// feature 0 = 0 improves the impurity measure, despite the prediction will always be 0
test("SPARK-3159 tree model redundancy - classification") {
// The following dataset is set up such that splitting over feature_1 for points having
// feature_0 = 0 improves the impurity measure, despite the prediction will always be 0
// in both branches.
val arr = Array(
LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
Expand All @@ -666,11 +667,13 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {

assert(prunedTree.numNodes === 5)
assert(unprunedTree.numNodes === 7)

assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size)
}

test("SPARK-3159 tree model redundancy - regression") {
// The following dataset is set up such that splitting over feature 0 for points having
// feature 1 = 1 improves the impurity measure, despite the prediction will always be 0.5
// The following dataset is set up such that splitting over feature_0 for points having
// feature_1 = 1 improves the impurity measure, despite the prediction will always be 0.5
// in both branches.
val arr = Array(
LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
Expand All @@ -694,6 +697,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind adding a check in both tests to make sure that the count of all the leaf nodes sums to the total count (i.e. 6)? That way we make sure we don't lose information when merging the leaves? You can do it via leafNode.impurityStats.count.

assert(prunedTree.numNodes === 3)
assert(unprunedTree.numNodes === 5)
assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size)
}
}

Expand All @@ -703,4 +707,16 @@ private object RandomForestSuite {
val (indices, values) = map.toSeq.sortBy(_._1).unzip
Vectors.sparse(size, indices.toArray, values.toArray)
}

@tailrec
private def getSumLeafCounters(nodes: List[Node], acc: Long = 0): Long =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to enclose the function body in curly braces

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I have added them

if (nodes.isEmpty) {
acc
}
else {
nodes.head match {
case i: InternalNode => getSumLeafCounters(i.leftChild :: i.rightChild :: nodes.tail, acc)
case l: LeafNode => getSumLeafCounters(nodes.tail, acc + l.impurityStats.count)
}
}
}