Skip to content
Closed
Prev Previous commit
Next Next commit
ENH: python, cache weightCol
  • Loading branch information
facaiy committed Jul 7, 2017
commit 25d681f38ff670c8209dd1ae5c874e10ac89b402
5 changes: 4 additions & 1 deletion python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -1546,7 +1546,10 @@ def _fit(self, dataset):

numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"]) + 1

multiclassLabeled = dataset.select(labelCol, featuresCol)
if isinstance(classifier, HasWeightCol) and classifier.getWeightCol():
multiclassLabeled = dataset.select(labelCol, featuresCol, classifier.getWeightCol())
else:
multiclassLabeled = dataset.select(labelCol, featuresCol)

# persist if underlying dataset is not persistent.
handlePersistence = \
Expand Down
10 changes: 10 additions & 0 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,6 +1255,16 @@ def test_output_columns(self):
output = model.transform(df)
self.assertEqual(output.columns, ["label", "features", "prediction"])

def test_cache_weightCol_if_necessary(self):
df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8), 1.0),
(1.0, Vectors.sparse(2, [], []), 1.0),
(2.0, Vectors.dense(0.5, 0.5), 1.0)],
["label", "features", "weight"])
lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight")
ovr = OneVsRest(classifier=lr)
model = ovr.fit(df)
self.assertIsNone(model)


class HashingTFTest(SparkSessionTestCase):

Expand Down