Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
fab2a0e
TST: create new test suite
facaiy Mar 31, 2017
f5d52cc
TST: helper method for construcing binary tree
facaiy Mar 31, 2017
b9248b7
TST: helper method, show tree node info
facaiy Mar 31, 2017
be12f4f
TST: helper method, check if pairs of leave with same prediction exists
facaiy Mar 31, 2017
b524202
TST: helper method for modifying nodes
facaiy Mar 31, 2017
98a73f9
ENH: merge the pairs of leave with same prediction of same parent
facaiy Mar 31, 2017
632325d
ENH: add mergeLeave param in Strategy
facaiy Apr 1, 2017
1205295
ENH: support mergeChild when training
facaiy Apr 1, 2017
434c762
ENH: add canMergeChildren param in DecisionTreeParams
facaiy Apr 1, 2017
5162552
ENH: add set method in tree classifier
facaiy Apr 1, 2017
21b1a85
ENH: stat: merge counts of each tree
facaiy Apr 1, 2017
25b712a
BUG: depth=0 tree has none of children
facaiy Apr 1, 2017
749dbd8
TST: add comment for test suite
facaiy Apr 1, 2017
fbd1c9a
TST: add iris dataset
facaiy Apr 5, 2017
f81e4e3
TST: helper method, check children of Node
facaiy Apr 5, 2017
129e6fe
TST: add unit test, check setCanMergeChildren for DecisionTreeClassifier
facaiy Apr 5, 2017
3f49146
TST: rename test case
facaiy Apr 5, 2017
1b42afb
BUG: fix for terminal node whose isLeaf is false
facaiy Apr 5, 2017
93ffa3f
CLN: format mergeCountsOfTreesInfo
facaiy Apr 5, 2017
a8351f8
CLN: fix for code style
facaiy Apr 24, 2017
701806f
BUG: mima, fix binary compatibility
facaiy Jul 5, 2017
a472c3a
merge latest master and reslove conflict
facaiy Jul 5, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ class DecisionTreeClassifier @Since("1.4.0") (
@Since("1.4.0")
override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)

/** @group expertSetParam */
@Since("2.2.0")
def setCanMergeChildren(value: Boolean): this.type = set(canMergeChildren, value)

/**
* Specifies how often to checkpoint the cached node IDs.
* E.g. 10 means that the cache will get checkpointed every 10 iterations.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ class GBTClassifier @Since("1.4.0") (
@Since("1.4.0")
override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)

/** @group expertSetParam */
@Since("2.2.0")
def setCanMergeChildren(value: Boolean): this.type = set(canMergeChildren, value)

/**
* Specifies how often to checkpoint the cached node IDs.
* E.g. 10 means that the cache will get checkpointed every 10 iterations.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ class RandomForestClassifier @Since("1.4.0") (
@Since("1.4.0")
override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)

/** @group expertSetParam */
@Since("2.2.0")
def setCanMergeChildren(value: Boolean): this.type = set(canMergeChildren, value)

/**
* Specifies how often to checkpoint the cached node IDs.
* E.g. 10 means that the cache will get checkpointed every 10 iterations.
Expand Down
47 changes: 47 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,53 @@ private[tree] object LearningNode {
new LearningNode(nodeIndex, None, None, None, false, null)
}

/** merge the pair of leave of same parent recursively if their prediction is same. */
def mergeChildrenWithSamePrediction(node: LearningNode): Int = {

/** find the mergeable leave and merge them. */
def checkAndMerge(node: LearningNode): (Option[Double], Int) = {
if (node == null) {
(None, 0)

} else if (node.isLeaf ||
// sometimes, node is terminal while its isLeaf is not set.
(node.leftChild.isEmpty && node.rightChild.isEmpty)) {
(Some(node.stats.impurityCalculator.predict), 0)

} else {
val (leftNode, leftMergeCounts) = checkAndMerge(node.leftChild.orNull)
val (rightNode, rightMergeCounts) = checkAndMerge(node.rightChild.orNull)
val mergeCounts = leftMergeCounts + rightMergeCounts

if (leftNode.isDefined && rightNode.isDefined && leftNode == rightNode) {
removeChildren(node)
node.isLeaf = true

(Some(node.stats.impurityCalculator.predict), mergeCounts + 1)

} else {
(None, mergeCounts)
}
}
}

val (_, mergeCounts) = checkAndMerge(node)

mergeCounts
}

/** delete all children of one node. */
def removeChildren(learningNode: LearningNode): Unit = {
val left = learningNode.leftChild
val right = learningNode.rightChild

learningNode.leftChild = None
learningNode.rightChild = None

left.foreach(removeChildren)
right.foreach(removeChildren)
}

// The below indexing methods were copied from spark.mllib.tree.model.Node

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,23 @@ private[spark] object RandomForest extends Logging {
}
}

// prune tree.
if (strategy.canMergeChildren) {
logInfo("Merge children with same prediction:")

// merge the pair of leaves if possible.
val mergeCounts = topNodes.map(LearningNode.mergeChildrenWithSamePrediction)

val mergeCountsOfTreesInfo: String = mergeCounts
.zipWithIndex
.map(_.swap)
.filter(_._2 > 0)
.map { case (id, count) => s"tree: $id, num of nodes merged: $count" }
.mkString("Merge info:\n", "\n", "\n")

logInfo(mergeCountsOfTreesInfo)
}

