diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 881dcefb79be..59aaa1cd457a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -82,6 +82,9 @@ class DecisionTreeClassifier @Since("1.4.0") (
@Since("1.6.0")
override def setSeed(value: Long): this.type = super.setSeed(value)
+ @Since("2.0.0")
+ override def setClassWeights(value: Array[Double]): this.type = super.setClassWeights(value)
+
override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
@@ -119,7 +122,7 @@ class DecisionTreeClassifier @Since("1.4.0") (
categoricalFeatures: Map[Int, Int],
numClasses: Int): OldStrategy = {
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity,
- subsamplingRate = 1.0)
+ subsamplingRate = 1.0, getClassWeights)
}
@Since("1.4.1")
@@ -129,7 +132,7 @@ class DecisionTreeClassifier @Since("1.4.0") (
@Since("1.4.0")
@Experimental
object DecisionTreeClassifier extends DefaultParamsReadable[DecisionTreeClassifier] {
- /** Accessor for supported impurities: entropy, gini */
+ /** Accessor for supported impurities: entropy, gini, weightedgini */
@Since("1.4.0")
final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
@@ -168,7 +171,7 @@ class DecisionTreeClassificationModel private[ml] (
}
override protected def predictRaw(features: Vector): Vector = {
- Vectors.dense(rootNode.predictImpl(features).impurityStats.stats.clone())
+ Vectors.dense(rootNode.predictImpl(features).impurityStats.weightedStats.clone())
}
override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index b3c074f83925..5e61b759c7c6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -98,13 +98,17 @@ class RandomForestClassifier @Since("1.4.0") (
override def setFeatureSubsetStrategy(value: String): this.type =
super.setFeatureSubsetStrategy(value)
+ @Since("2.0.0")
+ override def setClassWeights(value: Array[Double]): this.type = super.setClassWeights(value)
+
override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = getNumClasses(dataset)
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
val strategy =
- super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
+ super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification,
+ getOldImpurity, getSubsamplingRate, getClassWeights)
val instr = Instrumentation.create(this, oldDataset)
instr.logParams(params: _*)
@@ -195,7 +199,8 @@ class RandomForestClassificationModel private[ml] (
// Ignore the tree weights since all are 1.0 for now.
val votes = Array.fill[Double](numClasses)(0.0)
_trees.view.foreach { tree =>
- val classCounts: Array[Double] = tree.rootNode.predictImpl(features).impurityStats.stats
+ val classCounts: Array[Double] =
+ tree.rootNode.predictImpl(features).impurityStats.weightedStats
val total = classCounts.sum
if (total != 0) {
var i = 0
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index c4df9d11127f..b2fe5ded6179 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -117,7 +117,7 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
/** (private[ml]) Create a Strategy instance to use with the old API. */
private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = {
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity,
- subsamplingRate = 1.0)
+ subsamplingRate = 1.0, classWeights = Array())
}
@Since("1.4.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index a6dbf21d55e2..9429a053804d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -98,7 +98,8 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val strategy =
- super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)
+ super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression,
+ getOldImpurity, getSubsamplingRate, classWeights = Array())
val instr = Instrumentation.create(this, oldDataset)
instr.logParams(params: _*)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala
index 61091bb803e4..3d175006b9ab 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala
@@ -20,7 +20,6 @@ package org.apache.spark.ml.tree.impl
import org.apache.spark.mllib.tree.impurity._
-
/**
* DecisionTree statistics aggregator for a node.
* This holds a flat array of statistics for a set of (features, bins)
@@ -38,6 +37,7 @@ private[spark] class DTStatsAggregator(
case Gini => new GiniAggregator(metadata.numClasses)
case Entropy => new EntropyAggregator(metadata.numClasses)
case Variance => new VarianceAggregator()
+ case WeightedGini => new WeightedGiniAggregator(metadata.numClasses, metadata.classWeights)
case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}")
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
index 442f52bf0231..a8ad966adf1c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
@@ -53,7 +53,8 @@ private[spark] class DecisionTreeMetadata(
val minInstancesPerNode: Int,
val minInfoGain: Double,
val numTrees: Int,
- val numFeaturesPerNode: Int) extends Serializable {
+ val numFeaturesPerNode: Int,
+ val classWeights: Array[Double]) extends Serializable {
def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex)
@@ -207,7 +208,8 @@ private[spark] object DecisionTreeMetadata extends Logging {
new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth,
- strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode)
+ strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode,
+ strategy.classWeights)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 71c8c42ce5eb..fe83d602764a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -657,8 +657,15 @@ private[spark] object RandomForest extends Logging {
val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
val rightImpurity = rightImpurityCalculator.calculate()
- val leftWeight = leftCount / totalCount.toDouble
- val rightWeight = rightCount / totalCount.toDouble
+ // Weighted count is equivalent to normal count using Gini or Entropy impurity
+ // where the class weights are assumed to be uniform
+ val leftWeightedCount = leftImpurityCalculator.weightedCount
+ val rightWeightedCount = rightImpurityCalculator.weightedCount
+
+ val totalWeightedCount = leftWeightedCount + rightWeightedCount
+
+ val leftWeight = leftWeightedCount / totalWeightedCount.toDouble
+ val rightWeight = rightWeightedCount / totalWeightedCount.toDouble
val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index 56c85c9b53e1..029ccfec2e2c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -342,9 +342,17 @@ private[ml] object DecisionTreeModelReadWrite {
Param.jsonDecode[String](compact(render(impurityJson)))
}
+ // Get class weights to construct ImpurityCalculator. This value
+ // is ignored unless the impurity is WeightedGini
+ val classWeights: Array[Double] = {
+ val classWeightsJson: JValue = metadata.getParamValue("classWeights")
+ compact(render(classWeightsJson)).split("\\[|,|\\]")
+ .filter((s: String) => s.length() != 0).map((s: String) => s.toDouble)
+ }
+
val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath).as[NodeData]
- buildTreeFromNodes(data.collect(), impurityType)
+ buildTreeFromNodes(data.collect(), impurityType, classWeights)
}
/**
@@ -353,7 +361,8 @@ private[ml] object DecisionTreeModelReadWrite {
* @param impurityType Impurity type for this tree
* @return Root node of reconstructed tree
*/
- def buildTreeFromNodes(data: Array[NodeData], impurityType: String): Node = {
+ def buildTreeFromNodes(data: Array[NodeData], impurityType: String,
+ classWeights: Array[Double]): Node = {
// Load all nodes, sorted by ID.
val nodes = data.sortBy(_.id)
// Sanity checks; could remove
@@ -365,7 +374,8 @@ private[ml] object DecisionTreeModelReadWrite {
// traversal, this guarantees that child nodes will be built before parent nodes.
val finalNodes = new Array[Node](nodes.length)
nodes.reverseIterator.foreach { case n: NodeData =>
- val impurityStats = ImpurityCalculator.getCalculator(impurityType, n.impurityStats)
+ val impurityStats = ImpurityCalculator.getCalculator(impurityType,
+ n.impurityStats, classWeights)
val node = if (n.leftChild != -1) {
val leftChild = finalNodes(n.leftChild)
val rightChild = finalNodes(n.rightChild)
@@ -437,6 +447,15 @@ private[ml] object EnsembleModelReadWrite {
Param.jsonDecode[String](compact(render(impurityJson)))
}
+ // Get class weights to construct ImpurityCalculator. This value
+ // is ignored unless the impurity is WeightedGini
+ val classWeights: Array[Double] = {
+ val classWeightsJson: JValue = metadata.getParamValue("classWeights")
+ val classWeightsArray = compact(render(classWeightsJson)).split("\\[|,|\\]")
+ .filter((s: String) => s.length() != 0).map((s: String) => s.toDouble)
+ classWeightsArray
+ }
+
val treesMetadataPath = new Path(path, "treesMetadata").toString
val treesMetadataRDD: RDD[(Int, (Metadata, Double))] = sql.read.parquet(treesMetadataPath)
.select("treeID", "metadata", "weights").as[(Int, String, Double)].rdd.map {
@@ -454,7 +473,8 @@ private[ml] object EnsembleModelReadWrite {
val rootNodesRDD: RDD[(Int, Node)] =
nodeData.rdd.map(d => (d.treeID, d.nodeData)).groupByKey().map {
case (treeID: Int, nodeData: Iterable[NodeData]) =>
- treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType)
+ treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray,
+ impurityType, classWeights)
}
val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect()
(metadata, treesMetadata.zip(rootNodes), treesWeights)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index d7559f8950c3..aba5ab1aec45 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -24,7 +24,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
-import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
+import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance, WeightedGini}
import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError}
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
@@ -155,7 +155,31 @@ private[ml] trait DecisionTreeParams extends PredictorParams
*/
def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
- /** (private[ml]) Create a Strategy instance to use with the old API. */
+ /** (private[ml]) Create a Strategy instance. */
+ private[ml] def getOldStrategy(
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int,
+ oldAlgo: OldAlgo.Algo,
+ oldImpurity: OldImpurity,
+ subsamplingRate: Double,
+ classWeights: Array[Double]): OldStrategy = {
+ val strategy = OldStrategy.defaultStrategy(oldAlgo)
+ strategy.impurity = oldImpurity
+ strategy.checkpointInterval = getCheckpointInterval
+ strategy.maxBins = getMaxBins
+ strategy.maxDepth = getMaxDepth
+ strategy.maxMemoryInMB = getMaxMemoryInMB
+ strategy.minInfoGain = getMinInfoGain
+ strategy.minInstancesPerNode = getMinInstancesPerNode
+ strategy.useNodeIdCache = getCacheNodeIds
+ strategy.numClasses = numClasses
+ strategy.categoricalFeaturesInfo = categoricalFeatures
+ strategy.subsamplingRate = subsamplingRate
+ strategy.classWeights = classWeights
+ strategy
+ }
+
+ /** (private[ml]) Create a Strategy whose interface is compatible with the old API. */
private[ml] def getOldStrategy(
categoricalFeatures: Map[Int, Int],
numClasses: Int,
@@ -174,6 +198,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams
strategy.numClasses = numClasses
strategy.categoricalFeaturesInfo = categoricalFeatures
strategy.subsamplingRate = subsamplingRate
+ strategy.classWeights = Array(1.0, 1.0)
strategy
}
}
@@ -185,7 +210,7 @@ private[ml] trait TreeClassifierParams extends Params {
/**
* Criterion used for information gain calculation (case-insensitive).
- * Supported: "entropy" and "gini".
+ * Supported: "entropy", "gini" and "weightedgini".
* (default = gini)
* @group param
*/
@@ -194,7 +219,15 @@ private[ml] trait TreeClassifierParams extends Params {
s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}",
(value: String) => TreeClassifierParams.supportedImpurities.contains(value.toLowerCase))
- setDefault(impurity -> "gini")
+ /**
+ * An array that stores the weights of class labels. All elements must be non-negative.
+ * (default = Array(1.0, 1.0))
+ * @group Param
+ */
+ final val classWeights: DoubleArrayParam = new DoubleArrayParam(this, "classWeights", "An array" +
+ " that stores the weights of class labels. All elements must be non-negative.")
+
+ setDefault(impurity -> "gini", classWeights -> Array(1.0, 1.0))
/** @group setParam */
def setImpurity(value: String): this.type = set(impurity, value)
@@ -202,11 +235,18 @@ private[ml] trait TreeClassifierParams extends Params {
/** @group getParam */
final def getImpurity: String = $(impurity).toLowerCase
+ /** @group SetParam */
+ def setClassWeights(value: Array[Double]): this.type = set(classWeights, value)
+
+ /** @group GetParam */
+ final def getClassWeights: Array[Double] = $(classWeights)
+
/** Convert new impurity to old impurity. */
private[ml] def getOldImpurity: OldImpurity = {
getImpurity match {
case "entropy" => OldEntropy
case "gini" => OldGini
+ case "weightedgini" => WeightedGini
case _ =>
// Should never happen because of check in setter method.
throw new RuntimeException(
@@ -217,7 +257,8 @@ private[ml] trait TreeClassifierParams extends Params {
private[ml] object TreeClassifierParams {
// These options should be lowercase.
- final val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase)
+ final val supportedImpurities: Array[String] = Array("entropy", "gini", "weightedgini")
+ .map(_.toLowerCase)
}
private[ml] trait DecisionTreeClassifierParams
@@ -239,7 +280,16 @@ private[ml] trait TreeRegressorParams extends Params {
s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}",
(value: String) => TreeRegressorParams.supportedImpurities.contains(value.toLowerCase))
- setDefault(impurity -> "variance")
+ /**
+ * An array that stores the weights of class labels. This parameter will be ignored in
+ * regression trees.
+ * (default = Array())
+ * @group Param
+ */
+ final val classWeights: DoubleArrayParam = new DoubleArrayParam(this, "classWeights", "An array" +
+ " that stores the weights of class labels. All elements must be non-negative.")
+
+ setDefault(impurity -> "variance", classWeights -> Array())
/** @group setParam */
def setImpurity(value: String): this.type = set(impurity, value)
@@ -247,6 +297,12 @@ private[ml] trait TreeRegressorParams extends Params {
/** @group getParam */
final def getImpurity: String = $(impurity).toLowerCase
+ /** @group SetParam */
+ def setClassWeights(value: Array[Double]): this.type = set(classWeights, value)
+
+ /** @group GetParam */
+ final def getClassWeights: Array[Double] = $(classWeights)
+
/** Convert new impurity to old impurity. */
private[ml] def getOldImpurity: OldImpurity = {
getImpurity match {
@@ -312,8 +368,19 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams {
categoricalFeatures: Map[Int, Int],
numClasses: Int,
oldAlgo: OldAlgo.Algo,
- oldImpurity: OldImpurity): OldStrategy = {
- super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate)
+ oldImpurity: OldImpurity,
+ classWeights: Array[Double]): OldStrategy = {
+ super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo,
+ oldImpurity, getSubsamplingRate, classWeights)
+ }
+
+ private[ml] def getOldStrategy(
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int,
+ oldAlgo: OldAlgo.Algo,
+ oldImpurity: OldImpurity): OldStrategy = {
+ super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo,
+ oldImpurity, getSubsamplingRate, Array(1.0, 1.0))
}
}
@@ -455,7 +522,9 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS
private[ml] def getOldBoostingStrategy(
categoricalFeatures: Map[Int, Int],
oldAlgo: OldAlgo.Algo): OldBoostingStrategy = {
- val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, oldAlgo, OldVariance)
+ val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2,
+ oldAlgo, OldVariance, Array(1.0, 1.0))
+
// NOTE: The old API does not support "seed" so we ignore it.
new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index b34e1b1b56c4..e96350db6bb1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -23,7 +23,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Since
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
-import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
+import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance, WeightedGini}
/**
* Stores all the configuration options for tree construction
@@ -32,6 +32,7 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
* [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
* @param impurity Criterion used for information gain calculation.
* Supported for Classification: [[org.apache.spark.mllib.tree.impurity.Gini]],
+ * [[org.apache.spark.mllib.tree.impurity.WeightedGini]],
* [[org.apache.spark.mllib.tree.impurity.Entropy]].
* Supported for Regression: [[org.apache.spark.mllib.tree.impurity.Variance]].
* @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means
@@ -65,6 +66,8 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
* E.g. 10 means that the cache will get checkpointed every 10 updates. If
* the checkpoint directory is not set in
* [[org.apache.spark.SparkContext]], this setting is ignored.
+ * @param classWeights Weights of classes used in classification problems. It will be ignored in
+ * regression problems.
*/
@Since("1.0.0")
class Strategy @Since("1.3.0") (
@@ -80,7 +83,9 @@ class Strategy @Since("1.3.0") (
@Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256,
@Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1,
@Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false,
- @Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10) extends Serializable {
+ @Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10,
+ @Since("2.0.0") @BeanProperty var classWeights: Array[Double] = Array(1.0, 1.0))
+ extends Serializable {
/**
*/
@@ -96,6 +101,29 @@ class Strategy @Since("1.3.0") (
isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
}
+ /**
+ * Make the Strategy class compatible with old API
+ */
+ @Since("2.0.0")
+ def this(
+ algo: Algo,
+ impurity: Impurity,
+ maxDepth: Int,
+ numClasses: Int,
+ maxBins: Int,
+ quantileCalculationStrategy: QuantileStrategy,
+ categoricalFeaturesInfo: Map[Int, Int],
+ minInstancesPerNode: Int,
+ minInfoGain: Double,
+ maxMemoryInMB: Int,
+ subsamplingRate: Double,
+ useNodeIdCache: Boolean,
+ checkpointInterval: Int) {
+ this(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy,
+ categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, maxMemoryInMB,
+ subsamplingRate, useNodeIdCache, checkpointInterval, Array())
+ }
+
/**
* Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]]
*/
@@ -140,9 +168,9 @@ class Strategy @Since("1.3.0") (
require(numClasses >= 2,
s"DecisionTree Strategy for Classification must have numClasses >= 2," +
s" but numClasses = $numClasses.")
- require(Set(Gini, Entropy).contains(impurity),
+ require(Set(Gini, Entropy, WeightedGini).contains(impurity),
s"DecisionTree Strategy given invalid impurity for Classification: $impurity." +
- s" Valid settings: Gini, Entropy")
+ s" Valid settings: Gini, Entropy, WeightedGini")
case Regression =>
require(impurity == Variance,
s"DecisionTree Strategy given invalid impurity for Regression: $impurity." +
@@ -163,6 +191,14 @@ class Strategy @Since("1.3.0") (
require(subsamplingRate > 0 && subsamplingRate <= 1,
s"DecisionTree Strategy requires subsamplingRate <=1 and >0, but was given " +
s"$subsamplingRate")
+ if (impurity == WeightedGini) {
+ require(numClasses == classWeights.length,
+ s"DecisionTree Strategy requires the number of class weights be the same as the " +
+ s"number of classes, but there are $numClasses classes and ${classWeights.length} weights")
+ require(classWeights.forall((x: Double) => x >= 0),
+ s"DecisionTree Strategy requires the all the class weights be non-negative" +
+ s", but at least one of them is negative")
+ }
}
/**
@@ -172,7 +208,7 @@ class Strategy @Since("1.3.0") (
def copy: Strategy = {
new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain,
- maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval)
+ maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval, classWeights)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index ff7700d2d1b7..de24ba844451 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -138,6 +138,11 @@ private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCal
*/
def count: Long = stats.sum.toLong
+ /**
+ * Weighted summary statistics of data points, which in this case assume uniform class weights
+ */
+ def weightedCount: Double = stats.sum
+
/**
* Prediction which should be made based on the sufficient statistics.
*/
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index 58dc79b7398e..ded6488ddc79 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -134,6 +134,11 @@ private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalcul
*/
def count: Long = stats.sum.toLong
+ /**
+ * Weighted summary statistics of data points, which in this case assume uniform class weights
+ */
+ def weightedCount: Double = stats.sum
+
/**
* Prediction which should be made based on the sufficient statistics.
*/
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
index 65f0163ec605..b91752f0ff2c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -99,6 +99,7 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser
*/
private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) extends Serializable {
+ val weightedStats: Array[Double] = stats
/**
* Make a deep copy of this [[ImpurityCalculator]].
*/
@@ -147,6 +148,11 @@ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) exten
*/
def count: Long
+ /**
+ * Weighted summary statistics of data points
+ */
+ def weightedCount: Double
+
/**
* Prediction which should be made based on the sufficient statistics.
*/
@@ -185,11 +191,13 @@ private[spark] object ImpurityCalculator {
* Create an [[ImpurityCalculator]] instance of the given impurity type and with
* the given stats.
*/
- def getCalculator(impurity: String, stats: Array[Double]): ImpurityCalculator = {
+ def getCalculator(impurity: String, stats: Array[Double],
+ classWeights: Array[Double]): ImpurityCalculator = {
impurity match {
case "gini" => new GiniCalculator(stats)
case "entropy" => new EntropyCalculator(stats)
case "variance" => new VarianceCalculator(stats)
+ case "weightedgini" => new WeightedGiniCalculator(stats, classWeights)
case _ =>
throw new IllegalArgumentException(
s"ImpurityCalculator builder did not recognize impurity type: $impurity")
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index 2423516123b8..1087139fb4bd 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -122,6 +122,11 @@ private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCa
*/
def count: Long = stats(0).toLong
+ /**
+ * Weighted summary statistics of data points, which in this case assume uniform class weights
+ */
+ def weightedCount: Double = stats(0)
+
/**
* Prediction which should be made based on the sufficient statistics.
*/
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/WeightedGini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/WeightedGini.scala
new file mode 100644
index 000000000000..90232d07a691
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/WeightedGini.scala
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impurity
+
+import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
+
+/**
+ * :: Experimental ::
+ * Class for calculating the Gini impurity with class weights using
+ * altered prior method during classification.
+ */
+@Since("2.0.0")
+@Experimental
+object WeightedGini extends Impurity {
+
+ /**
+ * :: DeveloperApi ::
+ * information calculation for multiclass classification
+ * @param weightedCounts Array[Double] with counts for each label
+ * @param weightedTotalCount sum of counts for all labels
+ * @return information value, or 0 if totalCount = 0
+ */
+ @Since("2.0.0")
+ @DeveloperApi
+ override def calculate(weightedCounts: Array[Double], weightedTotalCount: Double): Double = {
+ if (weightedTotalCount == 0) {
+ return 0
+ }
+ val numClasses = weightedCounts.length
+ var impurity = 1.0
+ var classIndex = 0
+ while (classIndex < numClasses) {
+ val freq = weightedCounts(classIndex) / weightedTotalCount
+ impurity -= freq * freq
+ classIndex += 1
+ }
+ impurity
+ }
+
+ /**
+ * :: DeveloperApi ::
+ * variance calculation
+ * @param count number of instances
+ * @param sum sum of labels
+ * @param sumSquares summation of squares of the labels
+ * @return information value, or 0 if count = 0
+ */
+ @Since("2.0.0")
+ @DeveloperApi
+ override def calculate(count: Double, sum: Double, sumSquares: Double): Double =
+ throw new UnsupportedOperationException("WeightedGini.calculate")
+
+ /**
+ * Get this impurity instance.
+ * This is useful for passing impurity parameters to a Strategy in Java.
+ */
+ @Since("2.0.0")
+ def instance: this.type = this
+
+}
+
+/**
+ * Class for updating views of a vector of sufficient statistics,
+ * in order to compute impurity from a sample.
+ * Note: Instances of this class do not hold the data; they operate on views of the data.
+ * @param numClasses Number of classes for label.
+ * @param classWeights Weights of classes
+ */
+private[spark] class WeightedGiniAggregator(numClasses: Int, classWeights: Array[Double])
+ extends ImpurityAggregator(numClasses) with Serializable {
+
+ /**
+ * Update stats for one (node, feature, bin) with the given label.
+ * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
+ * @param offset Start index of stats for this (node, feature, bin).
+ */
+ def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = {
+ if (label >= statsSize) {
+ throw new IllegalArgumentException(s"WeightedGiniAggregator given label $label" +
+ s" but requires label < numClasses (= $statsSize).")
+ }
+ if (label < 0) {
+ throw new IllegalArgumentException(s"WeightedGiniAggregator given label $label" +
+ s"but requires label is non-negative.")
+ }
+ allStats(offset + label.toInt) += instanceWeight
+ }
+
+ /**
+ * Get an [[ImpurityCalculator]] for a (node, feature, bin).
+ * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
+ * @param offset Start index of stats for this (node, feature, bin).
+ */
+ def getCalculator(allStats: Array[Double], offset: Int): WeightedGiniCalculator = {
+ new WeightedGiniCalculator(allStats.view(offset, offset + statsSize).toArray, classWeights)
+ }
+}
+
+/**
+ * Stores statistics for one (node, feature, bin) for calculating impurity.
+ * Unlike [[WeightedGiniAggregator]], this class stores its own data and is for a specific
+ * (node, feature, bin).
+ * @param stats Array of sufficient statistics for a (node, feature, bin).
+ * @param classWeights Weights of classes
+ */
+private[spark] class WeightedGiniCalculator(stats: Array[Double], classWeights: Array[Double])
+ extends ImpurityCalculator(stats) {
+
+ override val weightedStats = stats.zip(classWeights).map(x => x._1 * x._2)
+ /**
+ * Make a deep copy of this [[ImpurityCalculator]].
+ */
+ def copy: WeightedGiniCalculator = new WeightedGiniCalculator(stats.clone(), classWeights.clone())
+
+ /**
+ * Calculate the impurity from the stored sufficient statistics.
+ */
+ def calculate(): Double = WeightedGini.calculate(weightedStats, weightedStats.sum)
+
+ /**
+ * Number of data points accounted for in the sufficient statistics.
+ */
+ def count: Long = stats.sum.toLong
+
+ /**
+ * Weighted summary statistics of data points
+ */
+ def weightedCount: Double = weightedStats.sum
+
+ /**
+ * Prediction which should be made based on the sufficient statistics.
+ */
+ def predict: Double = if (count == 0) {
+ 0
+ } else {
+ indexOfLargestArrayElement(weightedStats)
+ }
+
+ /**
+ * Probability of the label given by [[predict]].
+ */
+ override def prob(label: Double): Double = {
+ val lbl = label.toInt
+ require(lbl < stats.length,
+ s"WeightedGiniCalculator.prob given invalid label: $lbl (should be < ${stats.length}")
+ require(lbl >= 0, "WeightedGiniImpurity does not support negative labels")
+ val cnt = weightedCount
+ if (cnt == 0) {
+ 0
+ } else {
+ weightedStats(lbl) / cnt
+ }
+ }
+
+ override def toString: String = s"WeightedGiniCalculator(stats = [${stats.mkString(", ")}])"
+
+ /**
+ * Add the stats from another calculator into this one, modifying and returning this calculator.
+ * Update the weightedStats at the same time
+ */
+ override def add(other: ImpurityCalculator): ImpurityCalculator = {
+ require(stats.length == other.stats.length,
+ s"Two ImpurityCalculator instances cannot be added with different counts sizes." +
+ s" Sizes are ${stats.length} and ${other.stats.length}.")
+ val otherCalculator = other.asInstanceOf[WeightedGiniCalculator]
+ val len = otherCalculator.stats.length
+ var i = 0
+ while (i < len) {
+ stats(i) += otherCalculator.stats(i)
+ weightedStats(i) += otherCalculator.weightedStats(i)
+ i += 1
+ }
+ this
+ }
+
+ /**
+ * Subtract the stats from another calculator from this one, modifying and returning this
+ * calculator. Update the weightedStats at the same time
+ */
+ override def subtract(other: ImpurityCalculator): ImpurityCalculator = {
+ require(stats.length == other.stats.length,
+ s"Two ImpurityCalculator instances cannot be subtracted with different counts sizes." +
+ s" Sizes are ${stats.length} and ${other.stats.length}.")
+ val otherCalculator = other.asInstanceOf[WeightedGiniCalculator]
+ val len = otherCalculator.stats.length
+ var i = 0
+ while (i < len) {
+ stats(i) -= otherCalculator.stats(i)
+ weightedStats(i) -= otherCalculator.weightedStats(i)
+ i += 1
+ }
+ this
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 089d30abb5ef..096ab2467ab8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -69,6 +69,18 @@ class DecisionTreeClassifierSuite
// Tests calling train()
/////////////////////////////////////////////////////////////////////////////
+ test("Binary classification with explicitly setting uniform class weights") {
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("WeightedGini")
+ .setMaxDepth(2)
+ .setMaxBins(100)
+ .setSeed(1)
+ .setClassWeights(Array(1, 1))
+ val categoricalFeatures = Map(0 -> 3, 1 -> 3)
+ val numClasses = 2
+ compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures, numClasses)
+ }
+
test("Binary classification stump with ordered categorical features") {
val dt = new DecisionTreeClassifier()
.setImpurity("gini")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index 2e99ee157ae9..5ea110ec0d02 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -234,7 +234,8 @@ private object RandomForestClassifierSuite extends SparkFunSuite {
numClasses: Int): Unit = {
val numFeatures = data.first().features.size
val oldStrategy =
- rf.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, rf.getOldImpurity)
+ rf.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification,
+ rf.getOldImpurity, rf.getSubsamplingRate, rf.getClassWeights)
val oldModel = OldRandomForest.trainClassifier(
data.map(OldLabeledPoint.fromML), oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy,
rf.getSeed.toInt)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index c08335f9f84a..169dcdd3f567 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -140,7 +140,9 @@ private object RandomForestRegressorSuite extends SparkFunSuite {
categoricalFeatures: Map[Int, Int]): Unit = {
val numFeatures = data.first().features.size
val oldStrategy =
- rf.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, rf.getOldImpurity)
+ rf.getOldStrategy(categoricalFeatures, numClasses = 0,
+ OldAlgo.Regression, rf.getOldImpurity, rf.getSubsamplingRate,
+ classWeights = Array())
val oldModel = OldRandomForest.trainRegressor(data.map(OldLabeledPoint.fromML), oldStrategy,
rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt)
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index dcc2f305df75..dce4e698b82c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -93,7 +93,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
Map(), Set(),
Array(6), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
+ 0, 0, 0.0, 0, 0, Array[Double]()
)
val featureSamples = Array.fill(200000)(math.random)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
@@ -110,7 +110,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
Map(), Set(),
Array(5), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
+ 0, 0, 0.0, 0, 0, Array[Double]()
)
val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
@@ -124,7 +124,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
Map(), Set(),
Array(3), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
+ 0, 0, 0.0, 0, 0, Array[Double]()
)
val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
@@ -138,7 +138,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
Map(), Set(),
Array(3), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
+ 0, 0, 0.0, 0, 0, Array[Double]()
)
val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
diff --git a/scalastyle-config.xml b/scalastyle-config.xml
index 270104f85b83..57c275baed21 100644
--- a/scalastyle-config.xml
+++ b/scalastyle-config.xml
@@ -94,7 +94,7 @@ This file is divided into 3 sections:
-
+