Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ package org.apache.spark.ml.tree

import org.apache.spark.ml.linalg.Vector
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.{ImpurityStats,
InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict}
import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict}

/**
* Decision tree node interface.
Expand Down Expand Up @@ -266,15 +265,23 @@ private[tree] class LearningNode(
var isLeaf: Boolean,
var stats: ImpurityStats) extends Serializable {

def toNode: Node = toNode(prune = true)

/**
* Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children.
*/
def toNode: Node = {
if (leftChild.nonEmpty) {
assert(rightChild.nonEmpty && split.nonEmpty && stats != null,
def toNode(prune: Boolean = true): Node = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you just overload the method then you don't need to change the existing function calls.

def toNode: Node = toNode(prune = true)


if (!leftChild.isEmpty || !rightChild.isEmpty) {
assert(leftChild.nonEmpty && rightChild.nonEmpty && split.nonEmpty && stats != null,
"Unknown error during Decision Tree learning. Could not convert LearningNode to Node.")
new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain,
leftChild.get.toNode, rightChild.get.toNode, split.get, stats.impurityCalculator)
(leftChild.get.toNode(prune), rightChild.get.toNode(prune)) match {
case (l: LeafNode, r: LeafNode) if prune && l.prediction == r.prediction =>
new LeafNode(l.prediction, stats.impurity, stats.impurityCalculator)
case (l, r) =>
new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain,
l, r, split.get, stats.impurityCalculator)
}
} else {
if (stats.valid) {
new LeafNode(stats.impurityCalculator.predict, stats.impurity,
Expand All @@ -283,7 +290,6 @@ private[tree] class LearningNode(
// Here we want to keep same behavior with the old mllib.DecisionTreeModel
new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator)
}

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ private[spark] object RandomForest extends Logging {
featureSubsetStrategy: String,
seed: Long,
instr: Option[Instrumentation[_]],
prune: Boolean = true, // exposed for testing only, real trees are always pruned
parentUID: Option[String] = None): Array[DecisionTreeModel] = {

val timer = new TimeTracker()
Expand Down Expand Up @@ -223,22 +224,23 @@ private[spark] object RandomForest extends Logging {
case Some(uid) =>
if (strategy.algo == OldAlgo.Classification) {
topNodes.map { rootNode =>
new DecisionTreeClassificationModel(uid, rootNode.toNode, numFeatures,
new DecisionTreeClassificationModel(uid, rootNode.toNode(prune), numFeatures,
strategy.getNumClasses)
}
} else {
topNodes.map { rootNode =>
new DecisionTreeRegressionModel(uid, rootNode.toNode, numFeatures)
new DecisionTreeRegressionModel(uid, rootNode.toNode(prune), numFeatures)
}
}
case None =>
if (strategy.algo == OldAlgo.Classification) {
topNodes.map { rootNode =>
new DecisionTreeClassificationModel(rootNode.toNode, numFeatures,
new DecisionTreeClassificationModel(rootNode.toNode(prune), numFeatures,
strategy.getNumClasses)
}
} else {
topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode, numFeatures))
topNodes.map(rootNode =>
new DecisionTreeRegressionModel(rootNode.toNode(prune), numFeatures))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,44 +280,6 @@ class DecisionTreeClassifierSuite
dt.fit(df)
}

test("Use soft prediction for binary classification with ordered categorical features") {
// The following dataset is set up such that the best split is {1} vs. {0, 2}.
// If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen.
val arr = Array(
LabeledPoint(0.0, Vectors.dense(0.0)),
LabeledPoint(0.0, Vectors.dense(0.0)),
LabeledPoint(0.0, Vectors.dense(0.0)),
LabeledPoint(1.0, Vectors.dense(0.0)),
LabeledPoint(0.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(1.0, Vectors.dense(2.0)))
val data = sc.parallelize(arr)
val df = TreeTests.setMetadata(data, Map(0 -> 3), 2)

// Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
val dt = new DecisionTreeClassifier()
.setImpurity("gini")
.setMaxDepth(1)
.setMaxBins(3)
val model = dt.fit(df)
model.rootNode match {
case n: InternalNode =>
n.split match {
case s: CategoricalSplit =>
assert(s.leftCategories === Array(1.0))
case other =>
fail(s"All splits should be categorical, but got ${other.getClass.getName}: $other.")
}
case other =>
fail(s"Root node should be an internal node, but got ${other.getClass.getName}: $other.")
}
}

test("Feature importance with toy data") {
val dt = new DecisionTreeClassifier()
.setImpurity("gini")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.ml.tree.impl

import scala.annotation.tailrec
import scala.collection.mutable

import org.apache.spark.SparkFunSuite
Expand All @@ -38,6 +39,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {

import RandomForestSuite.mapToVec

private val seed = 42

/////////////////////////////////////////////////////////////////////////////
// Tests for split calculation
/////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -320,10 +323,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(topNode.isLeaf === false)
assert(topNode.stats === null)

val nodesForGroup = Map((0, Array(topNode)))
val treeToNodeToIndexInfo = Map((0, Map(
(topNode.id, new RandomForest.NodeIndexInfo(0, None))
)))
val nodesForGroup = Map(0 -> Array(topNode))
val treeToNodeToIndexInfo = Map(0 -> Map(
topNode.id -> new RandomForest.NodeIndexInfo(0, None)
))
val nodeStack = new mutable.ArrayStack[(Int, LearningNode)]
RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode),
nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack)
Expand Down Expand Up @@ -362,10 +365,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(topNode.isLeaf === false)
assert(topNode.stats === null)

val nodesForGroup = Map((0, Array(topNode)))
val treeToNodeToIndexInfo = Map((0, Map(
(topNode.id, new RandomForest.NodeIndexInfo(0, None))
)))
val nodesForGroup = Map(0 -> Array(topNode))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are fine, but I slightly prefer leaving stuff like this out. These aren't strictly style violations, and it distracts reviewers from the actual changes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was at my request, because we made that change in some new code, and then made a little sense to improve this also for consistency.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two tests previously moved here have now been moved back, there is still
"Use soft prediction for binary classification with ordered categorical features" to which I have applied @srowen 's comment, so the consistency argument still holds (even if weakened a bit).

val treeToNodeToIndexInfo = Map(0 -> Map(
topNode.id -> new RandomForest.NodeIndexInfo(0, None)
))
val nodeStack = new mutable.ArrayStack[(Int, LearningNode)]
RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode),
nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack)
Expand Down Expand Up @@ -407,7 +410,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3)

val model = RandomForest.run(input, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = 42, instr = None).head
seed = 42, instr = None, prune = false).head

model.rootNode match {
case n: InternalNode => n.split match {
case s: CategoricalSplit =>
Expand Down Expand Up @@ -631,13 +635,88 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
}

///////////////////////////////////////////////////////////////////////////////
// Tests for pruning of redundant subtrees (generated by a split improving the
// impurity measure, but always leading to the same prediction).
///////////////////////////////////////////////////////////////////////////////

test("SPARK-3159 tree model redundancy - classification") {
// The following dataset is set up such that splitting over feature_1 for points having
// feature_0 = 0 improves the impurity measure, despite the prediction will always be 0
// in both branches.
val arr = Array(
LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
LabeledPoint(0.0, Vectors.dense(1.0, 0.0)),
LabeledPoint(1.0, Vectors.dense(1.0, 1.0))
)
val rdd = sc.parallelize(arr)

val numClasses = 2
val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 4,
numClasses = numClasses, maxBins = 32)

val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto",
seed = 42, instr = None).head

val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto",
seed = 42, instr = None, prune = false).head

assert(prunedTree.numNodes === 5)
assert(unprunedTree.numNodes === 7)

assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size)
}

