Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Exposed pruning parameter accessible in Scala WIP
  • Loading branch information
CBribiescas committed Jun 14, 2021
commit fb835db4bae2dcd2c05ff4408ddc0252353ac569
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ class RandomForestClassifier @Since("1.4.0") (
@Since("1.4.0")
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)

/** @group setParam */
@Since("3.1.2")
def setPruneTree(value: Boolean): this.type = set(pruneTree, value)

/** @group expertSetParam */
@Since("1.4.0")
def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
Expand Down Expand Up @@ -152,10 +156,11 @@ class RandomForestClassifier @Since("1.4.0") (
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
strategy.bootstrap = $(bootstrap)
strategy.pruneTree = $(pruneTree)

instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, probabilityCol,
rawPredictionCol, leafCol, impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins,
maxMemoryInMB, minInfoGain, minInstancesPerNode, minWeightFractionPerNode, seed,
maxMemoryInMB, minInfoGain, pruneTree, minInstancesPerNode, minWeightFractionPerNode, seed,
subsamplingRate, thresholds, cacheNodeIds, checkpointInterval, bootstrap)

val trees = RandomForest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ private[spark] object RandomForest extends Logging with Serializable {
featureSubsetStrategy: String,
seed: Long,
instr: Option[Instrumentation],
prune: Boolean = false, // exposed for testing only, real trees are always not pruned
parentUID: Option[String] = None): Array[DecisionTreeModel] = {
val timer = new TimeTracker()
timer.start("total")
Expand Down Expand Up @@ -245,26 +244,26 @@ private[spark] object RandomForest extends Logging with Serializable {
topNodes.map { rootNode =>
new DecisionTreeClassificationModel(
uid,
rootNode.toNode(prune),
rootNode.toNode(strategy.pruneTree),
numFeatures,
strategy.getNumClasses)
}
} else {
topNodes.map { rootNode =>
new DecisionTreeRegressionModel(uid, rootNode.toNode(prune), numFeatures)
new DecisionTreeRegressionModel(uid, rootNode.toNode(strategy.pruneTree), numFeatures)
}
}
case None =>
if (strategy.algo == OldAlgo.Classification) {
topNodes.map { rootNode =>
new DecisionTreeClassificationModel(
rootNode.toNode(prune),
rootNode.toNode(strategy.pruneTree),
numFeatures,
strategy.getNumClasses)
}
} else {
topNodes.map(rootNode =>
new DecisionTreeRegressionModel(rootNode.toNode(prune), numFeatures))
new DecisionTreeRegressionModel(rootNode.toNode(strategy.pruneTree), numFeatures))
}
}
}
Expand All @@ -282,7 +281,6 @@ private[spark] object RandomForest extends Logging with Serializable {
featureSubsetStrategy: String,
seed: Long,
instr: Option[Instrumentation],
prune: Boolean = false, // exposed for testing only, real trees are always not pruned
parentUID: Option[String] = None): Array[DecisionTreeModel] = {
val timer = new TimeTracker()

Expand Down Expand Up @@ -331,7 +329,6 @@ private[spark] object RandomForest extends Logging with Serializable {
featureSubsetStrategy = featureSubsetStrategy,
seed = seed,
instr = instr,
prune = prune,
parentUID = parentUID)

baggedInput.unpersist()
Expand Down Expand Up @@ -973,8 +970,9 @@ private[spark] object RandomForest extends Logging with Serializable {
val numCategories = binAggregates.metadata.numBins(featureIndex)

/* Each bin is one category (feature value).
* The bins are ordered based on centroidForCategories, and this ordering determines which
* splits are considered. (With K categories, we consider K - 1 possible splits.)
* The bins are ordered based on centroidForCategories, and this ordering determines
* which splits are considered. (With K categories, we
* consider K - 1 possible splits.)
*
* centroidForCategories is a list: (category, centroid)
*/
Expand Down
23 changes: 22 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,23 @@ private[ml] trait DecisionTreeParams extends PredictorParams
" discretizing continuous features. Must be at least 2 and at least number of categories" +
" for any categorical feature.", ParamValidators.gtEq(2))

/**
* If true, the trained tree will undergo a 'pruning' process after training in which nodes
* that have the same class predictions will be merged. The benefit being that at prediction
* time the tree will be 'leaner'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the key point for users is that this is a good thing if only interested in class predictions. If interested in class probabilities, it won't necessarily give the right result and should be set to false. The text here is fine just wanting to make the tradeoffs explicit. I.e. leaner means smaller and faster.

* If false, the post-training tree will undergo no pruning. The benefit being that you
* maintain the class prediction probabilities
* (default = false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think as a first step we should keep it to 'true'.

* @group param
*/
final val pruneTree: BooleanParam = new BooleanParam(this, "pruneTree", "" +
"If true, the trained tree will undergo a 'pruning' process after training in which nodes" +
" that have the same class predictions will be merged. The benefit being that at prediction" +
" time the tree will be 'leaner'" +
" If false, the post-training tree will undergo no pruning. The benefit being that you" +
" maintain the class prediction probabilities"
)

/**
* Minimum number of instances each child must have after split.
* If a split causes the left or right child to have fewer than minInstancesPerNode,
Expand Down Expand Up @@ -137,7 +154,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams
" trees.")

setDefault(leafCol -> "", maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1,
minWeightFractionPerNode -> 0.0, minInfoGain -> 0.0, maxMemoryInMB -> 256,
minWeightFractionPerNode -> 0.0, minInfoGain -> 0.0, pruneTree -> false, maxMemoryInMB -> 256,
cacheNodeIds -> false, checkpointInterval -> 10)

/** @group setParam */
Expand All @@ -163,6 +180,9 @@ private[ml] trait DecisionTreeParams extends PredictorParams
/** @group getParam */
final def getMinInfoGain: Double = $(minInfoGain)

/** @group getParam */
final def getPruneTree: Boolean = $(pruneTree)

/** @group expertGetParam */
final def getMaxMemoryInMB: Int = $(maxMemoryInMB)

Expand All @@ -183,6 +203,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams
strategy.maxDepth = getMaxDepth
strategy.maxMemoryInMB = getMaxMemoryInMB
strategy.minInfoGain = getMinInfoGain
strategy.pruneTree = getPruneTree
strategy.minInstancesPerNode = getMinInstancesPerNode
strategy.minWeightFractionPerNode = getMinWeightFractionPerNode
strategy.useNodeIdCache = getCacheNodeIds
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
* @param minInfoGain Minimum information gain a split must get. Default value is 0.0.
* If a split has less information gain than minInfoGain,
* this split will not be considered as a valid split.
* @param pruneTree <TODO>
* @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is
* 256 MB. If too small, then 1 node will be split per iteration, and
* its aggregates may exceed this size.
Expand All @@ -77,6 +78,7 @@ class Strategy @Since("1.3.0") (
@Since("1.0.0") @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
@Since("1.2.0") @BeanProperty var minInstancesPerNode: Int = 1,
@Since("1.2.0") @BeanProperty var minInfoGain: Double = 0.0,
@Since("3.1.2") @BeanProperty var pruneTree: Boolean = false,
@Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256,
@Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1,
@Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false,
Expand Down Expand Up @@ -113,12 +115,13 @@ class Strategy @Since("1.3.0") (
categoricalFeaturesInfo: Map[Int, Int],
minInstancesPerNode: Int,
minInfoGain: Double,
pruneTree: Boolean,
maxMemoryInMB: Int,
subsamplingRate: Double,
useNodeIdCache: Boolean,
checkpointInterval: Int) = {
this(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy,
categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, maxMemoryInMB,
categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, pruneTree, maxMemoryInMB,
subsamplingRate, useNodeIdCache, checkpointInterval, 0.0)
}
// scalastyle:on argcount
Expand Down Expand Up @@ -200,7 +203,7 @@ class Strategy @Since("1.3.0") (
def copy: Strategy = {
new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode,
minInfoGain, maxMemoryInMB, subsamplingRate, useNodeIdCache,
minInfoGain, pruneTree, maxMemoryInMB, subsamplingRate, useNodeIdCache,
checkpointInterval, minWeightFractionPerNode)
}
}
Expand Down