Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added to decision tree classifier and to python
  • Loading branch information
CBribiescas committed Jun 20, 2021
commit 43ee8529f8128689262a291813b51b207cb39874
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ class DecisionTreeClassifier @Since("1.4.0") (
@Since("1.4.0")
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)

/** @group setParam */
@Since("3.1.2")
def setPruneTree(value: Boolean): this.type = set(pruneTree, value)

/** @group expertSetParam */
@Since("1.4.0")
def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
Expand Down Expand Up @@ -126,9 +130,11 @@ class DecisionTreeClassifier @Since("1.4.0") (
val instances = extractInstances(dataset, numClasses)
val strategy = getOldStrategy(categoricalFeatures, numClasses)
require(!strategy.bootstrap, "DecisionTreeClassifier does not need bootstrap sampling")
strategy.pruneTree = $(pruneTree)

instr.logNumClasses(numClasses)
instr.logParams(this, labelCol, featuresCol, predictionCol, rawPredictionCol,
probabilityCol, leafCol, maxDepth, maxBins, minInstancesPerNode, minInfoGain,
probabilityCol, leafCol, maxDepth, maxBins, minInstancesPerNode, minInfoGain, pruneTree,
maxMemoryInMB, cacheNodeIds, checkpointInterval, impurity, seed, thresholds)

val trees = RandomForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all",
Expand Down
47 changes: 24 additions & 23 deletions mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
Original file line number Diff line number Diff line change
Expand Up @@ -75,23 +75,6 @@ private[ml] trait DecisionTreeParams extends PredictorParams
" discretizing continuous features. Must be at least 2 and at least number of categories" +
" for any categorical feature.", ParamValidators.gtEq(2))

/**
* If true, the trained tree will undergo a 'pruning' process after training in which nodes
* that have the same class predictions will be merged. The benefit being that at prediction
* time the tree will be 'leaner'
* If false, the post-training tree will undergo no pruning. The benefit being that you
* maintain the class prediction probabilities
* (default = false)
* @group param
*/
final val pruneTree: BooleanParam = new BooleanParam(this, "pruneTree", "" +
"If true, the trained tree will undergo a 'pruning' process after training in which nodes" +
" that have the same class predictions will be merged. The benefit being that at prediction" +
" time the tree will be 'leaner'" +
" If false, the post-training tree will undergo no pruning. The benefit being that you" +
" maintain the class prediction probabilities"
)

/**
* Minimum number of instances each child must have after split.
* If a split causes the left or right child to have fewer than minInstancesPerNode,
Expand Down Expand Up @@ -154,7 +137,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams
" trees.")

setDefault(leafCol -> "", maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1,
minWeightFractionPerNode -> 0.0, minInfoGain -> 0.0, pruneTree -> false, maxMemoryInMB -> 256,
minWeightFractionPerNode -> 0.0, minInfoGain -> 0.0, maxMemoryInMB -> 256,
cacheNodeIds -> false, checkpointInterval -> 10)

/** @group setParam */
Expand All @@ -180,9 +163,6 @@ private[ml] trait DecisionTreeParams extends PredictorParams
/** @group getParam */
final def getMinInfoGain: Double = $(minInfoGain)

/** @group getParam */
final def getPruneTree: Boolean = $(pruneTree)

/** @group expertGetParam */
final def getMaxMemoryInMB: Int = $(maxMemoryInMB)

Expand All @@ -203,7 +183,6 @@ private[ml] trait DecisionTreeParams extends PredictorParams
strategy.maxDepth = getMaxDepth
strategy.maxMemoryInMB = getMaxMemoryInMB
strategy.minInfoGain = getMinInfoGain
strategy.pruneTree = getPruneTree
strategy.minInstancesPerNode = getMinInstancesPerNode
strategy.minWeightFractionPerNode = getMinWeightFractionPerNode
strategy.useNodeIdCache = getCacheNodeIds
Expand Down Expand Up @@ -232,10 +211,32 @@ private[ml] trait TreeClassifierParams extends Params {
(value: String) =>
TreeClassifierParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT)))

setDefault(impurity -> "gini")
/**
* If true, the trained tree will undergo a 'pruning' process after training in which nodes
* that have the same class predictions will be merged. This drawback means that the class
* probabilities will be lost. The benefit being that at prediction time the tree will be
* smaller and have faster predictions
* If false, the post-training tree will undergo no pruning. The benefit being that you
* maintain the class prediction probabilities
* (default = true)
* @group param
*/
final val pruneTree: BooleanParam = new BooleanParam(this, "pruneTree", "" +
"If true, the trained tree will undergo a 'pruning' process after training in which nodes" +
" that have the same class predictions will be merged. This drawback means that the class" +
" probabilities will be lost. The benefit being that at prediction time the tree will be" +
" smaller and have faster predictions" +
" If false, the post-training tree will undergo no pruning. The benefit being that you" +
" maintain the class prediction probabilities"
)

// HERE
setDefault(impurity -> "gini", pruneTree -> true)

