Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
test on multiclass summary
  • Loading branch information
Ming Jiang authored and jmwdpk committed Sep 11, 2017
commit 1a73e6c6d20d6374379ec2fb237e7f596e77bc62
5 changes: 2 additions & 3 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,10 +529,9 @@ def summary(self):
"""
if self.hasSummary:
java_blrt_summary = self._call_java("summary")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename this to java_lrt_summary, as it's not always binary logistic regression.

java_blrt_interceptVector = self._call_java("interceptVector")
java_blrt_numClasses = self._call_java("numClasses")
java_blrt_binarysummary = self._call_java("binarySummary")
if (len(java_blrt_interceptVector) == 1):
if (java_blrt_numClasses == 2):
java_blrt_binarysummary = self._call_java("binarySummary")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this is not necessary, we can just wrap java_lrt_summary with BinaryLogisticRegressionTrainingSummary.

return BinaryLogisticRegressionTrainingSummary(java_blrt_binarysummary)
else:
return LogisticRegressionTrainingSummary(java_blrt_summary)
Expand Down
34 changes: 34 additions & 0 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,6 +1478,40 @@ def test_logistic_regression_summary(self):
sameSummary = model.evaluate(df)
self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC)

def test_multiclass_logistic_regression_summary(self):
df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
(0.0, 2.0, Vectors.sparse(1, [], [])),
(2.0, 2.0, Vectors.dense(2.0)),
(2.0, 2.0, Vectors.dense(1.9))],
["label", "weight", "features"])
lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False)
model = lr.fit(df)
self.assertTrue(model.hasSummary)
s = model.summary
# test that api is callable and returns expected types
self.assertTrue(isinstance(s.predictions, DataFrame))
self.assertEqual(s.probabilityCol, "probability")
self.assertEqual(s.labelCol, "label")
self.assertEqual(s.featuresCol, "features")
self.assertEqual(s.predictionCol, "prediction")
objHist = s.objectiveHistory
self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float))
self.assertGreater(s.totalIterations, 0)
self.assertTrue(isinstance(s.labels, list))
self.assertTrue(isinstance(s.truePositiveRateByLabel, list))
self.assertTrue(isinstance(s.falsePositiveRateByLabel, list))
self.assertTrue(isinstance(s.precisionByLabel, list))
self.assertTrue(isinstance(s.recallByLabel, list))
self.assertTrue(isinstance(s.fMeasureByLabel, list))
self.assertAlmostEqual(s.accuracy, 0.75, 2)
self.assertAlmostEqual(s.weightedTruePositiveRate, 0.75, 2)
self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.25, 2)
self.assertAlmostEqual(s.weightedRecall, 0.75, 2)
self.assertAlmostEqual(s.weightedPrecision, 0.583, 2)
self.assertAlmostEqual(s.weightedFMeasure, 0.65, 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add these check for the above test_logistic_regression_summary and rename it to test_binary_logistic_regression_summary, since binary logistic regression summary has these variables as well.

# test evaluation (with training dataset) produces a summary with same values
# one check is enough to verify a summary is returned, Scala version runs full test
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add test for evaluation like:

sameSummary = model.evaluate(df)
self.assertAlmostEqual(sameSummary.accuracy, s.accuracy)


def test_gaussian_mixture_summary(self):
data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),),
(Vectors.sparse(1, [], []),)]
Expand Down