test("SPARK-3159 tree model redundancy - regression") {
// The following dataset is set up such that splitting over feature_0 for points having
// feature_1 = 1 improves the impurity measure, despite the prediction will always be 0.5
// in both branches.
val arr = Array(
LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
LabeledPoint(0.0, Vectors.dense(1.0, 0.0)),
LabeledPoint(1.0, Vectors.dense(1.0, 1.0)),
LabeledPoint(0.0, Vectors.dense(1.0, 1.0)),
LabeledPoint(0.5, Vectors.dense(1.0, 1.0))
)
val rdd = sc.parallelize(arr)

val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = Variance, maxDepth = 4,
numClasses = 0, maxBins = 32)

val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto",
seed = 42, instr = None).head

val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto",
seed = 42, instr = None, prune = false).head

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind adding a check in both tests to make sure that the count of all the leaf nodes sums to the total count (i.e. 6)? That way we make sure we don't lose information when merging the leaves? You can do it via leafNode.impurityStats.count.

assert(prunedTree.numNodes === 3)
assert(unprunedTree.numNodes === 5)
assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size)
}
}

private object RandomForestSuite {

def mapToVec(map: Map[Int, Double]): Vector = {
val size = (map.keys.toSeq :+ 0).max + 1
val (indices, values) = map.toSeq.sortBy(_._1).unzip
Vectors.sparse(size, indices.toArray, values.toArray)
}

@tailrec
private def getSumLeafCounters(nodes: List[Node], acc: Long = 0): Long =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to enclose the function body in curly braces

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I have added them

if (nodes.isEmpty) {
acc
}
else {
nodes.head match {
case i: InternalNode => getSumLeafCounters(i.leftChild :: i.rightChild :: nodes.tail, acc)
case l: LeafNode => getSumLeafCounters(nodes.tail, acc + l.impurityStats.count)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -363,10 +363,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
// if a split does not satisfy min instances per node requirements,
// this split is invalid, even though the information gain of split is large.
val arr = Array(
LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here again. You can fix this by inverting the labels. Probably an easier fix than moving and re-writing the test.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true , I have modified the input data for both tests as suggested, and "moved back" the two tests from .../ml/tree/impl/RandomForestSuite.scala to .../mllib/tree/DecisionTreeSuite.scala where they originally were. The whole suite of tests for mllib passes.

As a recap, 2 tests have been adapted by slightly changing the input data:

  • "Multiclass classification stump with 10-ary (ordered) categorical features"
  • "do not choose split that does not satisfy min instance per node requirements"

"Use soft prediction for binary classification with ordered categorical features" was present in two files:

  1. .../ml/classification/DecisionTreeClassifierSuite.scala
  2. .../ml/tree/impl/RandomForestSuite.scala

The one in 1. has been removed because it had to be adapted and it was redundant, while the one in 2. has been adapted following the same principle of other tests in that file such as "Avoid aggregation on the last level" test, for instance".

LabeledPoint(1.0, Vectors.dense(1.0, 1.0)),
LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
LabeledPoint(0.0, Vectors.dense(0.0, 0.0)))
LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
LabeledPoint(0.0, Vectors.dense(1.0, 1.0)),
LabeledPoint(1.0, Vectors.dense(0.0, 0.0)),
LabeledPoint(1.0, Vectors.dense(0.0, 0.0)))

val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini,
Expand Down Expand Up @@ -541,7 +541,7 @@ object DecisionTreeSuite extends SparkFunSuite {
Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](3000)
for (i <- 0 until 3000) {
if (i < 1000) {
if (i < 1001) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the type of thing that will puzzle someone down the line. I'm ok with it, though. 😝

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I have added a comment to explain the change:

[SPARK-3159] 1000 instead of 1001 to adapt "Multiclass classification stump with 10-ary (ordered) categorical features" test (different predictions prevent subtree pruning)

Could this be helpful?

arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0))
} else if (i < 2000) {
arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0))
Expand Down