val numFeatures = metadata.numFeatures

parentUID match {
Expand Down
16 changes: 15 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,18 @@ private[ml] trait DecisionTreeParams extends PredictorParams
" algorithm will cache node IDs for each instance. Caching can speed up training of deeper" +
" trees.")

/**
* If true, tree will try to merge its leaf nodes which share the same parent and output
* the same prediction.
* (default = false)
* @group expertParam
*/
final val canMergeChildren: BooleanParam = new BooleanParam(this, "canMergeChildren", "If true," +
" the tree will try to merge its leaf nodes which share the same parent and output.")

setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0,
maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10)
maxMemoryInMB -> 256, cacheNodeIds -> false, canMergeChildren -> false,
checkpointInterval -> 10)

/**
* @deprecated This method is deprecated and will be removed in 3.0.0.
Expand Down Expand Up @@ -176,6 +186,9 @@ private[ml] trait DecisionTreeParams extends PredictorParams
/** @group expertGetParam */
final def getCacheNodeIds: Boolean = $(cacheNodeIds)

/** @group expertGetParam */
final def getCanMergeChildren: Boolean = $(canMergeChildren)

/**
* @deprecated This method is deprecated and will be removed in 3.0.0.
* @group setParam
Expand All @@ -199,6 +212,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams
strategy.minInfoGain = getMinInfoGain
strategy.minInstancesPerNode = getMinInstancesPerNode
strategy.useNodeIdCache = getCacheNodeIds
strategy.canMergeChildren = getCanMergeChildren
strategy.numClasses = numClasses
strategy.categoricalFeaturesInfo = categoricalFeatures
strategy.subsamplingRate = subsamplingRate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
* @param subsamplingRate Fraction of the training data used for learning decision tree.
* @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will
* maintain a separate RDD of node Id cache for each row.
* @param canMergeChildren Merge pairs of leaf nodes of the same parent which
Copy link
Contributor Author

@facaiy facaiy Apr 26, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A new parameter is added in Strategy class, which fails Mima tests. How to deal with it?

java.lang.RuntimeException: spark-mllib: Binary compatibility check failed!

[error]  * synthetic method <init>$default$13()Int in object org.apache.spark.mllib.tree.configuration.Strategy has a different result type in current version, where it is Boolean rather than Int

see failed logs

* output the same prediction.
* @param checkpointInterval How often to checkpoint when the node Id cache gets updated.
* E.g. 10 means that the cache will get checkpointed every 10 updates. If
* the checkpoint directory is not set in
Expand All @@ -80,6 +82,7 @@ class Strategy @Since("1.3.0") (
@Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256,
@Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1,
@Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false,
@Since("2.2.0") @BeanProperty var canMergeChildren: Boolean = false,
@Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10) extends Serializable {

/**
Expand Down Expand Up @@ -172,7 +175,7 @@ class Strategy @Since("1.3.0") (
def copy: Strategy = {
new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain,
maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval)
maxMemoryInMB, subsamplingRate, useNodeIdCache, canMergeChildren, checkpointInterval)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode}
import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode, Node}
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
Expand Down Expand Up @@ -350,6 +350,31 @@ class DecisionTreeClassifierSuite
dt.fit(df)
}

test("SPARK-3159: Check for reducible DecisionTree") {
import DecisionTreeClassifierSuite.hasPairsOfSameChildren

val df: DataFrame = TreeTests.getIrisDataset(sc).toDF()
val dt = new DecisionTreeClassifier()
.setImpurity("Gini")
.setMaxBins(2)
.setMinInfoGain(0)
.setMinInstancesPerNode(1)
.setMaxDepth(3)
.setSeed(0)

val m1 = dt.setCanMergeChildren(false).fit(df)
assert(true === hasPairsOfSameChildren(m1.rootNode))

val m2 = dt.setCanMergeChildren(true).fit(df)
assert(false === hasPairsOfSameChildren(m2.rootNode))

val p1 = m1.transform(df).select(m1.getPredictionCol).collect()
val p2 = m2.transform(df).select(m2.getPredictionCol).collect()
assert(p1.zip(p2).forall { case (y1, y2) =>
y1.getDouble(0) === y2.getDouble(0)
})
}

/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -422,4 +447,25 @@ private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite {
TreeTests.checkEqual(oldTreeAsNew, newTree)
assert(newTree.numFeatures === numFeatures)
}

/** check if there exists pairs of leaf nodes with same prediction of the same parent. */
def hasPairsOfSameChildren(node: Node): Boolean = {
def check(node: Node): (Boolean, Option[Double]) = node match {
case n: LeafNode => (false, Some(n.prediction))
case n: InternalNode =>
val (leftFound, leftPredict) = check(n.leftChild)
val (rightFound, rightPredict) = check(n.rightChild)

if (leftFound || rightFound ||
(leftPredict.isDefined && leftPredict == rightPredict)) {
(true, None)
} else {
(false, None)
}
}

val (found, _) = check(node)

found
}
}
Loading