Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2416,6 +2416,18 @@ class LogisticRegressionSuite
blorSummary.recallByThreshold.collect() === sameBlorSummary.recallByThreshold.collect())
assert(
blorSummary.precisionByThreshold.collect() === sameBlorSummary.precisionByThreshold.collect())
assert(blorSummary.labels === sameBlorSummary.labels)
assert(blorSummary.truePositiveRateByLabel === sameBlorSummary.truePositiveRateByLabel)
assert(blorSummary.falsePositiveRateByLabel === sameBlorSummary.falsePositiveRateByLabel)
assert(blorSummary.precisionByLabel === sameBlorSummary.precisionByLabel)
assert(blorSummary.recallByLabel === sameBlorSummary.recallByLabel)
assert(blorSummary.fMeasureByLabel === sameBlorSummary.fMeasureByLabel)
assert(blorSummary.accuracy === sameBlorSummary.accuracy)
assert(blorSummary.weightedTruePositiveRate === sameBlorSummary.weightedTruePositiveRate)
assert(blorSummary.weightedFalsePositiveRate === sameBlorSummary.weightedFalsePositiveRate)
assert(blorSummary.weightedRecall === sameBlorSummary.weightedRecall)
assert(blorSummary.weightedPrecision === sameBlorSummary.weightedPrecision)
assert(blorSummary.weightedFMeasure === sameBlorSummary.weightedFMeasure)

lr.setFamily("multinomial")
val mlorModel = lr.fit(smallMultinomialDataset)
Expand Down
120 changes: 117 additions & 3 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,9 +528,11 @@ def summary(self):
trained on the training set. An exception is thrown if `trainingSummary is None`.
"""
if self.hasSummary:
java_blrt_summary = self._call_java("summary")
# Note: Once multiclass is added, update this to return correct summary
return BinaryLogisticRegressionTrainingSummary(java_blrt_summary)
java_lrt_summary = self._call_java("summary")
if self.numClasses <= 2:
return BinaryLogisticRegressionTrainingSummary(java_lrt_summary)
else:
return LogisticRegressionTrainingSummary(java_lrt_summary)
else:
raise RuntimeError("No training summary available for this %s" %
self.__class__.__name__)
Expand Down Expand Up @@ -585,6 +587,14 @@ def probabilityCol(self):
"""
return self._call_java("probabilityCol")

@property
@since("2.3.0")
def predictionCol(self):
"""
Field in "predictions" which gives the prediction of each class.
"""
return self._call_java("predictionCol")

@property
@since("2.0.0")
def labelCol(self):
Expand All @@ -603,6 +613,110 @@ def featuresCol(self):
"""
return self._call_java("featuresCol")

@property
@since("2.3.0")
def labels(self):
"""
Returns the sequence of labels in ascending order. This order matches the order used
in metrics which are specified as arrays over labels, e.g., truePositiveRateByLabel.

