Skip to content

Commit 8520d7c

Browse files
facaiyyanboliang
authored andcommitted
[SPARK-21306][ML] OneVsRest should support setWeightCol
## What changes were proposed in this pull request? add `setWeightCol` method for OneVsRest. `weightCol` is ignored if classifier doesn't inherit HasWeightCol trait. ## How was this patch tested? + [x] add an unit test. Author: Yan Facai (颜发才) <[email protected]> Closes #18554 from facaiy/BUG/oneVsRest_missing_weightCol. (cherry picked from commit a5a3189) Signed-off-by: Yanbo Liang <[email protected]>
1 parent 9498798 commit 8520d7c

File tree

4 files changed

+81
-9
lines changed

4 files changed

+81
-9
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import org.apache.spark.ml._
3434
import org.apache.spark.ml.attribute._
3535
import org.apache.spark.ml.linalg.Vector
3636
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
37+
import org.apache.spark.ml.param.shared.HasWeightCol
3738
import org.apache.spark.ml.util._
3839
import org.apache.spark.sql.{DataFrame, Dataset, Row}
3940
import org.apache.spark.sql.functions._
@@ -53,7 +54,8 @@ private[ml] trait ClassifierTypeTrait {
5354
/**
5455
* Params for [[OneVsRest]].
5556
*/
56-
private[ml] trait OneVsRestParams extends PredictorParams with ClassifierTypeTrait {
57+
private[ml] trait OneVsRestParams extends PredictorParams
58+
with ClassifierTypeTrait with HasWeightCol {
5759

5860
/**
5961
* param for the base binary classifier that we reduce multiclass classification into.
@@ -299,6 +301,18 @@ final class OneVsRest @Since("1.4.0") (
299301
@Since("1.5.0")
300302
def setPredictionCol(value: String): this.type = set(predictionCol, value)
301303

304+
/**
305+
* Sets the value of param [[weightCol]].
306+
*
307+
* This is ignored if weight is not supported by [[classifier]].
308+
* If this is not set or empty, we treat all instance weights as 1.0.
309+
* Default is not set, so all instances have weight one.
310+
*
311+
* @group setParam
312+
*/
313+
@Since("2.3.0")
314+
def setWeightCol(value: String): this.type = set(weightCol, value)
315+
302316
@Since("1.4.0")
303317
override def transformSchema(schema: StructType): StructType = {
304318
validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
@@ -317,7 +331,20 @@ final class OneVsRest @Since("1.4.0") (
317331
}
318332
val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity)
319333

320-
val multiclassLabeled = dataset.select($(labelCol), $(featuresCol))
334+
val weightColIsUsed = isDefined(weightCol) && $(weightCol).nonEmpty && {
335+
getClassifier match {
336+
case _: HasWeightCol => true
337+
case c =>
338+
logWarning(s"weightCol is ignored, as it is not supported by $c now.")
339+
false
340+
}
341+
}
342+
343+
val multiclassLabeled = if (weightColIsUsed) {
344+
dataset.select($(labelCol), $(featuresCol), $(weightCol))
345+
} else {
346+
dataset.select($(labelCol), $(featuresCol))
347+
}
321348

322349
// persist if underlying dataset is not persistent.
323350
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
@@ -337,7 +364,13 @@ final class OneVsRest @Since("1.4.0") (
337364
paramMap.put(classifier.labelCol -> labelColName)
338365
paramMap.put(classifier.featuresCol -> getFeaturesCol)
339366
paramMap.put(classifier.predictionCol -> getPredictionCol)
340-
classifier.fit(trainingDataset, paramMap)
367+
if (weightColIsUsed) {
368+
val classifier_ = classifier.asInstanceOf[ClassifierType with HasWeightCol]
369+
paramMap.put(classifier_.weightCol -> getWeightCol)
370+
classifier_.fit(trainingDataset, paramMap)
371+
} else {
372+
classifier.fit(trainingDataset, paramMap)
373+
}
341374
}.toArray[ClassificationModel[_, _]]
342375

343376
if (handlePersistence) {

mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,16 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
157157
assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
158158
}
159159

160+
test("SPARK-21306: OneVsRest should support setWeightCol") {
161+
val dataset2 = dataset.withColumn("weight", lit(1))
162+
// classifier inherits hasWeightCol
163+
val ova = new OneVsRest().setWeightCol("weight").setClassifier(new LogisticRegression())
164+
assert(ova.fit(dataset2) !== null)
165+
// classifier doesn't inherit hasWeightCol
166+
val ova2 = new OneVsRest().setWeightCol("weight").setClassifier(new DecisionTreeClassifier())
167+
assert(ova2.fit(dataset2) !== null)
168+
}
169+
160170
test("OneVsRest.copy and OneVsRestModel.copy") {
161171
val lr = new LogisticRegression()
162172
.setMaxIter(1)

python/pyspark/ml/classification.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,7 +1331,7 @@ def weights(self):
13311331
return self._call_java("weights")
13321332

13331333

1334-
class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasPredictionCol):
1334+
class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasWeightCol, HasPredictionCol):
13351335
"""
13361336
Parameters for OneVsRest and OneVsRestModel.
13371337
"""
@@ -1394,20 +1394,22 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable):
13941394

13951395
@keyword_only
13961396
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
1397-
classifier=None):
1397+
classifier=None, weightCol=None):
13981398
"""
13991399
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
1400-
classifier=None)
1400+
classifier=None, weightCol=None)
14011401
"""
14021402
super(OneVsRest, self).__init__()
14031403
kwargs = self._input_kwargs
14041404
self._set(**kwargs)
14051405

