@@ -900,19 +900,19 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
900900 def __init__ (self , featuresCol = "features" , labelCol = "label" , predictionCol = "prediction" ,
901901 maxDepth = 5 , maxBins = 32 , minInstancesPerNode = 1 , minInfoGain = 0.0 ,
902902 maxMemoryInMB = 256 , cacheNodeIds = False , checkpointInterval = 10 , lossType = "logistic" ,
903- maxIter = 20 , stepSize = 0.1 , seed = None ):
903+ maxIter = 20 , stepSize = 0.1 , seed = None , subsamplingRate = 1.0 ):
904904 """
905905 __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
906906 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
907907 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
908- lossType="logistic", maxIter=20, stepSize=0.1, seed=None)
908+ lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0 )
909909 """
910910 super (GBTClassifier , self ).__init__ ()
911911 self ._java_obj = self ._new_java_obj (
912912 "org.apache.spark.ml.classification.GBTClassifier" , self .uid )
913913 self ._setDefault (maxDepth = 5 , maxBins = 32 , minInstancesPerNode = 1 , minInfoGain = 0.0 ,
914914 maxMemoryInMB = 256 , cacheNodeIds = False , checkpointInterval = 10 ,
915- lossType = "logistic" , maxIter = 20 , stepSize = 0.1 )
915+ lossType = "logistic" , maxIter = 20 , stepSize = 0.1 , subsamplingRate = 1.0 )
916916 kwargs = self .__init__ ._input_kwargs
917917 self .setParams (** kwargs )
918918
@@ -921,12 +921,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
921921 def setParams (self , featuresCol = "features" , labelCol = "label" , predictionCol = "prediction" ,
922922 maxDepth = 5 , maxBins = 32 , minInstancesPerNode = 1 , minInfoGain = 0.0 ,
923923 maxMemoryInMB = 256 , cacheNodeIds = False , checkpointInterval = 10 ,
924- lossType = "logistic" , maxIter = 20 , stepSize = 0.1 , seed = None ):
924+ lossType = "logistic" , maxIter = 20 , stepSize = 0.1 , seed = None , subsamplingRate = 1.0 ):
925925 """
926926 setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
927927 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
928928 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
929- lossType="logistic", maxIter=20, stepSize=0.1, seed=None)
929+ lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0 )
930930 Sets params for Gradient Boosted Tree Classification.
931931 """
932932 kwargs = self .setParams ._input_kwargs
0 commit comments