Skip to content

Commit 0496413

Browse files
zhengruifengsrowen
authored andcommitted
[SPARK-30354][ML] GBT reuse DecisionTreeMetadata among iterations
### What changes were proposed in this pull request? precompute the `DecisionTreeMetadata` and reuse it for all trees ### Why are the changes needed? In existing impl, each `DecisionTreeRegressor` needs a pass on the whole dataset to calculate the same `DecisionTreeMetadata` repeatedly. In this PR, with default depth=5, it is about 8% faster then existing impl ### Does this PR introduce any user-facing change? No ### How was this patch tested? existing testsuites Closes apache#27011 from zhengruifeng/gbt_reuse_instr_meta. Authored-by: zhengruifeng <ruifengz@foxmail.com> Signed-off-by: Sean Owen <srowen@gmail.com>
1 parent 16e5e79 commit 0496413

File tree

2 files changed

+24
-7
lines changed

2 files changed

+24
-7
lines changed

mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.ml.tree.impl
2020
import org.apache.spark.internal.Logging
2121
import org.apache.spark.ml.feature.Instance
2222
import org.apache.spark.ml.linalg.Vector
23-
import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
23+
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
2424
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
2525
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy}
2626
import org.apache.spark.mllib.tree.impurity.{Variance => OldVariance}
@@ -306,8 +306,12 @@ private[spark] object GradientBoostedTrees extends Logging {
306306

307307
// Initialize tree
308308
timer.start("building tree 0")
309-
val firstTree = new DecisionTreeRegressor().setSeed(seed)
310-
val firstTreeModel = firstTree.train(input, treeStrategy, featureSubsetStrategy)
309+
val metadata = RandomForest.buildMetadata(input, treeStrategy,
310+
numTrees = 1, featureSubsetStrategy)
311+
val firstTreeModel = RandomForest.run(input, treeStrategy, numTrees = 1,
312+
featureSubsetStrategy, seed = seed, instr = None,
313+
parentUID = None, precomputedMetadata = Some(metadata))
314+
.head.asInstanceOf[DecisionTreeRegressionModel]
311315
val firstTreeWeight = 1.0
312316
baseLearners(0) = firstTreeModel
313317
baseLearnerWeights(0) = firstTreeWeight
@@ -342,8 +346,10 @@ private[spark] object GradientBoostedTrees extends Logging {
342346
logDebug("Gradient boosting tree iteration " + m)
343347
logDebug("###################################################")
344348

345-
val dt = new DecisionTreeRegressor().setSeed(seed + m)
346-
val model = dt.train(data, treeStrategy, featureSubsetStrategy)
349+
val model = RandomForest.run(data, treeStrategy, numTrees = 1,
350+
featureSubsetStrategy, seed = seed + m, instr = None,
351+
parentUID = None, precomputedMetadata = Some(metadata))
352+
.head.asInstanceOf[DecisionTreeRegressionModel]
347353
timer.stop(s"building tree $m")
348354
// Update partial model
349355
baseLearners(m) = model

mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,15 @@ private[spark] object RandomForest extends Logging with Serializable {
9999
run(instances, strategy, numTrees, featureSubsetStrategy, seed, None)
100100
}
101101

102+
def buildMetadata(
103+
input: RDD[Instance],
104+
strategy: OldStrategy,
105+
numTrees: Int,
106+
featureSubsetStrategy: String): DecisionTreeMetadata = {
107+
val retaggedInput = input.retag(classOf[Instance])
108+
DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
109+
}
110+
102111
/**
103112
* Train a random forest.
104113
*
@@ -113,7 +122,8 @@ private[spark] object RandomForest extends Logging with Serializable {
113122
seed: Long,
114123
instr: Option[Instrumentation],
115124
prune: Boolean = true, // exposed for testing only, real trees are always pruned
116-
parentUID: Option[String] = None): Array[DecisionTreeModel] = {
125+
parentUID: Option[String] = None,
126+
precomputedMetadata: Option[DecisionTreeMetadata] = None): Array[DecisionTreeModel] = {
117127

118128
val timer = new TimeTracker()
119129

@@ -122,8 +132,9 @@ private[spark] object RandomForest extends Logging with Serializable {
122132
timer.start("init")
123133

124134
val retaggedInput = input.retag(classOf[Instance])
125-
val metadata =
135+
val metadata = precomputedMetadata.getOrElse {
126136
DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
137+
}
127138

128139
instr match {
129140
case Some(instrumentation) =>

0 commit comments

Comments
 (0)