-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-3159][ML] Add decision tree pruning #20632
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
|
|
||
| package org.apache.spark.ml.tree.impl | ||
|
|
||
| import scala.annotation.tailrec | ||
| import scala.collection.mutable | ||
|
|
||
| import org.apache.spark.SparkFunSuite | ||
|
|
@@ -38,6 +39,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { | |
|
|
||
| import RandomForestSuite.mapToVec | ||
|
|
||
| private val seed = 42 | ||
|
|
||
| ///////////////////////////////////////////////////////////////////////////// | ||
| // Tests for split calculation | ||
| ///////////////////////////////////////////////////////////////////////////// | ||
|
|
@@ -320,10 +323,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { | |
| assert(topNode.isLeaf === false) | ||
| assert(topNode.stats === null) | ||
|
|
||
| val nodesForGroup = Map((0, Array(topNode))) | ||
| val treeToNodeToIndexInfo = Map((0, Map( | ||
| (topNode.id, new RandomForest.NodeIndexInfo(0, None)) | ||
| ))) | ||
| val nodesForGroup = Map(0 -> Array(topNode)) | ||
| val treeToNodeToIndexInfo = Map(0 -> Map( | ||
| topNode.id -> new RandomForest.NodeIndexInfo(0, None) | ||
| )) | ||
| val nodeStack = new mutable.ArrayStack[(Int, LearningNode)] | ||
| RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode), | ||
| nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack) | ||
|
|
@@ -362,10 +365,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { | |
| assert(topNode.isLeaf === false) | ||
| assert(topNode.stats === null) | ||
|
|
||
| val nodesForGroup = Map((0, Array(topNode))) | ||
| val treeToNodeToIndexInfo = Map((0, Map( | ||
| (topNode.id, new RandomForest.NodeIndexInfo(0, None)) | ||
| ))) | ||
| val nodesForGroup = Map(0 -> Array(topNode)) | ||
|
||
| val treeToNodeToIndexInfo = Map(0 -> Map( | ||
| topNode.id -> new RandomForest.NodeIndexInfo(0, None) | ||
| )) | ||
| val nodeStack = new mutable.ArrayStack[(Int, LearningNode)] | ||
| RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode), | ||
| nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack) | ||
|
|
@@ -407,7 +410,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { | |
| numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3) | ||
|
|
||
| val model = RandomForest.run(input, strategy, numTrees = 1, featureSubsetStrategy = "all", | ||
| seed = 42, instr = None).head | ||
| seed = 42, instr = None, prune = false).head | ||
|
|
||
| model.rootNode match { | ||
| case n: InternalNode => n.split match { | ||
| case s: CategoricalSplit => | ||
|
|
@@ -631,13 +635,88 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { | |
| val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0) | ||
| assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01) | ||
| } | ||
|
|
||
| /////////////////////////////////////////////////////////////////////////////// | ||
| // Tests for pruning of redundant subtrees (generated by a split improving the | ||
| // impurity measure, but always leading to the same prediction). | ||
| /////////////////////////////////////////////////////////////////////////////// | ||
|
|
||
| test("SPARK-3159 tree model redundancy - classification") { | ||
| // The following dataset is set up such that splitting over feature_1 for points having | ||
| // feature_0 = 0 improves the impurity measure, despite the prediction will always be 0 | ||
| // in both branches. | ||
| val arr = Array( | ||
| LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), | ||
| LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), | ||
| LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), | ||
| LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), | ||
| LabeledPoint(0.0, Vectors.dense(1.0, 0.0)), | ||
| LabeledPoint(1.0, Vectors.dense(1.0, 1.0)) | ||
| ) | ||
| val rdd = sc.parallelize(arr) | ||
|
|
||
| val numClasses = 2 | ||
| val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 4, | ||
| numClasses = numClasses, maxBins = 32) | ||
|
|
||
| val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", | ||
| seed = 42, instr = None).head | ||
|
|
||
| val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", | ||
| seed = 42, instr = None, prune = false).head | ||
|
|
||
| assert(prunedTree.numNodes === 5) | ||
| assert(unprunedTree.numNodes === 7) | ||
|
|
||
| assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size) | ||
| } | ||
|
|
||
| test("SPARK-3159 tree model redundancy - regression") { | ||
| // The following dataset is set up such that splitting over feature_0 for points having | ||
| // feature_1 = 1 improves the impurity measure, despite the prediction will always be 0.5 | ||
| // in both branches. | ||
| val arr = Array( | ||
| LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), | ||
| LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), | ||
| LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), | ||
| LabeledPoint(0.0, Vectors.dense(1.0, 0.0)), | ||
| LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), | ||
| LabeledPoint(0.0, Vectors.dense(1.0, 1.0)), | ||
| LabeledPoint(0.5, Vectors.dense(1.0, 1.0)) | ||
| ) | ||
| val rdd = sc.parallelize(arr) | ||
|
|
||
| val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = Variance, maxDepth = 4, | ||
| numClasses = 0, maxBins = 32) | ||
|
|
||
| val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", | ||
| seed = 42, instr = None).head | ||
|
|
||
| val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", | ||
| seed = 42, instr = None, prune = false).head | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would you mind adding a check in both tests to make sure that the count of all the leaf nodes sums to the total count (i.e. 6)? That way we make sure we don't lose information when merging the leaves? You can do it via |
||
| assert(prunedTree.numNodes === 3) | ||
| assert(unprunedTree.numNodes === 5) | ||
| assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size) | ||
| } | ||
| } | ||
|
|
||
| private object RandomForestSuite { | ||
|
|
||
| def mapToVec(map: Map[Int, Double]): Vector = { | ||
| val size = (map.keys.toSeq :+ 0).max + 1 | ||
| val (indices, values) = map.toSeq.sortBy(_._1).unzip | ||
| Vectors.sparse(size, indices.toArray, values.toArray) | ||
| } | ||
|
|
||
| @tailrec | ||
| private def getSumLeafCounters(nodes: List[Node], acc: Long = 0): Long = | ||
|
||
| if (nodes.isEmpty) { | ||
| acc | ||
| } | ||
| else { | ||
| nodes.head match { | ||
| case i: InternalNode => getSumLeafCounters(i.leftChild :: i.rightChild :: nodes.tail, acc) | ||
| case l: LeafNode => getSumLeafCounters(nodes.tail, acc + l.impurityStats.count) | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -363,10 +363,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { | |
| // if a split does not satisfy min instances per node requirements, | ||
| // this split is invalid, even though the information gain of split is large. | ||
| val arr = Array( | ||
| LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), | ||
|
||
| LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), | ||
| LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), | ||
| LabeledPoint(0.0, Vectors.dense(0.0, 0.0))) | ||
| LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), | ||
| LabeledPoint(0.0, Vectors.dense(1.0, 1.0)), | ||
| LabeledPoint(1.0, Vectors.dense(0.0, 0.0)), | ||
| LabeledPoint(1.0, Vectors.dense(0.0, 0.0))) | ||
|
|
||
| val rdd = sc.parallelize(arr) | ||
| val strategy = new Strategy(algo = Classification, impurity = Gini, | ||
|
|
@@ -541,7 +541,7 @@ object DecisionTreeSuite extends SparkFunSuite { | |
| Array[LabeledPoint] = { | ||
| val arr = new Array[LabeledPoint](3000) | ||
| for (i <- 0 until 3000) { | ||
| if (i < 1000) { | ||
| if (i < 1001) { | ||
|
||
| arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0)) | ||
| } else if (i < 2000) { | ||
| arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0)) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you just overload the method then you don't need to change the existing function calls.