diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 09f81b0dcbda..74624be360c6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -203,6 +203,7 @@ class GBTClassifier @Since("1.4.0") ( } else { GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy)) } + baseLearners.foreach(copyValues(_)) val numFeatures = baseLearners.head.numFeatures instr.logNumFeatures(numFeatures) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 731b43b67813..245cda35d8ad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -143,6 +143,7 @@ class RandomForestClassifier @Since("1.4.0") ( val trees = RandomForest .run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr)) .map(_.asInstanceOf[DecisionTreeClassificationModel]) + trees.foreach(copyValues(_)) val numFeatures = trees.head.numFeatures instr.logNumClasses(numClasses) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 00c0bc9f5e28..0cc06d82bf3f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -181,6 +181,7 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy)) } + baseLearners.foreach(copyValues(_)) val numFeatures = baseLearners.head.numFeatures instr.logNumFeatures(numFeatures) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 938aa5acac08..8f78fc1da18c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -130,6 +130,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S val trees = RandomForest .run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr)) .map(_.asInstanceOf[DecisionTreeRegressionModel]) + trees.foreach(copyValues(_)) val numFeatures = trees.head.numFeatures instr.logNamedValue(Instrumentation.loggerTags.numFeatures, numFeatures) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index af3dd201d3b5..530ca20d0eb0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -456,6 +456,22 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { } } + test("tree params") { + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2) + val gbt = new GBTClassifier() + .setMaxDepth(2) + .setCheckpointInterval(5) + .setSeed(123) + val model = gbt.fit(df) + + model.trees.foreach (i => { + assert(i.getMaxDepth === model.getMaxDepth) + assert(i.getCheckpointInterval === model.getCheckpointInterval) + assert(i.getSeed === model.getSeed) + }) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index f03ed0b76eb8..5958bfcf5ea6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -230,6 +230,26 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest { } } + test("tree params") { + val rdd = orderedLabeledPoints5_20 + val rf = new RandomForestClassifier() + .setImpurity("entropy") + .setMaxDepth(3) + .setNumTrees(3) + .setSeed(123) + val categoricalFeatures = Map.empty[Int, Int] + val numClasses = 2 + + val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) + val model = rf.fit(df) + + model.trees.foreach (i => { + assert(i.getMaxDepth === model.getMaxDepth) + assert(i.getSeed === model.getSeed) + assert(i.getImpurity === model.getImpurity) + }) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 60007975c3b5..e2462af2ac1d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -296,7 +296,21 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { } } - ///////////////////////////////////////////////////////////////////////////// + test("tree params") { + val gbt = new GBTRegressor() + .setMaxDepth(2) + .setCheckpointInterval(5) + .setSeed(123) + val model = gbt.fit(trainData.toDF) + + model.trees.foreach (i => { + assert(i.getMaxDepth === model.getMaxDepth) + assert(i.getCheckpointInterval === model.getCheckpointInterval) + assert(i.getSeed === model.getSeed) + }) + } + + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 0243e8d2335e..f3b0f0470e57 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -139,6 +139,25 @@ class RandomForestRegressorSuite extends MLTest with DefaultReadWriteTest{ } } + test("tree params") { + val rf = new RandomForestRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(10) + .setNumTrees(3) + .setSeed(123) + + val df = orderedLabeledPoints50_1000.toDF() + val model = rf.fit(df) + + model.trees.foreach (i => { + assert(i.getMaxDepth === model.getMaxDepth) + assert(i.getSeed === model.getSeed) + assert(i.getImpurity === model.getImpurity) + assert(i.getMaxBins === model.getMaxBins) + }) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load /////////////////////////////////////////////////////////////////////////////