Skip to content

Commit 32ec9a6

Browse files
committed
init
fix bagged
1 parent f8cfefa commit 32ec9a6

File tree

8 files changed

+161
-21
lines changed

8 files changed

+161
-21
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ class RandomForestClassifier @Since("1.4.0") (
6969
@Since("1.4.0")
7070
def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
7171

72+
/** @group setParam */
73+
@Since("3.0.0")
74+
def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)
75+
7276
/** @group setParam */
7377
@Since("1.4.0")
7478
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
@@ -118,6 +122,16 @@ class RandomForestClassifier @Since("1.4.0") (
118122
def setFeatureSubsetStrategy(value: String): this.type =
119123
set(featureSubsetStrategy, value)
120124

125+
/**
126+
* Sets the value of param [[weightCol]].
127+
* If this is not set or empty, we treat all instance weights as 1.0.
128+
* By default the weightCol is not set, so all instances have weight 1.0.
129+
*
130+
* @group setParam
131+
*/
132+
@Since("3.0.0")
133+
def setWeightCol(value: String): this.type = set(weightCol, value)
134+
121135
override protected def train(
122136
dataset: Dataset[_]): RandomForestClassificationModel = instrumented { instr =>
123137
instr.logPipelineStage(this)
@@ -132,14 +146,14 @@ class RandomForestClassifier @Since("1.4.0") (
132146
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
133147
}
134148

135-
val instances: RDD[Instance] = extractLabeledPoints(dataset, numClasses).map(_.toInstance)
149+
val instances = extractInstances(dataset)
136150
val strategy =
137151
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
138152

139-
instr.logParams(this, labelCol, featuresCol, predictionCol, probabilityCol, rawPredictionCol,
140-
leafCol, impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB,
141-
minInfoGain, minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds,
142-
checkpointInterval)
153+
instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, probabilityCol,
154+
rawPredictionCol, leafCol, impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins,
155+
maxMemoryInMB, minInfoGain, minInstancesPerNode, minWeightFractionPerNode, seed,
156+
subsamplingRate, thresholds, cacheNodeIds, checkpointInterval)
143157

144158
val trees = RandomForest
145159
.run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))

mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
6464
@Since("1.4.0")
6565
def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
6666

67+
/** @group setParam */
68+
@Since("3.0.0")
69+
def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)
70+
6771
/** @group setParam */
6872
@Since("1.4.0")
6973
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
@@ -113,20 +117,31 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
113117
def setFeatureSubsetStrategy(value: String): this.type =
114118
set(featureSubsetStrategy, value)
115119

120+
/**
121+
* Sets the value of param [[weightCol]].
122+
* If this is not set or empty, we treat all instance weights as 1.0.
123+
* By default the weightCol is not set, so all instances have weight 1.0.
124+
*
125+
* @group setParam
126+
*/
127+
@Since("3.0.0")
128+
def setWeightCol(value: String): this.type = set(weightCol, value)
129+
116130
override protected def train(
117131
dataset: Dataset[_]): RandomForestRegressionModel = instrumented { instr =>
118132
val categoricalFeatures: Map[Int, Int] =
119133
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
120134

121-
val instances = extractLabeledPoints(dataset).map(_.toInstance)
135+
val instances = extractInstances(dataset)
122136
val strategy =
123137
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)
124138

125139
instr.logPipelineStage(this)
126140
instr.logDataset(instances)
127-
instr.logParams(this, labelCol, featuresCol, predictionCol, leafCol, impurity, numTrees,
128-
featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain,
129-
minInstancesPerNode, seed, subsamplingRate, cacheNodeIds, checkpointInterval)
141+
instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, leafCol, impurity,
142+
numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain,
143+
minInstancesPerNode, minWeightFractionPerNode, seed, subsamplingRate, cacheNodeIds,
144+
checkpointInterval)
130145

131146
val trees = RandomForest
132147
.run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))

mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ private[spark] object BaggedPoint {
6565
seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = {
6666
// TODO: implement weighted bootstrapping
6767
if (withReplacement) {
68-
convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed)
68+
convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples,
69+
extractSampleWeight, seed)
6970
} else {
7071
if (numSubsamples == 1 && subsamplingRate == 1.0) {
7172
convertToBaggedRDDWithoutSampling(input, extractSampleWeight)
@@ -104,6 +105,7 @@ private[spark] object BaggedPoint {
104105
input: RDD[Datum],
105106
subsample: Double,
106107
numSubsamples: Int,
108+
extractSampleWeight: (Datum => Double),
107109
seed: Long): RDD[BaggedPoint[Datum]] = {
108110
input.mapPartitionsWithIndex { (partitionIndex, instances) =>
109111
// Use random seed = seed + partitionIndex + 1 to make generation reproducible.
@@ -116,7 +118,7 @@ private[spark] object BaggedPoint {
116118
subsampleCounts(subsampleIndex) = poisson.sample()
117119
subsampleIndex += 1
118120
}
119-
new BaggedPoint(instance, subsampleCounts)
121+
new BaggedPoint(instance, subsampleCounts, extractSampleWeight(instance))
120122
}
121123
}
122124
}

mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.ml.classification
1919

2020
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.ml.classification.LinearSVCSuite.generateSVMInput
2122
import org.apache.spark.ml.feature.LabeledPoint
2223
import org.apache.spark.ml.linalg.{Vector, Vectors}
2324
import org.apache.spark.ml.param.ParamsSuite
@@ -41,6 +42,8 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest {
4142

4243
private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _
4344
private var orderedLabeledPoints5_20: RDD[LabeledPoint] = _
45+
private var binaryDataset: DataFrame = _
46+
private val seed = 42
4447

4548
override def beforeAll(): Unit = {
4649
super.beforeAll()
@@ -50,6 +53,7 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest {
5053
orderedLabeledPoints5_20 =
5154
sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 5, 20))
5255
.map(_.asML)
56+
binaryDataset = generateSVMInput(0.01, Array[Double](-1.5, 1.0), 1000, seed).toDF()
5357
}
5458

5559
/////////////////////////////////////////////////////////////////////////////
@@ -259,6 +263,37 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest {
259263
})
260264
}
261265

266+
test("training with sample weights") {
267+
val df = binaryDataset
268+
val numClasses = 2
269+
// (numTrees, maxDepth, subsamplingRate, fractionInTol)
270+
val testParams = Seq(
271+
(20, 5, 1.0, 0.96),
272+
(20, 10, 1.0, 0.96),
273+
(20, 10, 0.95, 0.96)
274+
)
275+
276+
for ((numTrees, maxDepth, subsamplingRate, tol) <- testParams) {
277+
val estimator = new RandomForestClassifier()
278+
.setNumTrees(numTrees)
279+
.setMaxDepth(maxDepth)
280+
.setSubsamplingRate(subsamplingRate)
281+
.setSeed(seed)
282+
.setMinWeightFractionPerNode(0.049)
283+
284+
MLTestingUtils.testArbitrarilyScaledWeights[RandomForestClassificationModel,
285+
RandomForestClassifier](df.as[LabeledPoint], estimator,
286+
MLTestingUtils.modelPredictionEquals(df, _ == _, tol))
287+
MLTestingUtils.testOutliersWithSmallWeights[RandomForestClassificationModel,
288+
RandomForestClassifier](df.as[LabeledPoint], estimator,
289+
numClasses, MLTestingUtils.modelPredictionEquals(df, _ == _, tol),
290+
outlierRatio = 2)
291+
MLTestingUtils.testOversamplingVsWeighting[RandomForestClassificationModel,
292+
RandomForestClassifier](df.as[LabeledPoint], estimator,
293+
MLTestingUtils.modelPredictionEquals(df, _ == _, tol), seed)
294+
}
295+
}
296+
262297
/////////////////////////////////////////////////////////////////////////////
263298
// Tests of model save/load
264299
/////////////////////////////////////////////////////////////////////////////

mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@ import org.apache.spark.ml.feature.LabeledPoint
2222
import org.apache.spark.ml.linalg.Vector
2323
import org.apache.spark.ml.tree.impl.TreeTests
2424
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
25+
import org.apache.spark.ml.util.TestingUtils._
2526
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
2627
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
2728
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
29+
import org.apache.spark.mllib.util.LinearDataGenerator
2830
import org.apache.spark.rdd.RDD
2931
import org.apache.spark.sql.{DataFrame, Row}
3032

