Skip to content

Commit 9dc9f9a

Browse files
zhengruifengyanboliang
authored andcommitted
[SPARK-18177][ML][PYSPARK] Add missing 'subsamplingRate' of pyspark GBTClassifier
## What changes were proposed in this pull request? Add missing 'subsamplingRate' of pyspark GBTClassifier ## How was this patch tested? existing tests Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #15692 from zhengruifeng/gbt_subsamplingRate.
1 parent 0ea5d5b commit 9dc9f9a

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

python/pyspark/ml/classification.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -900,19 +900,19 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
900900
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
901901
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
902902
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic",
903-
maxIter=20, stepSize=0.1, seed=None):
903+
maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0):
904904
"""
905905
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
906906
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
907907
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
908-
lossType="logistic", maxIter=20, stepSize=0.1, seed=None)
908+
lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0)
909909
"""
910910
super(GBTClassifier, self).__init__()
911911
self._java_obj = self._new_java_obj(
912912
"org.apache.spark.ml.classification.GBTClassifier", self.uid)
913913
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
914914
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
915-
lossType="logistic", maxIter=20, stepSize=0.1)
915+
lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0)
916916
kwargs = self.__init__._input_kwargs
917917
self.setParams(**kwargs)
918918

@@ -921,12 +921,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
921921
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
922922
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
923923
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
924-
lossType="logistic", maxIter=20, stepSize=0.1, seed=None):
924+
lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0):
925925
"""
926926
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
927927
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
928928
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
929-
lossType="logistic", maxIter=20, stepSize=0.1, seed=None)
929+
lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0)
930930
Sets params for Gradient Boosted Tree Classification.
931931
"""
932932
kwargs = self.setParams._input_kwargs

0 commit comments

Comments
 (0)