diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 9f60f0896ec52..a40c0e03674b5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -76,6 +76,10 @@ class DecisionTreeClassifier @Since("1.4.0") ( @Since("1.4.0") override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + /** @group expertSetParam */ + @Since("2.2.0") + def setCanMergeChildren(value: Boolean): this.type = set(canMergeChildren, value) + /** * Specifies how often to checkpoint the cached node IDs. * E.g. 10 means that the cache will get checkpointed every 10 iterations. diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index ade0960f87a0d..28e0b65e7bcad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -92,6 +92,10 @@ class GBTClassifier @Since("1.4.0") ( @Since("1.4.0") override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + /** @group expertSetParam */ + @Since("2.2.0") + def setCanMergeChildren(value: Boolean): this.type = set(canMergeChildren, value) + /** * Specifies how often to checkpoint the cached node IDs. * E.g. 10 means that the cache will get checkpointed every 10 iterations. diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index ab4c235209289..b9e1745f02d01 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -78,6 +78,10 @@ class RandomForestClassifier @Since("1.4.0") ( @Since("1.4.0") override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + /** @group expertSetParam */ + @Since("2.2.0") + def setCanMergeChildren(value: Boolean): this.type = set(canMergeChildren, value) + /** * Specifies how often to checkpoint the cached node IDs. * E.g. 10 means that the cache will get checkpointed every 10 iterations. diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index 07e98a142b10e..278ba8d76ec0d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -342,6 +342,53 @@ private[tree] object LearningNode { new LearningNode(nodeIndex, None, None, None, false, null) } + /** merge the pair of leave of same parent recursively if their prediction is same. */ + def mergeChildrenWithSamePrediction(node: LearningNode): Int = { + + /** find the mergeable leave and merge them. */ + def checkAndMerge(node: LearningNode): (Option[Double], Int) = { + if (node == null) { + (None, 0) + + } else if (node.isLeaf || + // sometimes, node is terminal while its isLeaf is not set. + (node.leftChild.isEmpty && node.rightChild.isEmpty)) { + (Some(node.stats.impurityCalculator.predict), 0) + + } else { + val (leftNode, leftMergeCounts) = checkAndMerge(node.leftChild.orNull) + val (rightNode, rightMergeCounts) = checkAndMerge(node.rightChild.orNull) + val mergeCounts = leftMergeCounts + rightMergeCounts + + if (leftNode.isDefined && rightNode.isDefined && leftNode == rightNode) { + removeChildren(node) + node.isLeaf = true + + (Some(node.stats.impurityCalculator.predict), mergeCounts + 1) + + } else { + (None, mergeCounts) + } + } + } + + val (_, mergeCounts) = checkAndMerge(node) + + mergeCounts + } + + /** delete all children of one node. */ + def removeChildren(learningNode: LearningNode): Unit = { + val left = learningNode.leftChild + val right = learningNode.rightChild + + learningNode.leftChild = None + learningNode.rightChild = None + + left.foreach(removeChildren) + right.foreach(removeChildren) + } + // The below indexing methods were copied from spark.mllib.tree.model.Node /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 82e1ed85a0a14..9b53237a3a73d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -217,6 +217,23 @@ private[spark] object RandomForest extends Logging { } } + // prune tree. + if (strategy.canMergeChildren) { + logInfo("Merge children with same prediction:") + + // merge the pair of leaves if possible. + val mergeCounts = topNodes.map(LearningNode.mergeChildrenWithSamePrediction) + + val mergeCountsOfTreesInfo: String = mergeCounts + .zipWithIndex + .map(_.swap) + .filter(_._2 > 0) + .map { case (id, count) => s"tree: $id, num of nodes merged: $count" } + .mkString("Merge info:\n", "\n", "\n") + + logInfo(mergeCountsOfTreesInfo) + } + val numFeatures = metadata.numFeatures parentUID match { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 3fc3ac58b7795..c4e8c48a6b31e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -106,8 +106,18 @@ private[ml] trait DecisionTreeParams extends PredictorParams " algorithm will cache node IDs for each instance. Caching can speed up training of deeper" + " trees.") + /** + * If true, tree will try to merge its leaf nodes which share the same parent and output + * the same prediction. + * (default = false) + * @group expertParam + */ + final val canMergeChildren: BooleanParam = new BooleanParam(this, "canMergeChildren", "If true," + + " the tree will try to merge its leaf nodes which share the same parent and output.") + setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0, - maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10) + maxMemoryInMB -> 256, cacheNodeIds -> false, canMergeChildren -> false, + checkpointInterval -> 10) /** * @deprecated This method is deprecated and will be removed in 3.0.0. @@ -176,6 +186,9 @@ private[ml] trait DecisionTreeParams extends PredictorParams /** @group expertGetParam */ final def getCacheNodeIds: Boolean = $(cacheNodeIds) + /** @group expertGetParam */ + final def getCanMergeChildren: Boolean = $(canMergeChildren) + /** * @deprecated This method is deprecated and will be removed in 3.0.0. * @group setParam @@ -199,6 +212,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams strategy.minInfoGain = getMinInfoGain strategy.minInstancesPerNode = getMinInstancesPerNode strategy.useNodeIdCache = getCacheNodeIds + strategy.canMergeChildren = getCanMergeChildren strategy.numClasses = numClasses strategy.categoricalFeaturesInfo = categoricalFeatures strategy.subsamplingRate = subsamplingRate diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 58e8f5be7b9f0..b584ae3a34c2d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -61,6 +61,8 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} * @param subsamplingRate Fraction of the training data used for learning decision tree. * @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will * maintain a separate RDD of node Id cache for each row. + * @param canMergeChildren Merge pairs of leaf nodes of the same parent which + * output the same prediction. * @param checkpointInterval How often to checkpoint when the node Id cache gets updated. * E.g. 10 means that the cache will get checkpointed every 10 updates. If * the checkpoint directory is not set in @@ -80,6 +82,7 @@ class Strategy @Since("1.3.0") ( @Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256, @Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1, @Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false, + @Since("2.2.0") @BeanProperty var canMergeChildren: Boolean = false, @Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10) extends Serializable { /** @@ -172,7 +175,7 @@ class Strategy @Since("1.3.0") ( def copy: Strategy = { new Strategy(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, - maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval) + maxMemoryInMB, subsamplingRate, useNodeIdCache, canMergeChildren, checkpointInterval) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 918ab27e2730b..2eb12c6fe0604 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode} +import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode, Node} import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} @@ -350,6 +350,31 @@ class DecisionTreeClassifierSuite dt.fit(df) } + test("SPARK-3159: Check for reducible DecisionTree") { + import DecisionTreeClassifierSuite.hasPairsOfSameChildren + + val df: DataFrame = TreeTests.getIrisDataset(sc).toDF() + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxBins(2) + .setMinInfoGain(0) + .setMinInstancesPerNode(1) + .setMaxDepth(3) + .setSeed(0) + + val m1 = dt.setCanMergeChildren(false).fit(df) + assert(true === hasPairsOfSameChildren(m1.rootNode)) + + val m2 = dt.setCanMergeChildren(true).fit(df) + assert(false === hasPairsOfSameChildren(m2.rootNode)) + + val p1 = m1.transform(df).select(m1.getPredictionCol).collect() + val p2 = m2.transform(df).select(m2.getPredictionCol).collect() + assert(p1.zip(p2).forall { case (y1, y2) => + y1.getDouble(0) === y2.getDouble(0) + }) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// @@ -422,4 +447,25 @@ private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { TreeTests.checkEqual(oldTreeAsNew, newTree) assert(newTree.numFeatures === numFeatures) } + + /** check if there exists pairs of leaf nodes with same prediction of the same parent. */ + def hasPairsOfSameChildren(node: Node): Boolean = { + def check(node: Node): (Boolean, Option[Double]) = node match { + case n: LeafNode => (false, Some(n.prediction)) + case n: InternalNode => + val (leftFound, leftPredict) = check(n.leftChild) + val (rightFound, rightPredict) = check(n.rightChild) + + if (leftFound || rightFound || + (leftPredict.isDefined && leftPredict == rightPredict)) { + (true, None) + } else { + (false, None) + } + } + + val (found, _) = check(node) + + found + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/LearningNodeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/LearningNodeSuite.scala new file mode 100644 index 0000000000000..739b343e805d8 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/LearningNodeSuite.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.tree.impurity.GiniCalculator +import org.apache.spark.mllib.tree.model.ImpurityStats +import org.apache.spark.mllib.util.MLlibTestSparkContext + +/** + * Test suite for [[LearningNode]]. + */ +class LearningNodeSuite extends SparkFunSuite with MLlibTestSparkContext { + import LearningNodeSuite._ + + test("SPARK-3159: Check for reducible DecisionTree") { + val classA = TreeUtils.makeImpurityStats(2, 0) + val classB = TreeUtils.makeImpurityStats(2, 1) + + val root = TreeUtils.fillBinaryTree(classA)(3) + NodeUtils.setAsLeaf(root, Array(5)) + NodeUtils.setPredictValue(root, Array(6, 7, 12, 13, 14), classB) + /** + * use TreeUtils.toDebugTreeString(root). + * input: + 8(0.0) + 4--| + 9(0.0) + 2--| + 5(0.0) + 1--| + 12(1.0) + 6--| + 13(1.0) + 3--| + 14(1.0) + 7--| + 15(0.0) + */ + assert(true === hasPairsOfSameChildren(root)) + + val mergeCounts = LearningNode.mergeChildrenWithSamePrediction(root) + /** + * use TreeUtils.toDebugTreeString(root). + * result: + 2(0.0) + 1--| + 6(1.0) + 3--| + 14(1.0) + 7--| + 15(0.0) + */ + assert(false === hasPairsOfSameChildren(root)) + assert(mergeCounts === 3) + } +} + +object LearningNodeSuite { + + /** check if there exists pairs of leaf nodes with same prediction of the same parent. */ + def hasPairsOfSameChildren(node: LearningNode): Boolean = + if (node.isLeaf) { + false + + } else { + val left = node.leftChild.get + val right = node.rightChild.get + + if (left.isLeaf && right.isLeaf) { + val leftPredict = left.stats.impurityCalculator.predict + val rightPredict = right.stats.impurityCalculator.predict + + leftPredict == rightPredict + + } else { + // shortcut if find. + hasPairsOfSameChildren(left) || hasPairsOfSameChildren(right) + } + } + + /** helper methods for constructing tree. */ + object TreeUtils { + import LearningNode._ + + /** construct a full binary tree with same impurityStats. */ + def fillBinaryTree(impurityStats: ImpurityStats)(maxHeight: Int): LearningNode = { + def create(id: Int, height: Int): LearningNode = { + if (height == 0) { + new LearningNode(id, None, None, None, true, impurityStats) + + } else { + val leftNode = create(leftChildIndex(id), height - 1) + val rightNode = create(rightChildIndex(id), height - 1) + // use id to help locate node when debug. + val split = new ContinuousSplit(id, id) + + new LearningNode( + id, Some(leftNode), Some(rightNode), + Some(split), false, impurityStats) + } + } + + create(1, maxHeight) + } + + /** create an ImpurityStats for classification. */ + def makeImpurityStats(numClass: Int, predictClassId: Int): ImpurityStats = { + val stat = Array.fill(numClass)(0.0) + stat(predictClassId) = 1.0 + + val calculator = new GiniCalculator(stat) + + new ImpurityStats(0.0, 0.0, calculator, null, null) + } + + /** Full description of model */ + def toDebugTreeString(node: LearningNode, indent: Int = 0): String = { + val prefix: String = " " * indent + val id = node.id + + if (node.isLeaf) { + val predict = node.stats.impurityCalculator.predict + + prefix + id + "(" + predict + ")" + "\n" + + } else { + val left = toDebugTreeString(node.leftChild.get, indent + 1) + val right = toDebugTreeString(node.rightChild.get, indent + 1) + + left + + prefix + id + "--|\n" + + right + } + } + } + + /** helper methods used to operate nodes. */ + object NodeUtils { + import LearningNode._ + + /** assign the ImpurityStats to all nodes required in nodeIds. */ + def setPredictValue(root: LearningNode, + nodeIds: Array[Int], + impurityStats: ImpurityStats): Unit = + nodeIds.foreach { id => + val node = getNode(id, root) + + node.stats = impurityStats + } + + /** set internal nodes as leaf. */ + def setAsLeaf(root: LearningNode, nodeIds: Array[Int]): Unit = + nodeIds.foreach { id => + val node = getNode(id, root) + + if (! node.isLeaf) { + removeChildren(node) + node.isLeaf = true + } + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index 92a236928e90b..04e5d54ebd474 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -227,4 +227,163 @@ private[ml] object TreeTests extends SparkFunSuite { LabeledPoint(1.0, Vectors.dense(1.0, 2.0))) sc.parallelize(arr) } + + /** Iris dataset. + * @see http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/iris.scale + */ + def getIrisDataset(sc: SparkContext): RDD[LabeledPoint] = { + val arr = Array( + LabeledPoint(1, Vectors.dense(-0.555556, 0.25, -0.864407, -0.916667)), + LabeledPoint(1, Vectors.dense(-0.666667, -0.166667, -0.864407, -0.916667)), + LabeledPoint(1, Vectors.dense(-0.777778, 0.0, -0.898305, -0.916667)), + LabeledPoint(1, Vectors.dense(-0.833333, -0.0833334, -0.830508, -0.916667)), + LabeledPoint(1, Vectors.dense(-0.611111, 0.333333, -0.864407, -0.916667)), + LabeledPoint(1, Vectors.dense(-0.388889, 0.583333, -0.762712, -0.75)), + LabeledPoint(1, Vectors.dense(-0.833333, 0.166667, -0.864407, -0.833333)), + LabeledPoint(1, Vectors.dense(-0.611111, 0.166667, -0.830508, -0.916667)), + LabeledPoint(1, Vectors.dense(-0.944444, -0.25, -0.864407, -0.916667)), + LabeledPoint(1, Vectors.dense(-0.666667, -0.0833334, -0.830508, -1.0)), + LabeledPoint(1, Vectors.dense(-0.388889, 0.416667, -0.830508, -0.916667)), + LabeledPoint(1, Vectors.dense(-0.722222, 0.166667, -0.79661, -0.916667)), + LabeledPoint(1, Vectors.dense(-0.722222, -0.166667, -0.864407, -1.0)), + LabeledPoint(1, Vectors.dense(-1.0, -0.166667, -0.966102, -1.0)), + LabeledPoint(1, Vectors.dense(-0.166667, 0.666667, -0.932203, -0.916667)), + LabeledPoint(1, Vectors.dense(-0.222222, 1.0, -0.830508, -0.75)), + LabeledPoint(1, Vectors.dense(-0.388889, 0.583333, -0.898305, -0.75)), + LabeledPoint(1, Vectors.dense(-0.555556, 0.25, -0.864407, -0.833333)), + LabeledPoint(1, Vectors.dense(-0.222222, 0.5, -0.762712, -0.833333)), + LabeledPoint(1, Vectors.dense(-0.555556, 0.5, -0.830508, -0.833333)), + LabeledPoint(1, Vectors.dense(-0.388889, 0.166667, -0.762712, -0.916667)), + LabeledPoint(1, Vectors.dense(-0.555556, 0.416667, -0.830508, -0.75)), + LabeledPoint(1, Vectors.dense(-0.833333, 0.333333, -1.0, -0.916667)), + LabeledPoint(1, Vectors.dense(-0.555556, 0.0833333, -0.762712, -0.666667)), + LabeledPoint(1, Vectors.dense(-0.722222, 0.166667, -0.694915, -0.916667)), + LabeledPoint(1, Vectors.dense(-0.611111, -0.166667, -0.79661, -0.916667)), + LabeledPoint(1, Vectors.dense(-0.611111, 0.166667, -0.79661, -0.75)), + LabeledPoint(1, Vectors.dense(-0.5, 0.25, -0.830508, -0.916667)), + LabeledPoint(1, Vectors.dense(-0.5, 0.166667, -0.864407, -0.916667)), + LabeledPoint(1, Vectors.dense(-0.777778, 0.0, -0.79661, -0.916667)), + LabeledPoint(1, Vectors.dense(-0.722222, -0.0833334, -0.79661, -0.916667)), + LabeledPoint(1, Vectors.dense(-0.388889, 0.166667, -0.830508, -0.75)), + LabeledPoint(1, Vectors.dense(-0.5, 0.75, -0.830508, -1.0)), + LabeledPoint(1, Vectors.dense(-0.333333, 0.833333, -0.864407, -0.916667)), + LabeledPoint(1, Vectors.dense(-0.666667, -0.0833334, -0.830508, -1.0)), + LabeledPoint(1, Vectors.dense(-0.611111, 0.0, -0.932203, -0.916667)), + LabeledPoint(1, Vectors.dense(-0.333333, 0.25, -0.898305, -0.916667)), + LabeledPoint(1, Vectors.dense(-0.666667, -0.0833334, -0.830508, -1.0)), + LabeledPoint(1, Vectors.dense(-0.944444, -0.166667, -0.898305, -0.916667)), + LabeledPoint(1, Vectors.dense(-0.555556, 0.166667, -0.830508, -0.916667)), + LabeledPoint(1, Vectors.dense(-0.611111, 0.25, -0.898305, -0.833333)), + LabeledPoint(1, Vectors.dense(-0.888889, -0.75, -0.898305, -0.833333)), + LabeledPoint(1, Vectors.dense(-0.944444, 0.0, -0.898305, -0.916667)), + LabeledPoint(1, Vectors.dense(-0.611111, 0.25, -0.79661, -0.583333)), + LabeledPoint(1, Vectors.dense(-0.555556, 0.5, -0.694915, -0.75)), + LabeledPoint(1, Vectors.dense(-0.722222, -0.166667, -0.864407, -0.833333)), + LabeledPoint(1, Vectors.dense(-0.555556, 0.5, -0.79661, -0.916667)), + LabeledPoint(1, Vectors.dense(-0.833333, 0.0, -0.864407, -0.916667)), + LabeledPoint(1, Vectors.dense(-0.444444, 0.416667, -0.830508, -0.916667)), + LabeledPoint(1, Vectors.dense(-0.611111, 0.0833333, -0.864407, -0.916667)), + LabeledPoint(2, Vectors.dense(0.5, 0.0, 0.254237, 0.0833333)), + LabeledPoint(2, Vectors.dense(0.166667, 0.0, 0.186441, 0.166667)), + LabeledPoint(2, Vectors.dense(0.444444, -0.0833334, 0.322034, 0.166667)), + LabeledPoint(2, Vectors.dense(-0.333333, -0.75, 0.0169491, -4.03573E-8)), + LabeledPoint(2, Vectors.dense(0.222222, -0.333333, 0.220339, 0.166667)), + LabeledPoint(2, Vectors.dense(-0.222222, -0.333333, 0.186441, -4.03573E-8)), + LabeledPoint(2, Vectors.dense(0.111111, 0.0833333, 0.254237, 0.25)), + LabeledPoint(2, Vectors.dense(-0.666667, -0.666667, -0.220339, -0.25)), + LabeledPoint(2, Vectors.dense(0.277778, -0.25, 0.220339, -4.03573E-8)), + LabeledPoint(2, Vectors.dense(-0.5, -0.416667, -0.0169491, 0.0833333)), + LabeledPoint(2, Vectors.dense(-0.611111, -1.0, -0.152542, -0.25)), + LabeledPoint(2, Vectors.dense(-0.111111, -0.166667, 0.0847457, 0.166667)), + LabeledPoint(2, Vectors.dense(-0.0555556, -0.833333, 0.0169491, -0.25)), + LabeledPoint(2, Vectors.dense(-1.32455E-7, -0.25, 0.254237, 0.0833333)), + LabeledPoint(2, Vectors.dense(-0.277778, -0.25, -0.118644, -4.03573E-8)), + LabeledPoint(2, Vectors.dense(0.333333, -0.0833334, 0.152542, 0.0833333)), + LabeledPoint(2, Vectors.dense(-0.277778, -0.166667, 0.186441, 0.166667)), + LabeledPoint(2, Vectors.dense(-0.166667, -0.416667, 0.0508474, -0.25)), + LabeledPoint(2, Vectors.dense(0.0555554, -0.833333, 0.186441, 0.166667)), + LabeledPoint(2, Vectors.dense(-0.277778, -0.583333, -0.0169491, -0.166667)), + LabeledPoint(2, Vectors.dense(-0.111111, 0.0, 0.288136, 0.416667)), + LabeledPoint(2, Vectors.dense(-1.32455E-7, -0.333333, 0.0169491, -4.03573E-8)), + LabeledPoint(2, Vectors.dense(0.111111, -0.583333, 0.322034, 0.166667)), + LabeledPoint(2, Vectors.dense(-1.32455E-7, -0.333333, 0.254237, -0.0833333)), + LabeledPoint(2, Vectors.dense(0.166667, -0.25, 0.118644, -4.03573E-8)), + LabeledPoint(2, Vectors.dense(0.277778, -0.166667, 0.152542, 0.0833333)), + LabeledPoint(2, Vectors.dense(0.388889, -0.333333, 0.288136, 0.0833333)), + LabeledPoint(2, Vectors.dense(0.333333, -0.166667, 0.355932, 0.333333)), + LabeledPoint(2, Vectors.dense(-0.0555556, -0.25, 0.186441, 0.166667)), + LabeledPoint(2, Vectors.dense(-0.222222, -0.5, -0.152542, -0.25)), + LabeledPoint(2, Vectors.dense(-0.333333, -0.666667, -0.0508475, -0.166667)), + LabeledPoint(2, Vectors.dense(-0.333333, -0.666667, -0.0847458, -0.25)), + LabeledPoint(2, Vectors.dense(-0.166667, -0.416667, -0.0169491, -0.0833333)), + LabeledPoint(2, Vectors.dense(-0.0555556, -0.416667, 0.38983, 0.25)), + LabeledPoint(2, Vectors.dense(-0.388889, -0.166667, 0.186441, 0.166667)), + LabeledPoint(2, Vectors.dense(-0.0555556, 0.166667, 0.186441, 0.25)), + LabeledPoint(2, Vectors.dense(0.333333, -0.0833334, 0.254237, 0.166667)), + LabeledPoint(2, Vectors.dense(0.111111, -0.75, 0.152542, -4.03573E-8)), + LabeledPoint(2, Vectors.dense(-0.277778, -0.166667, 0.0508474, -4.03573E-8)), + LabeledPoint(2, Vectors.dense(-0.333333, -0.583333, 0.0169491, -4.03573E-8)), + LabeledPoint(2, Vectors.dense(-0.333333, -0.5, 0.152542, -0.0833333)), + LabeledPoint(2, Vectors.dense(-1.32455E-7, -0.166667, 0.220339, 0.0833333)), + LabeledPoint(2, Vectors.dense(-0.166667, -0.5, 0.0169491, -0.0833333)), + LabeledPoint(2, Vectors.dense(-0.611111, -0.75, -0.220339, -0.25)), + LabeledPoint(2, Vectors.dense(-0.277778, -0.416667, 0.0847457, -4.03573E-8)), + LabeledPoint(2, Vectors.dense(-0.222222, -0.166667, 0.0847457, -0.0833333)), + LabeledPoint(2, Vectors.dense(-0.222222, -0.25, 0.0847457, -4.03573E-8)), + LabeledPoint(2, Vectors.dense(0.0555554, -0.25, 0.118644, -4.03573E-8)), + LabeledPoint(2, Vectors.dense(-0.555556, -0.583333, -0.322034, -0.166667)), + LabeledPoint(2, Vectors.dense(-0.222222, -0.333333, 0.0508474, -4.03573E-8)), + LabeledPoint(3, Vectors.dense(0.111111, 0.0833333, 0.694915, 1.0)), + LabeledPoint(3, Vectors.dense(-0.166667, -0.416667, 0.38983, 0.5)), + LabeledPoint(3, Vectors.dense(0.555555, -0.166667, 0.661017, 0.666667)), + LabeledPoint(3, Vectors.dense(0.111111, -0.25, 0.559322, 0.416667)), + LabeledPoint(3, Vectors.dense(0.222222, -0.166667, 0.627119, 0.75)), + LabeledPoint(3, Vectors.dense(0.833333, -0.166667, 0.898305, 0.666667)), + LabeledPoint(3, Vectors.dense(-0.666667, -0.583333, 0.186441, 0.333333)), + LabeledPoint(3, Vectors.dense(0.666667, -0.25, 0.79661, 0.416667)), + LabeledPoint(3, Vectors.dense(0.333333, -0.583333, 0.627119, 0.416667)), + LabeledPoint(3, Vectors.dense(0.611111, 0.333333, 0.728813, 1.0)), + LabeledPoint(3, Vectors.dense(0.222222, 0.0, 0.38983, 0.583333)), + LabeledPoint(3, Vectors.dense(0.166667, -0.416667, 0.457627, 0.5)), + LabeledPoint(3, Vectors.dense(0.388889, -0.166667, 0.525424, 0.666667)), + LabeledPoint(3, Vectors.dense(-0.222222, -0.583333, 0.355932, 0.583333)), + LabeledPoint(3, Vectors.dense(-0.166667, -0.333333, 0.38983, 0.916667)), + LabeledPoint(3, Vectors.dense(0.166667, 0.0, 0.457627, 0.833333)), + LabeledPoint(3, Vectors.dense(0.222222, -0.166667, 0.525424, 0.416667)), + LabeledPoint(3, Vectors.dense(0.888889, 0.5, 0.932203, 0.75)), + LabeledPoint(3, Vectors.dense(0.888889, -0.5, 1.0, 0.833333)), + LabeledPoint(3, Vectors.dense(-0.0555556, -0.833333, 0.355932, 0.166667)), + LabeledPoint(3, Vectors.dense(0.444444, 0.0, 0.59322, 0.833333)), + LabeledPoint(3, Vectors.dense(-0.277778, -0.333333, 0.322034, 0.583333)), + LabeledPoint(3, Vectors.dense(0.888889, -0.333333, 0.932203, 0.583333)), + LabeledPoint(3, Vectors.dense(0.111111, -0.416667, 0.322034, 0.416667)), + LabeledPoint(3, Vectors.dense(0.333333, 0.0833333, 0.59322, 0.666667)), + LabeledPoint(3, Vectors.dense(0.611111, 0.0, 0.694915, 0.416667)), + LabeledPoint(3, Vectors.dense(0.0555554, -0.333333, 0.288136, 0.416667)), + LabeledPoint(3, Vectors.dense(-1.32455E-7, -0.166667, 0.322034, 0.416667)), + LabeledPoint(3, Vectors.dense(0.166667, -0.333333, 0.559322, 0.666667)), + LabeledPoint(3, Vectors.dense(0.611111, -0.166667, 0.627119, 0.25)), + LabeledPoint(3, Vectors.dense(0.722222, -0.333333, 0.728813, 0.5)), + LabeledPoint(3, Vectors.dense(1.0, 0.5, 0.830508, 0.583333)), + LabeledPoint(3, Vectors.dense(0.166667, -0.333333, 0.559322, 0.75)), + LabeledPoint(3, Vectors.dense(0.111111, -0.333333, 0.38983, 0.166667)), + LabeledPoint(3, Vectors.dense(-1.32455E-7, -0.5, 0.559322, 0.0833333)), + LabeledPoint(3, Vectors.dense(0.888889, -0.166667, 0.728813, 0.833333)), + LabeledPoint(3, Vectors.dense(0.111111, 0.166667, 0.559322, 0.916667)), + LabeledPoint(3, Vectors.dense(0.166667, -0.0833334, 0.525424, 0.416667)), + LabeledPoint(3, Vectors.dense(-0.0555556, -0.166667, 0.288136, 0.416667)), + LabeledPoint(3, Vectors.dense(0.444444, -0.0833334, 0.491525, 0.666667)), + LabeledPoint(3, Vectors.dense(0.333333, -0.0833334, 0.559322, 0.916667)), + LabeledPoint(3, Vectors.dense(0.444444, -0.0833334, 0.38983, 0.833333)), + LabeledPoint(3, Vectors.dense(-0.166667, -0.416667, 0.38983, 0.5)), + LabeledPoint(3, Vectors.dense(0.388889, 0.0, 0.661017, 0.833333)), + LabeledPoint(3, Vectors.dense(0.333333, 0.0833333, 0.59322, 1.0)), + LabeledPoint(3, Vectors.dense(0.333333, -0.166667, 0.423729, 0.833333)), + LabeledPoint(3, Vectors.dense(0.111111, -0.583333, 0.355932, 0.5)), + LabeledPoint(3, Vectors.dense(0.222222, -0.166667, 0.423729, 0.583333)), + LabeledPoint(3, Vectors.dense(0.0555554, 0.166667, 0.491525, 0.833333)), + LabeledPoint(3, Vectors.dense(-0.111111, -0.166667, 0.38983, 0.416667))) + + sc.parallelize(arr) + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 1793da03a2c3e..e7c2add8acf52 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,10 @@ object MimaExcludes { // Exclude rules for 2.3.x lazy val v23excludes = v22excludes ++ Seq( + // [SPARK-3159][MLlib] Check for reducible DecisionTree. + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.tree.configuration.Strategy.$default$13") + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.tree.configuration.Strategy.this") + // [SPARK-20495][SQL] Add StorageLevel to cacheTable API ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.cacheTable"),