@@ -37,12 +39,18 @@ class RandomForestRegressorSuite extends MLTest with DefaultReadWriteTest{
3739
import testImplicits._
3840

3941
private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _
42+
private var linearRegressionData: DataFrame = _
43+
private val seed = 42
4044

4145
override def beforeAll(): Unit = {
4246
super.beforeAll()
4347
orderedLabeledPoints50_1000 =
4448
sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
4549
.map(_.asML))
50+
51+
linearRegressionData = sc.parallelize(LinearDataGenerator.generateLinearInput(
52+
intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3),
53+
xVariance = Array(0.7, 1.2), nPoints = 1000, seed, eps = 0.5), 2).map(_.asML).toDF()
4654
}
4755

4856
/////////////////////////////////////////////////////////////////////////////
@@ -158,6 +166,37 @@ class RandomForestRegressorSuite extends MLTest with DefaultReadWriteTest{
158166
})
159167
}
160168

169+
test("training with sample weights") {
170+
val df = linearRegressionData
171+
val numClasses = 0
172+
// (numTrees, maxDepth, subsamplingRate, fractionInTol)
173+
val testParams = Seq(
174+
(50, 5, 1.0, 0.75),
175+
(50, 10, 1.0, 0.75),
176+
(50, 10, 0.95, 0.78)
177+
)
178+
179+
for ((numTrees, maxDepth, subsamplingRate, tol) <- testParams) {
180+
val estimator = new RandomForestRegressor()
181+
.setNumTrees(numTrees)
182+
.setMaxDepth(maxDepth)
183+
.setSubsamplingRate(subsamplingRate)
184+
.setSeed(seed)
185+
.setMinWeightFractionPerNode(0.05)
186+
187+
MLTestingUtils.testArbitrarilyScaledWeights[RandomForestRegressionModel,
188+
RandomForestRegressor](df.as[LabeledPoint], estimator,
189+
MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.2, tol))
190+
MLTestingUtils.testOutliersWithSmallWeights[RandomForestRegressionModel,
191+
RandomForestRegressor](df.as[LabeledPoint], estimator,
192+
numClasses, MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.2, tol),
193+
outlierRatio = 2)
194+
MLTestingUtils.testOversamplingVsWeighting[RandomForestRegressionModel,
195+
RandomForestRegressor](df.as[LabeledPoint], estimator,
196+
MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.2, tol), seed)
197+
}
198+
}
199+
161200
/////////////////////////////////////////////////////////////////////////////
162201
// Tests of model save/load
163202
/////////////////////////////////////////////////////////////////////////////

mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
5454
baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
5555
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
5656
expectedStddev, epsilon = 0.01)
57-
// should ignore weight function for now
58-
assert(baggedRDD.collect().forall(_.sampleWeight === 1.0))
57+
assert(baggedRDD.collect().forall(_.sampleWeight === 2.0))
5958
}
6059
}
6160

