Skip to content
Closed
Prev Previous commit
Next Next commit
TST: python test for setWeightCol
  • Loading branch information
facaiy committed Jul 13, 2017
commit a57f096eb34e57d6a72221f29d84b5ef0c296b9b
7 changes: 4 additions & 3 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,15 +1255,16 @@ def test_output_columns(self):
output = model.transform(df)
self.assertEqual(output.columns, ["label", "features", "prediction"])

def test_cache_weightCol_if_necessary(self):
def test_support_for_weightCol(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to also test with a classifier that doesn't have a weight col?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Use DecisionTreeClassifier to test.

df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8), 1.0),
(1.0, Vectors.sparse(2, [], []), 1.0),
(2.0, Vectors.dense(0.5, 0.5), 1.0)],
["label", "features", "weight"])
lr = LogisticRegression(maxIter=5, regParam=0.01)
ovr = OneVsRest(classifier=lr, weightCol="weight")
model = ovr.fit(df)
self.assertIsNotNone(model)
self.assertIsNotNone(ovr.fit(df))
ovr2 = OneVsRest(classifier=lr).setWeightCol("weight")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: We can remove test of ovr2 and ovr4, setting param in different way will run the same code at backend.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cleaned.

self.assertIsNotNone(ovr2.fit(df))


class HashingTFTest(SparkSessionTestCase):
Expand Down