Note: In most cases, it will be values {0.0, 1.0, ..., numClasses-1}, However, if the
training set is missing a label, then all of the arrays over labels
(e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the
expected numClasses.
"""
return self._call_java("labels")

@property
@since("2.3.0")
def truePositiveRateByLabel(self):
"""
Returns true positive rate for each label (category).
"""
return self._call_java("truePositiveRateByLabel")

@property
@since("2.3.0")
def falsePositiveRateByLabel(self):
"""
Returns false positive rate for each label (category).
"""
return self._call_java("falsePositiveRateByLabel")

@property
@since("2.3.0")
def precisionByLabel(self):
"""
Returns precision for each label (category).
"""
return self._call_java("precisionByLabel")

@property
@since("2.3.0")
def recallByLabel(self):
"""
Returns recall for each label (category).
"""
return self._call_java("recallByLabel")

@since("2.3.0")
def fMeasureByLabel(self, beta=1.0):
"""
Returns f-measure for each label (category).
"""
return self._call_java("fMeasureByLabel", beta)

@property
@since("2.3.0")
def accuracy(self):
"""
Returns accuracy.
(equals to the total number of correctly classified instances
out of the total number of instances.)
"""
return self._call_java("accuracy")

@property
@since("2.3.0")
def weightedTruePositiveRate(self):
"""
Returns weighted true positive rate.
(equals to precision, recall and f-measure)
"""
return self._call_java("weightedTruePositiveRate")

@property
@since("2.3.0")
def weightedFalsePositiveRate(self):
"""
Returns weighted false positive rate.
"""
return self._call_java("weightedFalsePositiveRate")

@property
@since("2.3.0")
def weightedRecall(self):
"""
Returns weighted averaged recall.
(equals to precision, recall and f-measure)
"""
return self._call_java("weightedRecall")

@property
@since("2.3.0")
def weightedPrecision(self):
"""
Returns weighted averaged precision.
"""
return self._call_java("weightedPrecision")

@since("2.3.0")
def weightedFMeasure(self, beta=1.0):
"""
Returns weighted averaged f-measure.
"""
return self._call_java("weightedFMeasure", beta)


@inherit_doc
class LogisticRegressionTrainingSummary(LogisticRegressionSummary):
Expand Down
61 changes: 60 additions & 1 deletion python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1451,7 +1451,7 @@ def test_glr_summary(self):
sameSummary = model.evaluate(df)
self.assertAlmostEqual(sameSummary.deviance, s.deviance)

def test_logistic_regression_summary(self):
def test_binary_logistic_regression_summary(self):
df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
(0.0, 2.0, Vectors.sparse(1, [], []))],
["label", "weight", "features"])
Expand All @@ -1464,20 +1464,79 @@ def test_logistic_regression_summary(self):
self.assertEqual(s.probabilityCol, "probability")
self.assertEqual(s.labelCol, "label")
self.assertEqual(s.featuresCol, "features")
self.assertEqual(s.predictionCol, "prediction")
objHist = s.objectiveHistory
self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float))
self.assertGreater(s.totalIterations, 0)
self.assertTrue(isinstance(s.labels, list))
self.assertTrue(isinstance(s.truePositiveRateByLabel, list))
self.assertTrue(isinstance(s.falsePositiveRateByLabel, list))
self.assertTrue(isinstance(s.precisionByLabel, list))
self.assertTrue(isinstance(s.recallByLabel, list))
self.assertTrue(isinstance(s.fMeasureByLabel(), list))
self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list))
self.assertTrue(isinstance(s.roc, DataFrame))
self.assertAlmostEqual(s.areaUnderROC, 1.0, 2)
self.assertTrue(isinstance(s.pr, DataFrame))
self.assertTrue(isinstance(s.fMeasureByThreshold, DataFrame))
self.assertTrue(isinstance(s.precisionByThreshold, DataFrame))
self.assertTrue(isinstance(s.recallByThreshold, DataFrame))
self.assertAlmostEqual(s.accuracy, 1.0, 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

care to add these to the scala unit test for binary summary as well?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also nit, but should probably add tests for all the new attributes, like falsePositiveRateByLabel as below.

self.assertAlmostEqual(s.weightedTruePositiveRate, 1.0, 2)
self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.0, 2)
self.assertAlmostEqual(s.weightedRecall, 1.0, 2)
self.assertAlmostEqual(s.weightedPrecision, 1.0, 2)
self.assertAlmostEqual(s.weightedFMeasure(), 1.0, 2)
self.assertAlmostEqual(s.weightedFMeasure(1.0), 1.0, 2)
# test evaluation (with training dataset) produces a summary with same values
# one check is enough to verify a summary is returned, Scala version runs full test
sameSummary = model.evaluate(df)
self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC)

def test_multiclass_logistic_regression_summary(self):
df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
(0.0, 2.0, Vectors.sparse(1, [], [])),
(2.0, 2.0, Vectors.dense(2.0)),
(2.0, 2.0, Vectors.dense(1.9))],
["label", "weight", "features"])
lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False)
model = lr.fit(df)
self.assertTrue(model.hasSummary)
s = model.summary
# test that api is callable and returns expected types
self.assertTrue(isinstance(s.predictions, DataFrame))
self.assertEqual(s.probabilityCol, "probability")
self.assertEqual(s.labelCol, "label")
self.assertEqual(s.featuresCol, "features")
self.assertEqual(s.predictionCol, "prediction")
objHist = s.objectiveHistory
self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float))
self.assertGreater(s.totalIterations, 0)
self.assertTrue(isinstance(s.labels, list))
self.assertTrue(isinstance(s.truePositiveRateByLabel, list))
self.assertTrue(isinstance(s.falsePositiveRateByLabel, list))
self.assertTrue(isinstance(s.precisionByLabel, list))
self.assertTrue(isinstance(s.recallByLabel, list))
self.assertTrue(isinstance(s.fMeasureByLabel(), list))
self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list))
self.assertAlmostEqual(s.accuracy, 0.75, 2)
self.assertAlmostEqual(s.weightedTruePositiveRate, 0.75, 2)
self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.25, 2)
self.assertAlmostEqual(s.weightedRecall, 0.75, 2)
self.assertAlmostEqual(s.weightedPrecision, 0.583, 2)
self.assertAlmostEqual(s.weightedFMeasure(), 0.65, 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add beta=1.0 to the methods that take beta as a parameter.

self.assertAlmostEqual(s.weightedFMeasure(1.0), 0.65, 2)
# test evaluation (with training dataset) produces a summary with same values
# one check is enough to verify a summary is returned, Scala version runs full test
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add test for evaluation like:

sameSummary = model.evaluate(df)
self.assertAlmostEqual(sameSummary.accuracy, s.accuracy)

sameSummary = model.evaluate(df)
self.assertAlmostEqual(sameSummary.accuracy, s.accuracy)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Like mentioned in annotation, one check is enough to verify a summary is returned, let's remove others to simplify the test. Thanks.

self.assertAlmostEqual(sameSummary.weightedTruePositiveRate, s.weightedTruePositiveRate)
self.assertAlmostEqual(sameSummary.weightedFalsePositiveRate, s.weightedFalsePositiveRate)
self.assertAlmostEqual(sameSummary.weightedRecall, s.weightedRecall)
self.assertAlmostEqual(sameSummary.weightedPrecision, s.weightedPrecision)
self.assertAlmostEqual(sameSummary.weightedFMeasure(), s.weightedFMeasure())
self.assertAlmostEqual(sameSummary.weightedFMeasure(1.0), s.weightedFMeasure(1.0))

def test_gaussian_mixture_summary(self):
data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),),
(Vectors.sparse(1, [], []),)]
Expand Down