python/pyspark/ml/classification.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,6 +1387,8 @@ class RandomForestClassifier(JavaProbabilisticClassifier, _RandomForestClassifie
13871387
>>> td = si_model.transform(df)
13881388
>>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42,
13891389
... leafCol="leafId")
1390+
>>> rf.getMinWeightFractionPerNode()
1391+
0.0
13901392
>>> model = rf.fit(td)
13911393
>>> model.getLabelCol()
13921394
'indexed'
@@ -1441,14 +1443,14 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
14411443
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
14421444
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
14431445
numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0,
1444-
leafCol="", minWeightFractionPerNode=0.0):
1446+
leafCol="", minWeightFractionPerNode=0.0, weightCol=None):
14451447
"""
14461448
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
14471449
probabilityCol="probability", rawPredictionCol="rawPrediction", \
14481450
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
14491451
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
14501452
numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0, \
1451-
leafCol="", minWeightFractionPerNode=0.0)
1453+
leafCol="", minWeightFractionPerNode=0.0, weightCol=None)
14521454
"""
14531455
super(RandomForestClassifier, self).__init__()
14541456
self._java_obj = self._new_java_obj(
@@ -1467,14 +1469,14 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
14671469
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
14681470
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None,
14691471
impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0,
1470-
leafCol="", minWeightFractionPerNode=0.0):
1472+
leafCol="", minWeightFractionPerNode=0.0, weightCol=None):
14711473
"""
14721474
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
14731475
probabilityCol="probability", rawPredictionCol="rawPrediction", \
14741476
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
14751477
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \
14761478
impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0, \
1477-
leafCol="", minWeightFractionPerNode=0.0)
1479+
leafCol="", minWeightFractionPerNode=0.0, weightCol=None)
14781480
Sets params for linear classification.
14791481
"""
14801482
kwargs = self._input_kwargs
@@ -1559,6 +1561,20 @@ def setCheckpointInterval(self, value):
15591561
"""
15601562
return self._set(checkpointInterval=value)
15611563

1564+
@since("3.0.0")
1565+
def setWeightCol(self, value):
1566+
"""
1567+
Sets the value of :py:attr:`weightCol`.
1568+
"""
1569+
return self._set(weightCol=value)
1570+
1571+
@since("3.0.0")
1572+
def setMinWeightFractionPerNode(self, value):
1573+
"""
1574+
Sets the value of :py:attr:`minWeightFractionPerNode`.
1575+
"""
1576+
return self._set(minWeightFractionPerNode=value)
1577+
15621578

15631579
class RandomForestClassificationModel(_TreeEnsembleModel, JavaProbabilisticClassificationModel,
15641580
_RandomForestClassifierParams, JavaMLWritable,

python/pyspark/ml/regression.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -995,6 +995,8 @@ class RandomForestRegressor(JavaPredictor, _RandomForestRegressorParams, JavaMLW
995995
... (1.0, Vectors.dense(1.0)),
996996
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
997997
>>> rf = RandomForestRegressor(numTrees=2, maxDepth=2)
998+
>>> rf.getMinWeightFractionPerNode()
999+
0.0
9981000
>>> rf.setSeed(42)
9991001
RandomForestRegressor...
10001002
>>> model = rf.fit(df)
@@ -1044,13 +1046,15 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
10441046
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
10451047
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
10461048
impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20,
1047-
featureSubsetStrategy="auto", leafCol="", minWeightFractionPerNode=0.0):
1049+
featureSubsetStrategy="auto", leafCol="", minWeightFractionPerNode=0.0,
1050+
weightCol=None):
10481051
"""
10491052
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
10501053
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
10511054
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
10521055
impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, \
1053-
featureSubsetStrategy="auto", leafCol=", minWeightFractionPerNode=0.0")
1056+
featureSubsetStrategy="auto", leafCol=", minWeightFractionPerNode=0.0", \
1057+
weightCol=None)
10541058
"""
10551059
super(RandomForestRegressor, self).__init__()
10561060
self._java_obj = self._new_java_obj(
@@ -1068,13 +1072,15 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
10681072
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
10691073
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
10701074
impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20,
1071-
featureSubsetStrategy="auto", leafCol="", minWeightFractionPerNode=0.0):
1075+
featureSubsetStrategy="auto", leafCol="", minWeightFractionPerNode=0.0,
1076+
weightCol=None):
10721077
"""
10731078
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
10741079
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
10751080
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
10761081
impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, \
1077-
featureSubsetStrategy="auto", leafCol="", minWeightFractionPerNode=0.0)
1082+
featureSubsetStrategy="auto", leafCol="", minWeightFractionPerNode=0.0, \
1083+
weightCol=None)
10781084
Sets params for linear regression.
10791085
"""
10801086
kwargs = self._input_kwargs
@@ -1159,6 +1165,20 @@ def setSeed(self, value):
11591165
"""
11601166
return self._set(seed=value)
11611167

1168+
@since("3.0.0")
1169+
def setWeightCol(self, value):
1170+
"""
1171+
Sets the value of :py:attr:`weightCol`.
1172+
"""
1173+
return self._set(weightCol=value)
1174+
1175+
@since("3.0.0")
1176+
def setMinWeightFractionPerNode(self, value):
1177+
"""
1178+
Sets the value of :py:attr:`minWeightFractionPerNode`.
1179+
"""
1180+
return self._set(minWeightFractionPerNode=value)
1181+
11621182

11631183
class RandomForestRegressionModel(_TreeEnsembleModel, _RandomForestRegressorParams,
11641184
JavaMLWritable, JavaMLReadable):

0 commit comments

Comments
 (0)