14061406
@keyword_only
14071407
@since("2.0.0")
1408-
def setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None):
1408+
def setParams(self, featuresCol=None, labelCol=None, predictionCol=None,
1409+
classifier=None, weightCol=None):
14091410
"""
1410-
setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None):
1411+
setParams(self, featuresCol=None, labelCol=None, predictionCol=None, \
1412+
classifier=None, weightCol=None):
14111413
Sets params for OneVsRest.
14121414
"""
14131415
kwargs = self._input_kwargs
@@ -1423,7 +1425,18 @@ def _fit(self, dataset):
14231425

14241426
numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"]) + 1
14251427

1426-
multiclassLabeled = dataset.select(labelCol, featuresCol)
1428+
weightCol = None
1429+
if (self.isDefined(self.weightCol) and self.getWeightCol()):
1430+
if isinstance(classifier, HasWeightCol):
1431+
weightCol = self.getWeightCol()
1432+
else:
1433+
warnings.warn("weightCol is ignored, "
1434+
"as it is not supported by {} now.".format(classifier))
1435+
1436+
if weightCol:
1437+
multiclassLabeled = dataset.select(labelCol, featuresCol, weightCol)
1438+
else:
1439+
multiclassLabeled = dataset.select(labelCol, featuresCol)
14271440

14281441
# persist if underlying dataset is not persistent.
14291442
handlePersistence = \
@@ -1439,6 +1452,8 @@ def trainSingleClass(index):
14391452
paramMap = dict([(classifier.labelCol, binaryLabelCol),
14401453
(classifier.featuresCol, featuresCol),
14411454
(classifier.predictionCol, predictionCol)])
1455+
if weightCol:
1456+
paramMap[classifier.weightCol] = weightCol
14421457
return classifier.fit(trainingDataset, paramMap)
14431458

14441459
# TODO: Parallel training for all classes.

python/pyspark/ml/tests.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,6 +1218,20 @@ def test_output_columns(self):
12181218
output = model.transform(df)
12191219
self.assertEqual(output.columns, ["label", "features", "prediction"])
12201220

1221+
def test_support_for_weightCol(self):
1222+
df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8), 1.0),
1223+
(1.0, Vectors.sparse(2, [], []), 1.0),
1224+
(2.0, Vectors.dense(0.5, 0.5), 1.0)],
1225+
["label", "features", "weight"])
1226+
# classifier inherits hasWeightCol
1227+
lr = LogisticRegression(maxIter=5, regParam=0.01)
1228+
ovr = OneVsRest(classifier=lr, weightCol="weight")
1229+
self.assertIsNotNone(ovr.fit(df))
1230+
# classifier doesn't inherit hasWeightCol
1231+
dt = DecisionTreeClassifier()
1232+
ovr2 = OneVsRest(classifier=dt, weightCol="weight")
1233+
self.assertIsNotNone(ovr2.fit(df))
1234+
12211235

12221236
class HashingTFTest(SparkSessionTestCase):
12231237

0 commit comments

Comments
 (0)