Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._

import org.apache.spark.annotation.Since
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
Expand All @@ -33,7 +32,6 @@ import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -138,20 +136,6 @@ class DecisionTreeClassifier @Since("1.4.0") (
trees.head.asInstanceOf[DecisionTreeClassificationModel]
}

/** (private[ml]) Train a decision tree on an RDD */
private[ml] def train(data: RDD[LabeledPoint],
oldStrategy: OldStrategy): DecisionTreeClassificationModel = instrumented { instr =>
val instances = data.map(_.toInstance)
instr.logPipelineStage(this)
instr.logDataset(instances)
instr.logParams(this, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
cacheNodeIds, checkpointInterval, impurity, seed)
val trees = RandomForest.run(instances, oldStrategy, numTrees = 1,
featureSubsetStrategy = "all", seed = 0L, instr = Some(instr), parentUID = Some(uid))

trees.head.asInstanceOf[DecisionTreeClassificationModel]
}

/** (private[ml]) Create a Strategy instance to use with the old API. */
private[ml] def getOldStrategy(
categoricalFeatures: Map[Int, Int],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ private[spark] object BaggedPoint {
if (numSubsamples == 1 && subsamplingRate == 1.0) {
convertToBaggedRDDWithoutSampling(input, extractSampleWeight)
} else {
convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed)
convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples,
extractSampleWeight, seed)
}
}
}
Expand All @@ -79,6 +80,7 @@ private[spark] object BaggedPoint {
input: RDD[Datum],
subsamplingRate: Double,
numSubsamples: Int,
extractSampleWeight: (Datum => Double),
seed: Long): RDD[BaggedPoint[Datum]] = {
input.mapPartitionsWithIndex { (partitionIndex, instances) =>
// Use random seed = seed + partitionIndex + 1 to make generation reproducible.
Expand All @@ -93,7 +95,7 @@ private[spark] object BaggedPoint {
}
subsampleIndex += 1
}
new BaggedPoint(instance, subsampleCounts)
new BaggedPoint(instance, subsampleCounts, extractSampleWeight(instance))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,11 +313,12 @@ private[spark] object GradientBoostedTrees extends Logging {

// Initialize tree
timer.start("building tree 0")
val metadata = RandomForest.buildMetadata(input, treeStrategy,
numTrees = 1, featureSubsetStrategy)
val firstTreeModel = RandomForest.run(input, treeStrategy, numTrees = 1,
featureSubsetStrategy, seed = seed, instr = instr,
parentUID = None, precomputedMetadata = Some(metadata))
val metadata = DecisionTreeMetadata.buildMetadata(
input.retag(classOf[Instance]), treeStrategy, numTrees = 1,
featureSubsetStrategy)
val firstTreeModel = RandomForest.runWithMetadata(input, metadata, treeStrategy,
numTrees = 1, featureSubsetStrategy, seed = seed, instr = instr,
parentUID = None)
.head.asInstanceOf[DecisionTreeRegressionModel]
val firstTreeWeight = 1.0
baseLearners(0) = firstTreeModel
Expand Down Expand Up @@ -353,9 +354,9 @@ private[spark] object GradientBoostedTrees extends Logging {
logDebug("Gradient boosting tree iteration " + m)
logDebug("###################################################")

val model = RandomForest.run(data, treeStrategy, numTrees = 1,
featureSubsetStrategy, seed = seed + m, instr = None,
parentUID = None, precomputedMetadata = Some(metadata))
val model = RandomForest.runWithMetadata(data, metadata, treeStrategy,
numTrees = 1, featureSubsetStrategy, seed = seed + m,
instr = None, parentUID = None)
.head.asInstanceOf[DecisionTreeRegressionModel]
timer.stop(s"building tree $m")
// Update partial model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,43 +100,26 @@ private[spark] object RandomForest extends Logging with Serializable {
run(instances, strategy, numTrees, featureSubsetStrategy, seed, None)
}

def buildMetadata(
input: RDD[Instance],
strategy: OldStrategy,
numTrees: Int,
featureSubsetStrategy: String): DecisionTreeMetadata = {
val retaggedInput = input.retag(classOf[Instance])
DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
}

/**
* Train a random forest.
* Train a random forest with metadata. This method is mainly for GBT, in which metadata can
* be reused among trees.
*
* @param input Training data: RDD of `Instance`
* @param metadata Learning and dataset metadata for DecisionTree.
* @return an unweighted set of trees
*/
def run(
def runWithMetadata(
input: RDD[Instance],
metadata: DecisionTreeMetadata,
strategy: OldStrategy,
numTrees: Int,
featureSubsetStrategy: String,
seed: Long,
instr: Option[Instrumentation],
prune: Boolean = true, // exposed for testing only, real trees are always pruned
parentUID: Option[String] = None,
precomputedMetadata: Option[DecisionTreeMetadata] = None): Array[DecisionTreeModel] = {

parentUID: Option[String] = None): Array[DecisionTreeModel] = {
val timer = new TimeTracker()

timer.start("total")

timer.start("init")

val retaggedInput = input.retag(classOf[Instance])
val metadata = precomputedMetadata.getOrElse {
DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
}

instr match {
case Some(instrumentation) =>
instrumentation.logNumFeatures(metadata.numFeatures)
Expand All @@ -150,6 +133,12 @@ private[spark] object RandomForest extends Logging with Serializable {
logInfo("weightedNumExamples: " + metadata.weightedNumExamples)
}

timer.start("total")

timer.start("init")

val retaggedInput = input.retag(classOf[Instance])

// Find the splits and the corresponding bins (interval between the splits) using a sample
// of the input data.
timer.start("findSplits")
Expand Down Expand Up @@ -225,7 +214,7 @@ private[spark] object RandomForest extends Logging with Serializable {
// Collect some nodes to split, and choose features for each node (if subsampling).
// Each group of nodes may come from one or multiple trees, and at multiple levels.
val (nodesForGroup, treeToNodeToIndexInfo) =
RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng)
RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng)
// Sanity check (should never occur):
assert(nodesForGroup.nonEmpty,
s"RandomForest selected empty nodesForGroup. Error for unknown reason.")
Expand Down Expand Up @@ -285,6 +274,32 @@ private[spark] object RandomForest extends Logging with Serializable {
}
}

/**
* Train a random forest.
*
* @param input Training data: RDD of `Instance`
* @return an unweighted set of trees
*/
def run(
input: RDD[Instance],
strategy: OldStrategy,
numTrees: Int,
featureSubsetStrategy: String,
seed: Long,
instr: Option[Instrumentation],
prune: Boolean = true, // exposed for testing only, real trees are always pruned
parentUID: Option[String] = None): Array[DecisionTreeModel] = {
val timer = new TimeTracker()

timer.start("build metadata")
val metadata = DecisionTreeMetadata
.buildMetadata(input.retag(classOf[Instance]), strategy, numTrees, featureSubsetStrategy)
timer.stop("build metadata")

runWithMetadata(input, metadata, strategy, numTrees, featureSubsetStrategy,
seed, instr, prune, parentUID)
}

/**
* Helper for binSeqOp, for data which can contain a mix of ordered and unordered features.
*
Expand Down Expand Up @@ -577,7 +592,7 @@ private[spark] object RandomForest extends Logging with Serializable {

// transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
// which can be combined with other partition using `reduceByKey`
nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
nodeStatsAggregators.iterator.zipWithIndex.map(_.swap)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IDEA editor always shows warnings on the two lines, change them to avoid warnings.

}
} else {
input.mapPartitions { points =>
Expand All @@ -595,7 +610,7 @@ private[spark] object RandomForest extends Logging with Serializable {

// transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
// which can be combined with other partition using `reduceByKey`
nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
nodeStatsAggregators.iterator.zipWithIndex.map(_.swap)
}
}

Expand All @@ -610,6 +625,7 @@ private[spark] object RandomForest extends Logging with Serializable {
binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))
(nodeIndex, (split, stats))
}.collectAsMap()
nodeToFeaturesBc.destroy()

timer.stop("chooseSplits")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,30 +493,31 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
test("training with sample weights") {
val df = binaryDataset
val numClasses = 2
val predEquals = (x: Double, y: Double) => x == y
// (maxIter, maxDepth)
// (maxIter, maxDepth, subsamplingRate, fractionInTol)
val testParams = Seq(
(5, 5),
(5, 10)
(5, 5, 1.0, 0.99),
(5, 10, 1.0, 0.99),
(5, 10, 0.95, 0.9)
)

for ((maxIter, maxDepth) <- testParams) {
for ((maxIter, maxDepth, subsamplingRate, tol) <- testParams) {
val estimator = new GBTClassifier()
.setMaxIter(maxIter)
.setMaxDepth(maxDepth)
.setSubsamplingRate(subsamplingRate)
.setSeed(seed)
.setMinWeightFractionPerNode(0.049)

MLTestingUtils.testArbitrarilyScaledWeights[GBTClassificationModel,
GBTClassifier](df.as[LabeledPoint], estimator,
MLTestingUtils.modelPredictionEquals(df, predEquals, 0.7))
MLTestingUtils.modelPredictionEquals(df, _ == _, tol))
MLTestingUtils.testOutliersWithSmallWeights[GBTClassificationModel,
GBTClassifier](df.as[LabeledPoint], estimator,
numClasses, MLTestingUtils.modelPredictionEquals(df, predEquals, 0.8),
numClasses, MLTestingUtils.modelPredictionEquals(df, _ == _, tol),
outlierRatio = 2)
MLTestingUtils.testOversamplingVsWeighting[GBTClassificationModel,
GBTClassifier](df.as[LabeledPoint], estimator,
MLTestingUtils.modelPredictionEquals(df, predEquals, 0.7), seed)
MLTestingUtils.modelPredictionEquals(df, _ == _, tol), seed)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,29 +321,31 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest {
test("training with sample weights") {
val df = linearRegressionData
val numClasses = 0
// (maxIter, maxDepth)
// (maxIter, maxDepth, subsamplingRate, fractionInTol)
val testParams = Seq(
(5, 5),
(5, 10)
(5, 5, 1.0, 0.98),
(5, 10, 1.0, 0.98),
(5, 10, 0.95, 0.6)
)

for ((maxIter, maxDepth) <- testParams) {
for ((maxIter, maxDepth, subsamplingRate, tol) <- testParams) {
val estimator = new GBTRegressor()
.setMaxIter(maxIter)
.setMaxDepth(maxDepth)
.setSubsamplingRate(subsamplingRate)
.setSeed(seed)
.setMinWeightFractionPerNode(0.1)

MLTestingUtils.testArbitrarilyScaledWeights[GBTRegressionModel,
GBTRegressor](df.as[LabeledPoint], estimator,
MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.1, 0.95))
MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.1, tol))
MLTestingUtils.testOutliersWithSmallWeights[GBTRegressionModel,
GBTRegressor](df.as[LabeledPoint], estimator, numClasses,
MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.1, 0.95),
MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.1, tol),
outlierRatio = 2)
MLTestingUtils.testOversamplingVsWeighting[GBTRegressionModel,
GBTRegressor](df.as[LabeledPoint], estimator,
MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.01, 0.95), seed)
MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.1, tol), seed)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
expectedStddev, epsilon = 0.01)
// should ignore weight function for now
assert(baggedRDD.collect().forall(_.sampleWeight === 1.0))
assert(baggedRDD.collect().forall(_.sampleWeight === 2.0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just trying to understand, why did the sample weight change in this test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because this testsuite meet conditions: withReplacement=false, numSubsamples!=1,
it will call the modified convertToBaggedRDDSamplingWithoutReplacement,

and the extractSampleWeight here is (_: LabeledPoint) => 2.0, so output baggedPoints will have sampleWeight==2.0

}
}

Expand Down