Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ class GBTClassifier @Since("1.4.0") (
} else {
GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy))
}
baseLearners.foreach(copyValues(_))

val numFeatures = baseLearners.head.numFeatures
instr.logNumFeatures(numFeatures)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ class RandomForestClassifier @Since("1.4.0") (
val trees = RandomForest
.run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
.map(_.asInstanceOf[DecisionTreeClassificationModel])
trees.foreach(copyValues(_))

val numFeatures = trees.head.numFeatures
instr.logNumClasses(numClasses)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
GradientBoostedTrees.run(trainDataset, boostingStrategy,
$(seed), $(featureSubsetStrategy))
}
baseLearners.foreach(copyValues(_))

val numFeatures = baseLearners.head.numFeatures
instr.logNumFeatures(numFeatures)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
val trees = RandomForest
.run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
.map(_.asInstanceOf[DecisionTreeRegressionModel])
trees.foreach(copyValues(_))

val numFeatures = trees.head.numFeatures
instr.logNamedValue(Instrumentation.loggerTags.numFeatures, numFeatures)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,22 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
}
}

test("tree params") {
val categoricalFeatures = Map.empty[Int, Int]
val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2)
val gbt = new GBTClassifier()
.setMaxDepth(2)
.setCheckpointInterval(5)
.setSeed(123)
val model = gbt.fit(df)

model.trees.foreach (i => {
assert(i.getMaxDepth === model.getMaxDepth)
assert(i.getCheckpointInterval === model.getCheckpointInterval)
assert(i.getSeed === model.getSeed)
})
}

/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,26 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest {
}
}

test("tree params") {
val rdd = orderedLabeledPoints5_20
val rf = new RandomForestClassifier()
.setImpurity("entropy")
.setMaxDepth(3)
.setNumTrees(3)
.setSeed(123)
val categoricalFeatures = Map.empty[Int, Int]
val numClasses = 2

val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
val model = rf.fit(df)

model.trees.foreach (i => {
assert(i.getMaxDepth === model.getMaxDepth)
assert(i.getSeed === model.getSeed)
assert(i.getImpurity === model.getImpurity)
})
}

/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,21 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest {
}
}

/////////////////////////////////////////////////////////////////////////////
test("tree params") {
val gbt = new GBTRegressor()
.setMaxDepth(2)
.setCheckpointInterval(5)
.setSeed(123)
val model = gbt.fit(trainData.toDF)

model.trees.foreach (i => {
assert(i.getMaxDepth === model.getMaxDepth)
assert(i.getCheckpointInterval === model.getCheckpointInterval)
assert(i.getSeed === model.getSeed)
})
}

/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,25 @@ class RandomForestRegressorSuite extends MLTest with DefaultReadWriteTest{
}
}

test("tree params") {
val rf = new RandomForestRegressor()
.setImpurity("variance")
.setMaxDepth(2)
.setMaxBins(10)
.setNumTrees(3)
.setSeed(123)

val df = orderedLabeledPoints50_1000.toDF()
val model = rf.fit(df)

model.trees.foreach (i => {
assert(i.getMaxDepth === model.getMaxDepth)
assert(i.getSeed === model.getSeed)
assert(i.getImpurity === model.getImpurity)
assert(i.getMaxBins === model.getMaxBins)
})
}

/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
Expand Down