/** @group getParam */
final def getImpurity: String = $(impurity).toLowerCase(Locale.ROOT)
/** @group getParam */
final def getPruneTree: Boolean = $(pruneTree)

/** Convert new impurity to old impurity. */
private[ml] def getOldImpurity: OldImpurity = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.tree._
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.tree.{DecisionTreeSuite => OldDTSuite, EnsembleTestHelper}
import org.apache.spark.mllib.tree.configuration.{
Algo => OldAlgo,
QuantileStrategy,
Strategy => OldStrategy
}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator, Variance}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.util.collection.OpenHashMap
Expand Down
32 changes: 22 additions & 10 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,7 +1337,7 @@ class _DecisionTreeClassifierParams(_DecisionTreeParams, _TreeClassifierParams):

def __init__(self, *args):
super(_DecisionTreeClassifierParams, self).__init__(*args)
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, pruneTree=True,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
impurity="gini", leafCol="", minWeightFractionPerNode=0.0)

Expand Down Expand Up @@ -1428,13 +1428,13 @@ class DecisionTreeClassifier(_JavaProbabilisticClassifier, _DecisionTreeClassifi
@keyword_only
def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
probabilityCol="probability", rawPredictionCol="rawPrediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, pruneTree=True,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0):
"""
__init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
probabilityCol="probability", rawPredictionCol="rawPrediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, pruneTree=True, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0)
"""
Expand All @@ -1448,14 +1448,14 @@ def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="p
@since("1.4.0")
def setParams(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
probabilityCol="probability", rawPredictionCol="rawPrediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, pruneTree=True,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
impurity="gini", seed=None, weightCol=None, leafCol="",
minWeightFractionPerNode=0.0):
"""
setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
probabilityCol="probability", rawPredictionCol="rawPrediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, pruneTree=True\
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0)
Sets params for the DecisionTreeClassifier.
Expand All @@ -1478,6 +1478,12 @@ def setMaxBins(self, value):
"""
return self._set(maxBins=value)

def setPruneTree(self, value):
"""
Sets the value of :py:attr:`pruneTree`.
"""
return self._set(pruneTree=value)

def setMinInstancesPerNode(self, value):
"""
Sets the value of :py:attr:`minInstancesPerNode`.
Expand Down Expand Up @@ -1580,7 +1586,7 @@ class _RandomForestClassifierParams(_RandomForestParams, _TreeClassifierParams):

def __init__(self, *args):
super(_RandomForestClassifierParams, self).__init__(*args)
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, pruneTree=True,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
impurity="gini", numTrees=20, featureSubsetStrategy="auto",
subsamplingRate=1.0, leafCol="", minWeightFractionPerNode=0.0,
Expand Down Expand Up @@ -1667,14 +1673,14 @@ class RandomForestClassifier(_JavaProbabilisticClassifier, _RandomForestClassifi
@keyword_only
def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
probabilityCol="probability", rawPredictionCol="rawPrediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, pruneTree=True,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0,
leafCol="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True):
"""
__init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
probabilityCol="probability", rawPredictionCol="rawPrediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, pruneTree=True\
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0, \
leafCol="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True)
Expand All @@ -1689,14 +1695,14 @@ def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="p
@since("1.4.0")
def setParams(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
probabilityCol="probability", rawPredictionCol="rawPrediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, pruneTree=True,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None,
impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0,
leafCol="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
probabilityCol="probability", rawPredictionCol="rawPrediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, pruneTree=True, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \
impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0, \
leafCol="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True)
Expand All @@ -1720,6 +1726,12 @@ def setMaxBins(self, value):
"""
return self._set(maxBins=value)

def setPruneTree(self, value):
"""
Sets the value of :py:attr:`pruneTree`.
"""
return self._set(pruneTree=value)

def setMinInstancesPerNode(self, value):
"""
Sets the value of :py:attr:`minInstancesPerNode`.
Expand Down
14 changes: 14 additions & 0 deletions python/pyspark/ml/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,14 @@ class _TreeClassifierParams(Params):
"Supported options: " +
", ".join(supportedImpurities), typeConverter=TypeConverters.toString)

pruneTree = Param(Params._dummy(), "pruneTree", "" +
"If true, the trained tree will undergo a 'pruning' process after training in which nodes" +
" that have the same class predictions will be merged. This drawback means that the class" +
" probabilities will be lost. The benefit being that at prediction time the tree will be" +
" smaller and have faster predictions" +
" If false, the post-training tree will undergo no pruning. The benefit being that you" +
" maintain the class prediction probabilities", typeConverter=TypeConverters.toBoolean)

def __init__(self):
super(_TreeClassifierParams, self).__init__()

Expand All @@ -347,6 +355,12 @@ def getImpurity(self):
Gets the value of impurity or its default value.
"""
return self.getOrDefault(self.impurity)
@since("3.1.2")
def getPruneTree(self):
"""
Gets the value of pruneTree or its default value.
"""
return self.getOrDefault(self.pruneTree)


class _TreeRegressorParams(_HasVarianceImpurity):
Expand Down