diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 4a2bc19426ef3..2f2136268b724 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -72,6 +72,10 @@ class DecisionTreeClassifier @Since("1.4.0") ( @Since("1.4.0") def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + /** @group setParam */ + @Since("3.1.2") + def setPruneTree(value: Boolean): this.type = set(pruneTree, value) + /** @group expertSetParam */ @Since("1.4.0") def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) @@ -126,9 +130,11 @@ class DecisionTreeClassifier @Since("1.4.0") ( val instances = extractInstances(dataset, numClasses) val strategy = getOldStrategy(categoricalFeatures, numClasses) require(!strategy.bootstrap, "DecisionTreeClassifier does not need bootstrap sampling") + strategy.pruneTree = $(pruneTree) + instr.logNumClasses(numClasses) instr.logParams(this, labelCol, featuresCol, predictionCol, rawPredictionCol, - probabilityCol, leafCol, maxDepth, maxBins, minInstancesPerNode, minInfoGain, + probabilityCol, leafCol, maxDepth, maxBins, minInstancesPerNode, minInfoGain, pruneTree, maxMemoryInMB, cacheNodeIds, checkpointInterval, impurity, seed, thresholds) val trees = RandomForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all", diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index f9ce62b91924b..e96d3985d2001 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -75,6 +75,10 @@ class RandomForestClassifier @Since("1.4.0") ( @Since("1.4.0") def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + /** @group setParam */ + @Since("3.1.2") + def setPruneTree(value: Boolean): this.type = set(pruneTree, value) + /** @group expertSetParam */ @Since("1.4.0") def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) @@ -152,10 +156,11 @@ class RandomForestClassifier @Since("1.4.0") ( val strategy = super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity) strategy.bootstrap = $(bootstrap) + strategy.pruneTree = $(pruneTree) instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, probabilityCol, rawPredictionCol, leafCol, impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, - maxMemoryInMB, minInfoGain, minInstancesPerNode, minWeightFractionPerNode, seed, + maxMemoryInMB, minInfoGain, pruneTree, minInstancesPerNode, minWeightFractionPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval, bootstrap) val trees = RandomForest diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index b6bc7aaeed628..51215f693029a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -38,7 +38,6 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} - /** * ALGORITHM * @@ -94,8 +93,9 @@ private[spark] object RandomForest extends Logging with Serializable { numTrees: Int, featureSubsetStrategy: String, seed: Long): Array[DecisionTreeModel] = { - val instances = input.map { case LabeledPoint(label, features) => - Instance(label, 1.0, features.asML) + val instances = input.map { + case LabeledPoint(label, features) => + Instance(label, 1.0, features.asML) } run(instances, strategy, numTrees, featureSubsetStrategy, seed, None) } @@ -117,7 +117,6 @@ private[spark] object RandomForest extends Logging with Serializable { featureSubsetStrategy: String, seed: Long, instr: Option[Instrumentation], - prune: Boolean = true, // exposed for testing only, real trees are always pruned parentUID: Option[String] = None): Array[DecisionTreeModel] = { val timer = new TimeTracker() timer.start("total") @@ -141,7 +140,8 @@ private[spark] object RandomForest extends Logging with Serializable { // depth of the decision tree val maxDepth = strategy.maxDepth - require(maxDepth <= 30, + require( + maxDepth <= 30, s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.") // Max memory usage for aggregates @@ -163,7 +163,9 @@ private[spark] object RandomForest extends Logging with Serializable { // At first, all the rows belong to the root nodes (node Id == 1). nodeIds = baggedInput.map { _ => Array.fill(numTrees)(1) } nodeIdCheckpointer = new PeriodicRDDCheckpointer[Array[Int]]( - strategy.getCheckpointInterval, sc, StorageLevel.MEMORY_AND_DISK) + strategy.getCheckpointInterval, + sc, + StorageLevel.MEMORY_AND_DISK) nodeIdCheckpointer.update(nodeIds) } @@ -192,9 +194,10 @@ private[spark] object RandomForest extends Logging with Serializable { // Collect some nodes to split, and choose features for each node (if subsampling). // Each group of nodes may come from one or multiple trees, and at multiple levels. val (nodesForGroup, treeToNodeToIndexInfo) = - RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng) + RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng) // Sanity check (should never occur): - assert(nodesForGroup.nonEmpty, + assert( + nodesForGroup.nonEmpty, s"RandomForest selected empty nodesForGroup. Error for unknown reason.") // Only send trees to worker if they contain nodes being split this iteration. @@ -203,8 +206,16 @@ private[spark] object RandomForest extends Logging with Serializable { // Choose node splits, and enqueue new nodes as needed. timer.start("findBestSplits") - val bestSplit = RandomForest.findBestSplits(baggedInput, metadata, topNodesForGroup, - nodesForGroup, treeToNodeToIndexInfo, bcSplits, nodeStack, timer, nodeIds, + val bestSplit = RandomForest.findBestSplits( + baggedInput, + metadata, + topNodesForGroup, + nodesForGroup, + treeToNodeToIndexInfo, + bcSplits, + nodeStack, + timer, + nodeIds, outputBestSplits = strategy.useNodeIdCache) if (strategy.useNodeIdCache) { nodeIds = updateNodeIds(baggedInput, nodeIds, bcSplits, bestSplit) @@ -231,23 +242,28 @@ private[spark] object RandomForest extends Logging with Serializable { case Some(uid) => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(uid, rootNode.toNode(prune), numFeatures, + new DecisionTreeClassificationModel( + uid, + rootNode.toNode(strategy.pruneTree), + numFeatures, strategy.getNumClasses) } } else { topNodes.map { rootNode => - new DecisionTreeRegressionModel(uid, rootNode.toNode(prune), numFeatures) + new DecisionTreeRegressionModel(uid, rootNode.toNode(strategy.pruneTree), numFeatures) } } case None => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(rootNode.toNode(prune), numFeatures, + new DecisionTreeClassificationModel( + rootNode.toNode(strategy.pruneTree), + numFeatures, strategy.getNumClasses) } } else { topNodes.map(rootNode => - new DecisionTreeRegressionModel(rootNode.toNode(prune), numFeatures)) + new DecisionTreeRegressionModel(rootNode.toNode(strategy.pruneTree), numFeatures)) } } } @@ -265,7 +281,6 @@ private[spark] object RandomForest extends Logging with Serializable { featureSubsetStrategy: String, seed: Long, instr: Option[Instrumentation], - prune: Boolean = true, // exposed for testing only, real trees are always pruned parentUID: Option[String] = None): Array[DecisionTreeModel] = { val timer = new TimeTracker() @@ -282,9 +297,12 @@ private[spark] object RandomForest extends Logging with Serializable { val splits = findSplits(retaggedInput, metadata, seed) timer.stop("findSplits") logDebug("numBins: feature: number of bins") - logDebug(Range(0, metadata.numFeatures).map { featureIndex => - s"\t$featureIndex\t${metadata.numBins(featureIndex)}" - }.mkString("\n")) + logDebug( + Range(0, metadata.numFeatures) + .map { featureIndex => + s"\t$featureIndex\t${metadata.numBins(featureIndex)}" + } + .mkString("\n")) // Bin feature values (TreePoint representation). // Cache input RDD for speedup during multiple passes. @@ -292,14 +310,26 @@ private[spark] object RandomForest extends Logging with Serializable { val bcSplits = input.sparkContext.broadcast(splits) val baggedInput = BaggedPoint - .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, strategy.bootstrap, - (tp: TreePoint) => tp.weight, seed = seed) + .convertToBaggedRDD( + treeInput, + strategy.subsamplingRate, + numTrees, + strategy.bootstrap, + (tp: TreePoint) => tp.weight, + seed = seed) .persist(StorageLevel.MEMORY_AND_DISK) .setName("bagged tree points") - val trees = runBagged(baggedInput = baggedInput, metadata = metadata, bcSplits = bcSplits, - strategy = strategy, numTrees = numTrees, featureSubsetStrategy = featureSubsetStrategy, - seed = seed, instr = instr, prune = prune, parentUID = parentUID) + val trees = runBagged( + baggedInput = baggedInput, + metadata = metadata, + bcSplits = bcSplits, + strategy = strategy, + numTrees = numTrees, + featureSubsetStrategy = featureSubsetStrategy, + seed = seed, + instr = instr, + parentUID = parentUID) baggedInput.unpersist() bcSplits.destroy() @@ -316,26 +346,27 @@ private[spark] object RandomForest extends Logging with Serializable { bcSplits: Broadcast[Array[Array[Split]]], bestSplits: Array[Map[Int, Split]]): RDD[Array[Int]] = { require(nodeIds != null && bestSplits != null) - input.zip(nodeIds).map { case (point, ids) => - var treeId = 0 - while (treeId < bestSplits.length) { - val bestSplitsInTree = bestSplits(treeId) - if (bestSplitsInTree != null) { - val nodeId = ids(treeId) - bestSplitsInTree.get(nodeId).foreach { bestSplit => - val featureId = bestSplit.featureIndex - val bin = point.datum.binnedFeatures(featureId) - val newNodeId = if (bestSplit.shouldGoLeft(bin, bcSplits.value(featureId))) { - LearningNode.leftChildIndex(nodeId) - } else { - LearningNode.rightChildIndex(nodeId) + input.zip(nodeIds).map { + case (point, ids) => + var treeId = 0 + while (treeId < bestSplits.length) { + val bestSplitsInTree = bestSplits(treeId) + if (bestSplitsInTree != null) { + val nodeId = ids(treeId) + bestSplitsInTree.get(nodeId).foreach { bestSplit => + val featureId = bestSplit.featureIndex + val bin = point.datum.binnedFeatures(featureId) + val newNodeId = if (bestSplit.shouldGoLeft(bin, bcSplits.value(featureId))) { + LearningNode.leftChildIndex(nodeId) + } else { + LearningNode.rightChildIndex(nodeId) + } + ids(treeId) = newNodeId } - ids(treeId) = newNodeId } + treeId += 1 } - treeId += 1 - } - ids + ids } } @@ -387,7 +418,11 @@ private[spark] object RandomForest extends Logging with Serializable { var splitIndex = 0 while (splitIndex < numSplits) { if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) { - agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, numSamples, + agg.featureUpdate( + leftNodeFeatureOffset, + splitIndex, + treePoint.label, + numSamples, sampleWeight) } splitIndex += 1 @@ -502,8 +537,9 @@ private[spark] object RandomForest extends Logging with Serializable { logDebug(s"numFeatures = ${metadata.numFeatures}") logDebug(s"numClasses = ${metadata.numClasses}") logDebug(s"isMulticlass = ${metadata.isMulticlass}") - logDebug(s"isMulticlassWithCategoricalFeatures = " + - s"${metadata.isMulticlassWithCategoricalFeatures}") + logDebug( + s"isMulticlassWithCategoricalFeatures = " + + s"${metadata.isMulticlassWithCategoricalFeatures}") logDebug(s"using nodeIdCache = $useNodeIdCache") /* @@ -530,11 +566,21 @@ private[spark] object RandomForest extends Logging with Serializable { val numSamples = baggedPoint.subsampleCounts(treeIndex) val sampleWeight = baggedPoint.sampleWeight if (metadata.unorderedFeatures.isEmpty) { - orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, numSamples, sampleWeight, + orderedBinSeqOp( + agg(aggNodeIndex), + baggedPoint.datum, + numSamples, + sampleWeight, featuresForNode) } else { - mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits, - metadata.unorderedFeatures, numSamples, sampleWeight, featuresForNode) + mixedBinSeqOp( + agg(aggNodeIndex), + baggedPoint.datum, + splits, + metadata.unorderedFeatures, + numSamples, + sampleWeight, + featuresForNode) } agg(aggNodeIndex).updateParent(baggedPoint.datum.label, numSamples, sampleWeight) } @@ -555,11 +601,16 @@ private[spark] object RandomForest extends Logging with Serializable { agg: Array[DTStatsAggregator], baggedPoint: BaggedPoint[TreePoint], splits: Array[Array[Split]]): Array[DTStatsAggregator] = { - treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => - val nodeIndex = - topNodesForGroup(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits) - nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), - agg, baggedPoint, splits) + treeToNodeToIndexInfo.foreach { + case (treeIndex, nodeIndexToInfo) => + val nodeIndex = + topNodesForGroup(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits) + nodeBinSeqOp( + treeIndex, + nodeIndexToInfo.getOrElse(nodeIndex, null), + agg, + baggedPoint, + splits) } agg } @@ -571,12 +622,17 @@ private[spark] object RandomForest extends Logging with Serializable { agg: Array[DTStatsAggregator], dataPoint: (BaggedPoint[TreePoint], Array[Int]), splits: Array[Array[Split]]): Array[DTStatsAggregator] = { - treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => - val baggedPoint = dataPoint._1 - val nodeIdCache = dataPoint._2 - val nodeIndex = nodeIdCache(treeIndex) - nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), - agg, baggedPoint, splits) + treeToNodeToIndexInfo.foreach { + case (treeIndex, nodeIndexToInfo) => + val baggedPoint = dataPoint._1 + val nodeIdCache = dataPoint._2 + val nodeIndex = nodeIdCache(treeIndex) + nodeBinSeqOp( + treeIndex, + nodeIndexToInfo.getOrElse(nodeIndex, null), + agg, + baggedPoint, + splits) } agg } @@ -585,8 +641,8 @@ private[spark] object RandomForest extends Logging with Serializable { * Get node index in group --> features indices map, * which is a short cut to find feature indices for a node given node index in group. */ - def getNodeToFeatures( - treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]): Option[Map[Int, Array[Int]]] = { + def getNodeToFeatures(treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]) + : Option[Map[Int, Array[Int]]] = { if (!metadata.subsamplingFeatures) { None } else { @@ -594,7 +650,8 @@ private[spark] object RandomForest extends Logging with Serializable { treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo => nodeIdToNodeInfo.values.foreach { nodeIndexInfo => assert(nodeIndexInfo.featureSubset.isDefined) - mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = nodeIndexInfo.featureSubset.get + mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = + nodeIndexInfo.featureSubset.get } } Some(mutableNodeToFeatures.toMap) @@ -603,10 +660,11 @@ private[spark] object RandomForest extends Logging with Serializable { // array of nodes to train indexed by node index in group val nodes = new Array[LearningNode](numNodes) - nodesForGroup.foreach { case (treeIndex, nodesForTree) => - nodesForTree.foreach { node => - nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node - } + nodesForGroup.foreach { + case (treeIndex, nodesForTree) => + nodesForTree.foreach { node => + nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node + } } // Calculate best splits for all nodes in the group @@ -660,17 +718,20 @@ private[spark] object RandomForest extends Logging with Serializable { } } - val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b)).map { - case (nodeIndex, aggStats) => - val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => - Some(nodeToFeatures(nodeIndex)) - } + val nodeToBestSplits = partitionAggregates + .reduceByKey((a, b) => a.merge(b)) + .map { + case (nodeIndex, aggStats) => + val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => + Some(nodeToFeatures(nodeIndex)) + } - // find best split for each node - val (split: Split, stats: ImpurityStats) = - binsToBestSplit(aggStats, bcSplits.value, featuresForNode, nodes(nodeIndex)) - (nodeIndex, (split, stats)) - }.collectAsMap() + // find best split for each node + val (split: Split, stats: ImpurityStats) = + binsToBestSplit(aggStats, bcSplits.value, featuresForNode, nodes(nodeIndex)) + (nodeIndex, (split, stats)) + } + .collectAsMap() nodeToFeaturesBc.destroy() timer.stop("chooseSplits") @@ -682,55 +743,64 @@ private[spark] object RandomForest extends Logging with Serializable { } // Iterate over all nodes in this group. - nodesForGroup.foreach { case (treeIndex, nodesForTree) => - nodesForTree.foreach { node => - val nodeIndex = node.id - val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex) - val aggNodeIndex = nodeInfo.nodeIndexInGroup - val (split: Split, stats: ImpurityStats) = - nodeToBestSplits(aggNodeIndex) - logDebug(s"best split = $split") - - // Extract info for this node. Create children if not leaf. - val isLeaf = - (stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth) - node.isLeaf = isLeaf - node.stats = stats - logDebug(s"Node = $node") - - if (!isLeaf) { - node.split = Some(split) - val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) == metadata.maxDepth - val leftChildIsLeaf = childIsLeaf || (math.abs(stats.leftImpurity) < Utils.EPSILON) - val rightChildIsLeaf = childIsLeaf || (math.abs(stats.rightImpurity) < Utils.EPSILON) - node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex), - leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator))) - node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex), - rightChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator))) - - if (outputBestSplits) { - val bestSplitsInTree = bestSplits(treeIndex) - if (bestSplitsInTree == null) { - bestSplits(treeIndex) = mutable.Map[Int, Split](nodeIndex -> split) - } else { - bestSplitsInTree.update(nodeIndex, split) + nodesForGroup.foreach { + case (treeIndex, nodesForTree) => + nodesForTree.foreach { node => + val nodeIndex = node.id + val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex) + val aggNodeIndex = nodeInfo.nodeIndexInGroup + val (split: Split, stats: ImpurityStats) = + nodeToBestSplits(aggNodeIndex) + logDebug(s"best split = $split") + + // Extract info for this node. Create children if not leaf. + val isLeaf = + (stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth) + node.isLeaf = isLeaf + node.stats = stats + logDebug(s"Node = $node") + + if (!isLeaf) { + node.split = Some(split) + val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) == metadata.maxDepth + val leftChildIsLeaf = childIsLeaf || (math.abs(stats.leftImpurity) < Utils.EPSILON) + val rightChildIsLeaf = childIsLeaf || (math.abs(stats.rightImpurity) < Utils.EPSILON) + node.leftChild = Some( + LearningNode( + LearningNode.leftChildIndex(nodeIndex), + leftChildIsLeaf, + ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator))) + node.rightChild = Some( + LearningNode( + LearningNode.rightChildIndex(nodeIndex), + rightChildIsLeaf, + ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator))) + + if (outputBestSplits) { + val bestSplitsInTree = bestSplits(treeIndex) + if (bestSplitsInTree == null) { + bestSplits(treeIndex) = mutable.Map[Int, Split](nodeIndex -> split) + } else { + bestSplitsInTree.update(nodeIndex, split) + } } - } - // enqueue left child and right child if they are not leaves - if (!leftChildIsLeaf) { - nodeStack.prepend((treeIndex, node.leftChild.get)) - } - if (!rightChildIsLeaf) { - nodeStack.prepend((treeIndex, node.rightChild.get)) - } + // enqueue left child and right child if they are not leaves + if (!leftChildIsLeaf) { + nodeStack.prepend((treeIndex, node.leftChild.get)) + } + if (!rightChildIsLeaf) { + nodeStack.prepend((treeIndex, node.rightChild.get)) + } - logDebug(s"leftChildIndex = ${node.leftChild.get.id}" + - s", impurity = ${stats.leftImpurity}") - logDebug(s"rightChildIndex = ${node.rightChild.get.id}" + - s", impurity = ${stats.rightImpurity}") + logDebug( + s"leftChildIndex = ${node.leftChild.get.id}" + + s", impurity = ${stats.leftImpurity}") + logDebug( + s"rightChildIndex = ${node.rightChild.get.id}" + + s", impurity = ${stats.rightImpurity}") + } } - } } if (outputBestSplits) { @@ -800,8 +870,12 @@ private[spark] object RandomForest extends Logging with Serializable { return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) } - new ImpurityStats(gain, impurity, parentImpurityCalculator, - leftImpurityCalculator, rightImpurityCalculator) + new ImpurityStats( + gain, + impurity, + parentImpurityCalculator, + leftImpurityCalculator, + rightImpurityCalculator) } /** @@ -825,130 +899,156 @@ private[spark] object RandomForest extends Logging with Serializable { } val validFeatureSplits = - Iterator.range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx => - featuresForNode.map(features => (featureIndexIdx, features(featureIndexIdx))) - .getOrElse((featureIndexIdx, featureIndexIdx)) - }.withFilter { case (_, featureIndex) => - binAggregates.metadata.numSplits(featureIndex) != 0 - } + Iterator + .range(0, binAggregates.metadata.numFeaturesPerNode) + .map { featureIndexIdx => + featuresForNode + .map(features => (featureIndexIdx, features(featureIndexIdx))) + .getOrElse((featureIndexIdx, featureIndexIdx)) + } + .withFilter { + case (_, featureIndex) => + binAggregates.metadata.numSplits(featureIndex) != 0 + } // For each (feature, split), calculate the gain, and select the best (feature, split). val splitsAndImpurityInfo = - validFeatureSplits.map { case (featureIndexIdx, featureIndex) => - val numSplits = binAggregates.metadata.numSplits(featureIndex) - if (binAggregates.metadata.isContinuous(featureIndex)) { - // Cumulative sum (scanLeft) of bin statistics. - // Afterwards, binAggregates for a bin is the sum of aggregates for - // that bin + all preceding bins. - val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) - var splitIndex = 0 - while (splitIndex < numSplits) { - binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) - splitIndex += 1 - } - // Find best split. - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { splitIdx => - val leftChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) - val rightChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) - rightChildStats.subtract(leftChildStats) - gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) - (splitIdx, gainAndImpurityStats) - }.maxBy(_._2.gain) - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) - } else if (binAggregates.metadata.isUnordered(featureIndex)) { - // Unordered categorical feature - val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx) - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { splitIndex => - val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) - val rightChildStats = binAggregates.getParentImpurityCalculator() - .subtract(leftChildStats) - gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) - (splitIndex, gainAndImpurityStats) - }.maxBy(_._2.gain) - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) - } else { - // Ordered categorical feature - val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) - val numCategories = binAggregates.metadata.numBins(featureIndex) - - /* Each bin is one category (feature value). - * The bins are ordered based on centroidForCategories, and this ordering determines which - * splits are considered. (With K categories, we consider K - 1 possible splits.) - * + validFeatureSplits.map { + case (featureIndexIdx, featureIndex) => + val numSplits = binAggregates.metadata.numSplits(featureIndex) + if (binAggregates.metadata.isContinuous(featureIndex)) { + // Cumulative sum (scanLeft) of bin statistics. + // Afterwards, binAggregates for a bin is the sum of aggregates for + // that bin + all preceding bins. + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + var splitIndex = 0 + while (splitIndex < numSplits) { + binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) + splitIndex += 1 + } + // Find best split. + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits) + .map { splitIdx => + val leftChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) + val rightChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) + rightChildStats.subtract(leftChildStats) + gainAndImpurityStats = calculateImpurityStats( + gainAndImpurityStats, + leftChildStats, + rightChildStats, + binAggregates.metadata) + (splitIdx, gainAndImpurityStats) + } + .maxBy(_._2.gain) + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) + } else if (binAggregates.metadata.isUnordered(featureIndex)) { + // Unordered categorical feature + val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx) + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits) + .map { splitIndex => + val leftChildStats = + binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) + val rightChildStats = binAggregates + .getParentImpurityCalculator() + .subtract(leftChildStats) + gainAndImpurityStats = calculateImpurityStats( + gainAndImpurityStats, + leftChildStats, + rightChildStats, + binAggregates.metadata) + (splitIndex, gainAndImpurityStats) + } + .maxBy(_._2.gain) + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) + } else { + // Ordered categorical feature + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + val numCategories = binAggregates.metadata.numBins(featureIndex) + + /* Each bin is one category (feature value). + * The bins are ordered based on centroidForCategories, and this ordering determines + * which splits are considered. (With K categories, we + * consider K - 1 possible splits.) + * * centroidForCategories is a list: (category, centroid) - */ - val centroidForCategories = Range(0, numCategories).map { featureValue => - val categoryStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val centroid = if (categoryStats.count != 0) { - if (binAggregates.metadata.isMulticlass) { - // multiclass classification - // For categorical variables in multiclass classification, - // the bins are ordered by the impurity of their corresponding labels. - categoryStats.calculate() - } else if (binAggregates.metadata.isClassification) { - // binary classification - // For categorical variables in binary classification, - // the bins are ordered by the count of class 1. - categoryStats.stats(1) + */ + val centroidForCategories = Range(0, numCategories).map { featureValue => + val categoryStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val centroid = if (categoryStats.count != 0) { + if (binAggregates.metadata.isMulticlass) { + // multiclass classification + // For categorical variables in multiclass classification, + // the bins are ordered by the impurity of their corresponding labels. + categoryStats.calculate() + } else if (binAggregates.metadata.isClassification) { + // binary classification + // For categorical variables in binary classification, + // the bins are ordered by the count of class 1. + categoryStats.stats(1) + } else { + // regression + // For categorical variables in regression and binary classification, + // the bins are ordered by the prediction. + categoryStats.predict + } } else { - // regression - // For categorical variables in regression and binary classification, - // the bins are ordered by the prediction. - categoryStats.predict + Double.MaxValue } - } else { - Double.MaxValue + (featureValue, centroid) } - (featureValue, centroid) - } - logDebug(s"Centroids for categorical variable: " + - s"${centroidForCategories.mkString(",")}") - - // bins sorted by centroids - val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2) - - logDebug(s"Sorted centroids for categorical variable = " + - s"${categoriesSortedByCentroid.mkString(",")}") - - // Cumulative sum (scanLeft) of bin statistics. - // Afterwards, binAggregates for a bin is the sum of aggregates for - // that bin + all preceding bins. - var splitIndex = 0 - while (splitIndex < numSplits) { - val currentCategory = categoriesSortedByCentroid(splitIndex)._1 - val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1 - binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory) - splitIndex += 1 + logDebug( + s"Centroids for categorical variable: " + + s"${centroidForCategories.mkString(",")}") + + // bins sorted by centroids + val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2) + + logDebug( + s"Sorted centroids for categorical variable = " + + s"${categoriesSortedByCentroid.mkString(",")}") + + // Cumulative sum (scanLeft) of bin statistics. + // Afterwards, binAggregates for a bin is the sum of aggregates for + // that bin + all preceding bins. + var splitIndex = 0 + while (splitIndex < numSplits) { + val currentCategory = categoriesSortedByCentroid(splitIndex)._1 + val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1 + binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory) + splitIndex += 1 + } + // lastCategory = index of bin with total aggregates for this (node, feature) + val lastCategory = categoriesSortedByCentroid.last._1 + // Find best split. + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits) + .map { splitIndex => + val featureValue = categoriesSortedByCentroid(splitIndex)._1 + val leftChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val rightChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) + rightChildStats.subtract(leftChildStats) + gainAndImpurityStats = calculateImpurityStats( + gainAndImpurityStats, + leftChildStats, + rightChildStats, + binAggregates.metadata) + (splitIndex, gainAndImpurityStats) + } + .maxBy(_._2.gain) + val categoriesForSplit = + categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) + val bestFeatureSplit = + new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories) + (bestFeatureSplit, bestFeatureGainStats) } - // lastCategory = index of bin with total aggregates for this (node, feature) - val lastCategory = categoriesSortedByCentroid.last._1 - // Find best split. - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { splitIndex => - val featureValue = categoriesSortedByCentroid(splitIndex)._1 - val leftChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val rightChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) - rightChildStats.subtract(leftChildStats) - gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) - (splitIndex, gainAndImpurityStats) - }.maxBy(_._2.gain) - val categoriesForSplit = - categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) - val bestFeatureSplit = - new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories) - (bestFeatureSplit, bestFeatureGainStats) - } } val (bestSplit, bestSplitStats) = @@ -959,11 +1059,13 @@ private[spark] object RandomForest extends Logging with Serializable { val dummyFeatureIndex = featuresForNode.map(_.head).getOrElse(0) val parentImpurityCalculator = binAggregates.getParentImpurityCalculator() if (binAggregates.metadata.isContinuous(dummyFeatureIndex)) { - (new ContinuousSplit(dummyFeatureIndex, 0), + ( + new ContinuousSplit(dummyFeatureIndex, 0), ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)) } else { val numCategories = binAggregates.metadata.featureArity(dummyFeatureIndex) - (new CategoricalSplit(dummyFeatureIndex, Array(), numCategories), + ( + new CategoricalSplit(dummyFeatureIndex, Array(), numCategories), ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)) } } else { @@ -1036,27 +1138,34 @@ private[spark] object RandomForest extends Logging with Serializable { // being spun up that will definitely do no work. val numPartitions = math.min(continuousFeatures.length, input.partitions.length) - input.flatMap { point => - continuousFeatures.iterator - .map(idx => (idx, (point.features(idx), point.weight))) - .filter(_._2._1 != 0.0) - }.aggregateByKey((new OpenHashMap[Double, Double], 0L), numPartitions)( - seqOp = { case ((map, c), (v, w)) => - map.changeValue(v, w, _ + w) - (map, c + 1L) - }, - combOp = { case ((map1, c1), (map2, c2)) => - map2.foreach { case (v, w) => - map1.changeValue(v, w, _ + w) - } - (map1, c1 + c2) + input + .flatMap { point => + continuousFeatures.iterator + .map(idx => (idx, (point.features(idx), point.weight))) + .filter(_._2._1 != 0.0) } - ).map { case (idx, (map, c)) => - val thresholds = findSplitsForContinuousFeature(map.toMap, c, metadata, idx) - val splits: Array[Split] = thresholds.map(thresh => new ContinuousSplit(idx, thresh)) - logDebug(s"featureIndex = $idx, numSplits = ${splits.length}") - (idx, splits) - }.collectAsMap() + .aggregateByKey((new OpenHashMap[Double, Double], 0L), numPartitions)( + seqOp = { + case ((map, c), (v, w)) => + map.changeValue(v, w, _ + w) + (map, c + 1L) + }, + combOp = { + case ((map1, c1), (map2, c2)) => + map2.foreach { + case (v, w) => + map1.changeValue(v, w, _ + w) + } + (map1, c1 + c2) + }) + .map { + case (idx, (map, c)) => + val thresholds = findSplitsForContinuousFeature(map.toMap, c, metadata, idx) + val splits: Array[Split] = thresholds.map(thresh => new ContinuousSplit(idx, thresh)) + logDebug(s"featureIndex = $idx, numSplits = ${splits.length}") + (idx, splits) + } + .collectAsMap() } else Map.empty[Int, Array[Split]] val numFeatures = metadata.numFeatures @@ -1127,9 +1236,10 @@ private[spark] object RandomForest extends Logging with Serializable { featureIndex: Int): Array[Double] = { val valueWeights = new OpenHashMap[Double, Double] var count = 0L - featureSamples.foreach { case (weight, value) => - valueWeights.changeValue(value, weight, _ + weight) - count += 1L + featureSamples.foreach { + case (weight, value) => + valueWeights.changeValue(value, weight, _ + weight) + count += 1L } findSplitsForContinuousFeature(valueWeights.toMap, count, metadata, featureIndex) } @@ -1152,7 +1262,8 @@ private[spark] object RandomForest extends Logging with Serializable { count: Long, metadata: DecisionTreeMetadata, featureIndex: Int): Array[Double] = { - require(metadata.isContinuous(featureIndex), + require( + metadata.isContinuous(featureIndex), "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.") val splits = if (partValueWeights.isEmpty) { @@ -1226,7 +1337,8 @@ private[spark] object RandomForest extends Logging with Serializable { private[tree] class NodeIndexInfo( val nodeIndexInGroup: Int, - val featureSubset: Option[Array[Int]]) extends Serializable + val featureSubset: Option[Array[Int]]) + extends Serializable /** * Pull nodes off of the queue, and collect a group of nodes to be split on this iteration. @@ -1264,8 +1376,13 @@ private[spark] object RandomForest extends Logging with Serializable { val (treeIndex, node) = nodeStack.head // Choose subset of features for node (if subsampling). val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { - Some(SamplingUtils.reservoirSampleAndCount(Range(0, - metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong())._1) + Some( + SamplingUtils + .reservoirSampleAndCount( + Range(0, metadata.numFeatures).iterator, + metadata.numFeaturesPerNode, + rng.nextLong()) + ._1) } else { None } @@ -1273,11 +1390,13 @@ private[spark] object RandomForest extends Logging with Serializable { val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L if (memUsage + nodeMemUsage <= maxMemoryUsage || memUsage == 0) { nodeStack.remove(0) - mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[LearningNode]()) += + mutableNodesForGroup.getOrElseUpdate( + treeIndex, + new mutable.ArrayBuffer[LearningNode]()) += node mutableTreeToNodeToIndexInfo - .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id) - = new NodeIndexInfo(numNodesInGroup, featureSubset) + .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id) = + new NodeIndexInfo(numNodesInGroup, featureSubset) numNodesInGroup += 1 memUsage += nodeMemUsage } else { @@ -1286,9 +1405,10 @@ private[spark] object RandomForest extends Logging with Serializable { } if (memUsage > maxMemoryUsage) { // If maxMemoryUsage is 0, we should still allow splitting 1 node. - logWarning(s"Tree learning is using approximately $memUsage bytes per iteration, which" + - s" exceeds requested limit maxMemoryUsage=$maxMemoryUsage. This allows splitting" + - s" $numNodesInGroup nodes in this iteration.") + logWarning( + s"Tree learning is using approximately $memUsage bytes per iteration, which" + + s" exceeds requested limit maxMemoryUsage=$maxMemoryUsage. This allows splitting" + + s" $numNodesInGroup nodes in this iteration.") } // Convert mutable maps to immutable ones. val nodesForGroup: Map[Int, Array[LearningNode]] = @@ -1324,8 +1444,7 @@ private[spark] object RandomForest extends Logging with Serializable { * @param metadata decision tree metadata * @return subsample fraction */ - private def samplesFractionForFindSplits( - metadata: DecisionTreeMetadata): Double = { + private def samplesFractionForFindSplits(metadata: DecisionTreeMetadata): Double = { // Calculate the number of samples for approximate quantile calculation. val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) if (requiredSamples < metadata.numExamples) { @@ -1334,4 +1453,5 @@ private[spark] object RandomForest extends Logging with Serializable { 1.0 } } + } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 768e14f4b74e4..4817d9ddec66e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -211,10 +211,32 @@ private[ml] trait TreeClassifierParams extends Params { (value: String) => TreeClassifierParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT))) - setDefault(impurity -> "gini") + /** + * If true, the trained tree will undergo a 'pruning' process after training in which nodes + * that have the same class predictions will be merged. This drawback means that the class + * probabilities will be lost. The benefit being that at prediction time the tree will be + * smaller and have faster predictions + * If false, the post-training tree will undergo no pruning. The benefit being that you + * maintain the class prediction probabilities + * (default = true) + * @group param + */ + final val pruneTree: BooleanParam = new BooleanParam(this, "pruneTree", "" + + "If true, the trained tree will undergo a 'pruning' process after training in which nodes" + + " that have the same class predictions will be merged. This drawback means that the class" + + " probabilities will be lost. The benefit being that at prediction time the tree will be" + + " smaller and have faster predictions" + + " If false, the post-training tree will undergo no pruning. The benefit being that you" + + " maintain the class prediction probabilities" + ) + + // HERE + setDefault(impurity -> "gini", pruneTree -> true) /** @group getParam */ final def getImpurity: String = $(impurity).toLowerCase(Locale.ROOT) + /** @group getParam */ + final def getPruneTree: Boolean = $(pruneTree) /** Convert new impurity to old impurity. */ private[ml] def getOldImpurity: OldImpurity = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 0f6c7033687fa..45c1a9c08a38f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -55,6 +55,8 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} * @param minInfoGain Minimum information gain a split must get. Default value is 0.0. * If a split has less information gain than minInfoGain, * this split will not be considered as a valid split. + * @param pruneTree If this is true, the final training tree will undergo a pruning in which + * nodes with the same classifications are merged. * @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is * 256 MB. If too small, then 1 node will be split per iteration, and * its aggregates may exceed this size. @@ -77,6 +79,7 @@ class Strategy @Since("1.3.0") ( @Since("1.0.0") @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), @Since("1.2.0") @BeanProperty var minInstancesPerNode: Int = 1, @Since("1.2.0") @BeanProperty var minInfoGain: Double = 0.0, + @Since("3.1.2") @BeanProperty var pruneTree: Boolean = false, @Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256, @Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1, @Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false, @@ -113,12 +116,13 @@ class Strategy @Since("1.3.0") ( categoricalFeaturesInfo: Map[Int, Int], minInstancesPerNode: Int, minInfoGain: Double, + pruneTree: Boolean, maxMemoryInMB: Int, subsamplingRate: Double, useNodeIdCache: Boolean, checkpointInterval: Int) = { this(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy, - categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, maxMemoryInMB, + categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, pruneTree, maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval, 0.0) } // scalastyle:on argcount @@ -200,7 +204,7 @@ class Strategy @Since("1.3.0") ( def copy: Strategy = { new Strategy(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, - minInfoGain, maxMemoryInMB, subsamplingRate, useNodeIdCache, + minInfoGain, pruneTree, maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval, minWeightFractionPerNode) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 3ca6816ce7c0d..2bb24c19aedb2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -59,8 +59,13 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance) assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2, - maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2)) + val strategy = new OldStrategy( + OldAlgo.Classification, + Gini, + maxDepth = 2, + numClasses = 2, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2)) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) val splits = RandomForest.findSplits(rdd, metadata, seed = 42) @@ -71,13 +76,19 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(splits(0).length === 0) } - test("Binary classification with 3-ary (ordered) categorical features," + - " with no samples for one category: split calculation") { + test( + "Binary classification with 3-ary (ordered) categorical features," + + " with no samples for one category: split calculation") { val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance) assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2, - maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + val strategy = new OldStrategy( + OldAlgo.Classification, + Gini, + maxDepth = 2, + numClasses = 2, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) @@ -91,11 +102,23 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { test("find splits for a continuous feature") { // find splits for normal case { - val fakeMetadata = new DecisionTreeMetadata(1, 200000, 200000.0, 0, 0, - Map(), Set(), - Array(6), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0.0, 0, 0 - ) + val fakeMetadata = new DecisionTreeMetadata( + 1, + 200000, + 200000.0, + 0, + 0, + Map(), + Set(), + Array(6), + Gini, + QuantileStrategy.Sort, + 0, + 0, + 0.0, + 0.0, + 0, + 0) val featureSamples = Array.fill(10000)((1.0, math.random)).filter(_._2 != 0.0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) assert(splits.length === 5) @@ -107,16 +130,29 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // SPARK-16957: Use midpoints for split values. { - val fakeMetadata = new DecisionTreeMetadata(1, 8, 8.0, 0, 0, - Map(), Set(), - Array(3), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0.0, 0, 0 - ) + val fakeMetadata = new DecisionTreeMetadata( + 1, + 8, + 8.0, + 0, + 0, + Map(), + Set(), + Array(3), + Gini, + QuantileStrategy.Sort, + 0, + 0, + 0.0, + 0.0, + 0, + 0) // possibleSplits <= numSplits { val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1) - .map(x => (1.0, x.toDouble)).filter(_._2 != 0.0) + .map(x => (1.0, x.toDouble)) + .filter(_._2 != 0.0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val expectedSplits = Array((0.0 + 1.0) / 2) assert(splits === expectedSplits) @@ -125,7 +161,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // possibleSplits > numSplits { val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3) - .map(x => (1.0, x.toDouble)).filter(_._2 != 0.0) + .map(x => (1.0, x.toDouble)) + .filter(_._2 != 0.0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val expectedSplits = Array((0.0 + 1.0) / 2, (2.0 + 3.0) / 2) assert(splits === expectedSplits) @@ -135,11 +172,23 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits should not return identical splits // when there are not enough split candidates, reduce the number of splits in metadata { - val fakeMetadata = new DecisionTreeMetadata(1, 12, 12.0, 0, 0, - Map(), Set(), - Array(5), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0.0, 0, 0 - ) + val fakeMetadata = new DecisionTreeMetadata( + 1, + 12, + 12.0, + 0, + 0, + Map(), + Set(), + Array(5), + Gini, + QuantileStrategy.Sort, + 0, + 0, + 0.0, + 0.0, + 0, + 0) val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3).map(x => (1.0, x.toDouble)) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val expectedSplits = Array((1.0 + 2.0) / 2, (2.0 + 3.0) / 2) @@ -150,11 +199,23 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits when most samples close to the minimum { - val fakeMetadata = new DecisionTreeMetadata(1, 18, 18.0, 0, 0, - Map(), Set(), - Array(3), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0.0, 0, 0 - ) + val fakeMetadata = new DecisionTreeMetadata( + 1, + 18, + 18.0, + 0, + 0, + Map(), + Set(), + Array(3), + Gini, + QuantileStrategy.Sort, + 0, + 0, + 0.0, + 0.0, + 0, + 0) val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(x => (1.0, x.toDouble)) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) @@ -164,11 +225,23 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits when most samples close to the maximum { - val fakeMetadata = new DecisionTreeMetadata(1, 17, 17.0, 0, 0, - Map(), Set(), - Array(2), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0.0, 0, 0 - ) + val fakeMetadata = new DecisionTreeMetadata( + 1, + 17, + 17.0, + 0, + 0, + Map(), + Set(), + Array(2), + Gini, + QuantileStrategy.Sort, + 0, + 0, + 0.0, + 0.0, + 0, + 0) val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(x => (1.0, x.toDouble)) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) @@ -178,14 +251,30 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits for arbitrarily scaled data { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0, - Map(), Set(), - Array(6), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0.0, 0, 0 - ) + val fakeMetadata = new DecisionTreeMetadata( + 1, + 0, + 0.0, + 0, + 0, + Map(), + Set(), + Array(6), + Gini, + QuantileStrategy.Sort, + 0, + 0, + 0.0, + 0.0, + 0, + 0) val featureSamplesUnitWeight = Array.fill(10)((1.0, math.random)) - val featureSamplesSmallWeight = featureSamplesUnitWeight.map { case (w, x) => (w * 0.001, x)} - val featureSamplesLargeWeight = featureSamplesUnitWeight.map { case (w, x) => (w * 1000, x)} + val featureSamplesSmallWeight = featureSamplesUnitWeight.map { + case (w, x) => (w * 0.001, x) + } + val featureSamplesLargeWeight = featureSamplesUnitWeight.map { + case (w, x) => (w * 1000, x) + } val splitsUnitWeight = RandomForest .findSplitsForContinuousFeature(featureSamplesUnitWeight, fakeMetadata, 0) val splitsSmallWeight = RandomForest @@ -198,11 +287,23 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits when most weight is close to the minimum { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0, - Map(), Set(), - Array(3), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0.0, 0, 0 - ) + val fakeMetadata = new DecisionTreeMetadata( + 1, + 0, + 0.0, + 0, + 0, + Map(), + Set(), + Array(3), + Gini, + QuantileStrategy.Sort, + 0, + 0, + 0.0, + 0.0, + 0, + 0) val featureSamples = Array((10, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6)).map { case (w, x) => (w.toDouble, x.toDouble) } @@ -216,10 +317,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val data = Array.fill(5)(lp) val rdd = sc.parallelize(data) - val strategy = new OldStrategy(OldAlgo.Regression, Gini, maxDepth = 2, - maxBins = 5) - withClue("DecisionTree requires number of features > 0," + - " but was given an empty features vector") { + val strategy = new OldStrategy(OldAlgo.Regression, Gini, maxDepth = 2, maxBins = 5) + withClue( + "DecisionTree requires number of features > 0," + + " but was given an empty features vector") { intercept[IllegalArgumentException] { RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None) } @@ -231,23 +332,19 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val data = Array.fill(5)(instance) val rdd = sc.parallelize(data) val strategy = new OldStrategy( - OldAlgo.Classification, - Gini, - maxDepth = 2, - numClasses = 2, - maxBins = 5, - categoricalFeaturesInfo = Map(0 -> 1, 1 -> 5)) + OldAlgo.Classification, + Gini, + maxDepth = 2, + numClasses = 2, + maxBins = 5, + categoricalFeaturesInfo = Map(0 -> 1, 1 -> 5)) val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None) assert(tree.rootNode.impurity === -1.0) assert(tree.depth === 0) assert(tree.rootNode.prediction === instance.label) // Test with no categorical features - val strategy2 = new OldStrategy( - OldAlgo.Regression, - Variance, - maxDepth = 2, - maxBins = 5) + val strategy2 = new OldStrategy(OldAlgo.Regression, Variance, maxDepth = 2, maxBins = 5) val Array(tree2) = RandomForest.run(rdd, strategy2, 1, "all", 42L, instr = None) assert(tree2.rootNode.impurity === -1.0) assert(tree2.depth === 0) @@ -278,12 +375,15 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(metadata.numBins(1) === 3) // Expecting 2^2 - 1 = 3 splits per feature - def checkCategoricalSplit(s: Split, featureIndex: Int, leftCategories: Array[Double]): Unit = { + def checkCategoricalSplit( + s: Split, + featureIndex: Int, + leftCategories: Array[Double]): Unit = { assert(s.featureIndex === featureIndex) assert(s.isInstanceOf[CategoricalSplit]) val s0 = s.asInstanceOf[CategoricalSplit] assert(s0.leftCategories === leftCategories) - assert(s0.numCategories === 3) // for this unit test + assert(s0.numCategories === 3) // for this unit test } // Feature 0 checkCategoricalSplit(splits(0)(0), 0, Array(0.0)) @@ -296,12 +396,18 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { } test("Multiclass classification with ordered categorical features: split calculations") { - val arr = OldDTSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() + val arr = OldDTSuite + .generateCategoricalDataPointsForMulticlassForOrderedFeatures() .map(_.asML.toInstance) assert(arr.length === 3000) val rdd = sc.parallelize(arr) - val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 100, - maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) + val strategy = new OldStrategy( + OldAlgo.Classification, + Gini, + maxDepth = 2, + numClasses = 100, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) // 2^(10-1) - 1 > 100, so categorical features will be ordered val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) @@ -331,8 +437,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) val input = sc.parallelize(arr.map(_.toInstance)) - val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1, - numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val strategy = new OldStrategy( + algo = OldAlgo.Classification, + impurity = Gini, + maxDepth = 1, + numClasses = 2, + categoricalFeaturesInfo = Map(0 -> 3)) val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) val splits = RandomForest.findSplits(input, metadata, seed = 42) val bcSplits = input.sparkContext.broadcast(splits) @@ -345,12 +455,17 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.stats === null) val nodesForGroup = Map(0 -> Array(topNode)) - val treeToNodeToIndexInfo = Map(0 -> Map( - topNode.id -> new RandomForest.NodeIndexInfo(0, None) - )) + val treeToNodeToIndexInfo = + Map(0 -> Map(topNode.id -> new RandomForest.NodeIndexInfo(0, None))) val nodeStack = new mutable.ListBuffer[(Int, LearningNode)] - RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode), - nodesForGroup, treeToNodeToIndexInfo, bcSplits, nodeStack) + RandomForest.findBestSplits( + baggedInput, + metadata, + Map(0 -> topNode), + nodesForGroup, + treeToNodeToIndexInfo, + bcSplits, + nodeStack) bcSplits.destroy() // don't enqueue leaf nodes into node queue @@ -375,8 +490,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) val input = sc.parallelize(arr.map(_.toInstance)) - val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 5, - numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val strategy = new OldStrategy( + algo = OldAlgo.Classification, + impurity = Gini, + maxDepth = 5, + numClasses = 2, + categoricalFeaturesInfo = Map(0 -> 3)) val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) val splits = RandomForest.findSplits(input, metadata, seed = 42) val bcSplits = input.sparkContext.broadcast(splits) @@ -389,12 +508,17 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.stats === null) val nodesForGroup = Map(0 -> Array(topNode)) - val treeToNodeToIndexInfo = Map(0 -> Map( - topNode.id -> new RandomForest.NodeIndexInfo(0, None) - )) + val treeToNodeToIndexInfo = + Map(0 -> Map(topNode.id -> new RandomForest.NodeIndexInfo(0, None))) val nodeStack = new mutable.ListBuffer[(Int, LearningNode)] - RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode), - nodesForGroup, treeToNodeToIndexInfo, bcSplits, nodeStack) + RandomForest.findBestSplits( + baggedInput, + metadata, + Map(0 -> topNode), + nodesForGroup, + treeToNodeToIndexInfo, + bcSplits, + nodeStack) bcSplits.destroy() // don't enqueue a node into node queue if its impurity is 0.0 @@ -430,18 +554,32 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val input = sc.parallelize(arr.map(_.toInstance)) // Must set maxBins s.t. the feature will be treated as an ordered categorical feature. - val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1, - numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3) - - val model = RandomForest.run(input, strategy, numTrees = 1, featureSubsetStrategy = "all", - seed = 42, instr = None, prune = false).head + val strategy = new OldStrategy( + algo = OldAlgo.Classification, + impurity = Gini, + maxDepth = 1, + numClasses = 2, + categoricalFeaturesInfo = Map(0 -> 3), + maxBins = 3) + + val model = RandomForest + .run( + input, + strategy, + numTrees = 1, + featureSubsetStrategy = "all", + seed = 42, + instr = None, + prune = false) + .head model.rootNode match { - case n: InternalNode => n.split match { - case s: CategoricalSplit => - assert(s.leftCategories === Array(1.0)) - case _ => fail("model.rootNode.split was not a CategoricalSplit") - } + case n: InternalNode => + n.split match { + case s: CategoricalSplit => + assert(s.leftCategories === Array(1.0)) + case _ => fail("model.rootNode.split was not a CategoricalSplit") + } case _ => fail("model.rootNode was not an InternalNode") } } @@ -457,18 +595,21 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val strategy2 = new OldStrategy(OldAlgo.Classification, Entropy, 3, 2, 100, maxMemoryInMB = 0) - val tree1 = RandomForest.run(rdd, strategy1, numTrees = 1, featureSubsetStrategy = "all", - seed = 42, instr = None).head - val tree2 = RandomForest.run(rdd, strategy2, numTrees = 1, featureSubsetStrategy = "all", - seed = 42, instr = None).head - - def getChildren(rootNode: Node): Array[InternalNode] = rootNode match { - case n: InternalNode => - assert(n.leftChild.isInstanceOf[InternalNode]) - assert(n.rightChild.isInstanceOf[InternalNode]) - Array(n.leftChild.asInstanceOf[InternalNode], n.rightChild.asInstanceOf[InternalNode]) - case _ => fail("rootNode was not an InternalNode") - } + val tree1 = RandomForest + .run(rdd, strategy1, numTrees = 1, featureSubsetStrategy = "all", seed = 42, instr = None) + .head + val tree2 = RandomForest + .run(rdd, strategy2, numTrees = 1, featureSubsetStrategy = "all", seed = 42, instr = None) + .head + + def getChildren(rootNode: Node): Array[InternalNode] = + rootNode match { + case n: InternalNode => + assert(n.leftChild.isInstanceOf[InternalNode]) + assert(n.rightChild.isInstanceOf[InternalNode]) + Array(n.leftChild.asInstanceOf[InternalNode], n.rightChild.asInstanceOf[InternalNode]) + case _ => fail("rootNode was not an InternalNode") + } // Single group second level tree construction. val children1 = getChildren(tree1.rootNode) @@ -514,8 +655,9 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { nodeStack.prepend((treeIndex, topNodes(treeIndex))) } val rng = new scala.util.Random(seed = seed) - val (nodesForGroup: Map[Int, Array[LearningNode]], - treeToNodeToIndexInfo: Map[Int, Map[Int, RandomForest.NodeIndexInfo]]) = + val ( + nodesForGroup: Map[Int, Array[LearningNode]], + treeToNodeToIndexInfo: Map[Int, Map[Int, RandomForest.NodeIndexInfo]]) = RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng) assert(nodesForGroup.size === numTrees, failString) @@ -523,12 +665,15 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { if (numFeaturesPerNode == numFeatures) { // featureSubset values should all be None - assert(treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)), + assert( + treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)), failString) } else { // Check number of features. - assert(treeToNodeToIndexInfo.values.forall(_.values.forall( - _.featureSubset.get.length === numFeaturesPerNode)), failString) + assert( + treeToNodeToIndexInfo.values.forall( + _.values.forall(_.featureSubset.get.length === numFeaturesPerNode)), + failString) } } } @@ -536,7 +681,9 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { checkFeatureSubsetStrategy(numTrees = 1, "auto", numFeatures) checkFeatureSubsetStrategy(numTrees = 1, "all", numFeatures) checkFeatureSubsetStrategy(numTrees = 1, "sqrt", math.sqrt(numFeatures).ceil.toInt) - checkFeatureSubsetStrategy(numTrees = 1, "log2", + checkFeatureSubsetStrategy( + numTrees = 1, + "log2", (math.log(numFeatures) / math.log(2)).ceil.toInt) checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 3.0).ceil.toInt) @@ -554,7 +701,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val invalidStrategies = Array("-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0") for (invalidStrategy <- invalidStrategies) { - intercept[IllegalArgumentException]{ + intercept[IllegalArgumentException] { val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 1, invalidStrategy) } @@ -563,7 +710,9 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures) checkFeatureSubsetStrategy(numTrees = 2, "auto", math.sqrt(numFeatures).ceil.toInt) checkFeatureSubsetStrategy(numTrees = 2, "sqrt", math.sqrt(numFeatures).ceil.toInt) - checkFeatureSubsetStrategy(numTrees = 2, "log2", + checkFeatureSubsetStrategy( + numTrees = 2, + "log2", (math.log(numFeatures) / math.log(2)).ceil.toInt) checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt) @@ -577,7 +726,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { checkFeatureSubsetStrategy(numTrees = 2, strategy, expected) } for (invalidStrategy <- invalidStrategies) { - intercept[IllegalArgumentException]{ + intercept[IllegalArgumentException] { val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 2, invalidStrategy) } @@ -586,15 +735,23 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { test("Binary classification with continuous features: subsampling features") { val categoricalFeaturesInfo = Map.empty[Int, Int] - val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 2, - numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) + val strategy = new OldStrategy( + algo = OldAlgo.Classification, + impurity = Gini, + maxDepth = 2, + numClasses = 2, + categoricalFeaturesInfo = categoricalFeaturesInfo) binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy) } test("Binary classification with continuous features and node Id cache: subsampling features") { val categoricalFeaturesInfo = Map.empty[Int, Int] - val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 2, - numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, + val strategy = new OldStrategy( + algo = OldAlgo.Classification, + impurity = Gini, + maxDepth = 2, + numClasses = 2, + categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true) binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy) } @@ -647,7 +804,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { } val importances: Vector = TreeEnsembleModel.featureImportances(trees, 2) val tree2norm = feature0importance + feature1importance - val expected = Vectors.dense((1.0 + feature0importance / tree2norm) / 2.0, + val expected = Vectors.dense( + (1.0 + feature0importance / tree2norm) / 2.0, (feature1importance / tree2norm) / 2.0) assert(importances ~== expected relTol 0.01) } @@ -676,23 +834,48 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { Instance(0.0, 1.0, Vectors.dense(0.0, 0.0)), Instance(1.0, 1.0, Vectors.dense(1.0, 0.0)), Instance(0.0, 1.0, Vectors.dense(1.0, 0.0)), - Instance(1.0, 1.0, Vectors.dense(1.0, 1.0)) - ) + Instance(1.0, 1.0, Vectors.dense(1.0, 1.0))) val rdd = sc.parallelize(arr) val numClasses = 2 - val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 4, - numClasses = numClasses, maxBins = 32) - - val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", - seed = 42, instr = None).head - - val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", - seed = 42, instr = None, prune = false).head + val strategy = new OldStrategy( + algo = OldAlgo.Classification, + impurity = Gini, + maxDepth = 4, + numClasses = numClasses, + maxBins = 32) + + val prunedTree = RandomForest + .run( + rdd, + strategy, + numTrees = 1, + featureSubsetStrategy = "auto", + seed = 42, + instr = None, + prune = true) + .head + + val unprunedTree = RandomForest + .run( + rdd, + strategy, + numTrees = 1, + featureSubsetStrategy = "auto", + seed = 42, + instr = None, + prune = false) + .head + + val defaultBehaviorTree = RandomForest + .run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", seed = 42, instr = None) + .head assert(prunedTree.numNodes === 5) assert(unprunedTree.numNodes === 7) + assert(defaultBehaviorTree.numNodes == unprunedTree.numNodes) + assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size) } @@ -707,21 +890,47 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { Instance(0.0, 1.0, Vectors.dense(1.0, 0.0)), Instance(1.0, 1.0, Vectors.dense(1.0, 1.0)), Instance(0.0, 1.0, Vectors.dense(1.0, 1.0)), - Instance(0.5, 1.0, Vectors.dense(1.0, 1.0)) - ) + Instance(0.5, 1.0, Vectors.dense(1.0, 1.0))) val rdd = sc.parallelize(arr) - val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = Variance, maxDepth = 4, - numClasses = 0, maxBins = 32) - - val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", - seed = 42, instr = None).head - - val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", - seed = 42, instr = None, prune = false).head + val strategy = new OldStrategy( + algo = OldAlgo.Regression, + impurity = Variance, + maxDepth = 4, + numClasses = 0, + maxBins = 32) + + val prunedTree = RandomForest + .run( + rdd, + strategy, + numTrees = 1, + featureSubsetStrategy = "auto", + seed = 42, + instr = None, + prune = true) + .head + + val unprunedTree = RandomForest + .run( + rdd, + strategy, + numTrees = 1, + featureSubsetStrategy = "auto", + seed = 42, + instr = None, + prune = false) + .head + + val defaultBehaviorTree = RandomForest + .run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", seed = 42, instr = None) + .head assert(prunedTree.numNodes === 3) assert(unprunedTree.numNodes === 5) + + assert(defaultBehaviorTree.numNodes == unprunedTree.numNodes) + assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size) } @@ -738,13 +947,15 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val unitWeightTrees = RandomForest.run(rddWithUnitWeights, strategy, 3, "all", 42L, None) val smallWeightTrees = RandomForest.run(rddWithSmallWeights, strategy, 3, "all", 42L, None) - unitWeightTrees.zip(smallWeightTrees).foreach { case (unitTree, smallWeightTree) => - TreeTests.checkEqual(unitTree, smallWeightTree) + unitWeightTrees.zip(smallWeightTrees).foreach { + case (unitTree, smallWeightTree) => + TreeTests.checkEqual(unitTree, smallWeightTree) } val bigWeightTrees = RandomForest.run(rddWithBigWeights, strategy, 3, "all", 42L, None) - unitWeightTrees.zip(bigWeightTrees).foreach { case (unitTree, bigWeightTree) => - TreeTests.checkEqual(unitTree, bigWeightTree) + unitWeightTrees.zip(bigWeightTrees).foreach { + case (unitTree, bigWeightTree) => + TreeTests.checkEqual(unitTree, bigWeightTree) } } @@ -754,11 +965,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { Instance(0.0, 1.0, Vectors.dense(0.0)), Instance(0.0, 1.0, Vectors.dense(0.0)), Instance(0.0, 1.0, Vectors.dense(0.0)), - Instance(1.0, 0.1, Vectors.dense(1.0)) - ) + Instance(1.0, 0.1, Vectors.dense(1.0))) val rdd = sc.parallelize(data) - val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2, - minWeightFractionPerNode = 0.5) + val strategy = + new OldStrategy(OldAlgo.Classification, Gini, 3, 2, minWeightFractionPerNode = 0.5) val Array(tree1) = RandomForest.run(rdd, strategy, 1, "all", 42L, None) assert(tree1.depth === 0) @@ -777,6 +987,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { } private object RandomForestSuite { + def mapToVec(map: Map[Int, Double]): Vector = { val size = (map.keys.toSeq :+ 0).max + 1 val (indices, values) = map.toSeq.sortBy(_._1).unzip @@ -787,12 +998,12 @@ private object RandomForestSuite { private def getSumLeafCounters(nodes: List[Node], acc: Long = 0): Long = { if (nodes.isEmpty) { acc - } - else { + } else { nodes.head match { case i: InternalNode => getSumLeafCounters(i.leftChild :: i.rightChild :: nodes.tail, acc) case l: LeafNode => getSumLeafCounters(nodes.tail, acc + l.impurityStats.rawCount) } } } + } diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 79b57d7ed67ad..93290578654bf 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1337,7 +1337,7 @@ class _DecisionTreeClassifierParams(_DecisionTreeParams, _TreeClassifierParams): def __init__(self, *args): super(_DecisionTreeClassifierParams, self).__init__(*args) - self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, pruneTree=True, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", leafCol="", minWeightFractionPerNode=0.0) @@ -1428,13 +1428,13 @@ class DecisionTreeClassifier(_JavaProbabilisticClassifier, _DecisionTreeClassifi @keyword_only def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, pruneTree=True, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0): """ __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", \ - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, pruneTree=True, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \ seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0) """ @@ -1448,14 +1448,14 @@ def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="p @since("1.4.0") def setParams(self, *, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, pruneTree=True, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0): """ setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", \ - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, pruneTree=True\ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \ seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0) Sets params for the DecisionTreeClassifier. @@ -1478,6 +1478,12 @@ def setMaxBins(self, value): """ return self._set(maxBins=value) + def setPruneTree(self, value): + """ + Sets the value of :py:attr:`pruneTree`. + """ + return self._set(pruneTree=value) + def setMinInstancesPerNode(self, value): """ Sets the value of :py:attr:`minInstancesPerNode`. @@ -1580,7 +1586,7 @@ class _RandomForestClassifierParams(_RandomForestParams, _TreeClassifierParams): def __init__(self, *args): super(_RandomForestClassifierParams, self).__init__(*args) - self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, pruneTree=True, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0, leafCol="", minWeightFractionPerNode=0.0, @@ -1667,14 +1673,14 @@ class RandomForestClassifier(_JavaProbabilisticClassifier, _RandomForestClassifi @keyword_only def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, pruneTree=True, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0, leafCol="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True): """ __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", \ - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, pruneTree=True\ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \ numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0, \ leafCol="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True) @@ -1689,14 +1695,14 @@ def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="p @since("1.4.0") def setParams(self, *, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, pruneTree=True, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0, leafCol="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", \ - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, pruneTree=True, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \ impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0, \ leafCol="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True) @@ -1720,6 +1726,12 @@ def setMaxBins(self, value): """ return self._set(maxBins=value) + def setPruneTree(self, value): + """ + Sets the value of :py:attr:`pruneTree`. + """ + return self._set(pruneTree=value) + def setMinInstancesPerNode(self, value): """ Sets the value of :py:attr:`minInstancesPerNode`. diff --git a/python/pyspark/ml/tree.py b/python/pyspark/ml/tree.py index 7ddeb097c4676..8404d3980e1b4 100644 --- a/python/pyspark/ml/tree.py +++ b/python/pyspark/ml/tree.py @@ -338,6 +338,14 @@ class _TreeClassifierParams(Params): "Supported options: " + ", ".join(supportedImpurities), typeConverter=TypeConverters.toString) + pruneTree = Param(Params._dummy(), "pruneTree", "" + + "If true, the trained tree will undergo a 'pruning' process after training in which nodes" + + " that have the same class predictions will be merged. This drawback means that the class" + + " probabilities will be lost. The benefit being that at prediction time the tree will be" + + " smaller and have faster predictions" + + " If false, the post-training tree will undergo no pruning. The benefit being that you" + + " maintain the class prediction probabilities", typeConverter=TypeConverters.toBoolean) + def __init__(self): super(_TreeClassifierParams, self).__init__() @@ -347,6 +355,12 @@ def getImpurity(self): Gets the value of impurity or its default value. """ return self.getOrDefault(self.impurity) + @since("3.1.2") + def getPruneTree(self): + """ + Gets the value of pruneTree or its default value. + """ + return self.getOrDefault(self.pruneTree) class _TreeRegressorParams(_HasVarianceImpurity):