Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 119 additions & 2 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,8 +529,11 @@ def summary(self):
"""
if self.hasSummary:
java_blrt_summary = self._call_java("summary")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename this to java_lrt_summary, as it's not always binary logistic regression.

# Note: Once multiclass is added, update this to return correct summary
return BinaryLogisticRegressionTrainingSummary(java_blrt_summary)
if (self.numClasses == 2):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (self.numClasses <= 2)

java_blrt_binarysummary = self._call_java("binarySummary")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this is not necessary, we can just wrap java_lrt_summary with BinaryLogisticRegressionTrainingSummary.

return BinaryLogisticRegressionTrainingSummary(java_blrt_binarysummary)
else:
return LogisticRegressionTrainingSummary(java_blrt_summary)
else:
raise RuntimeError("No training summary available for this %s" %
self.__class__.__name__)
Expand Down Expand Up @@ -585,6 +588,14 @@ def probabilityCol(self):
"""
return self._call_java("probabilityCol")

@property
@since("2.3.0")
def predictionCol(self):
"""
Field in "predictions" which gives the prediction of each class.
"""
return self._call_java("predictionCol")

@property
@since("2.0.0")
def labelCol(self):
Expand All @@ -603,6 +614,112 @@ def featuresCol(self):
"""
return self._call_java("featuresCol")

@property
@since("2.3.0")
def labels(self):
"""
Returns the sequence of labels in ascending order. This order matches the order used
in metrics which are specified as arrays over labels, e.g., truePositiveRateByLabel.

Note: In most cases, it will be values {0.0, 1.0, ..., numClasses-1}, However, if the
training set is missing a label, then all of the arrays over labels
(e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the
expected numClasses.
"""
return self._call_java("labels")

@property
@since("2.3.0")
def truePositiveRateByLabel(self):
"""
Returns true positive rate for each label (category).
"""
return self._call_java("truePositiveRateByLabel")

@property
@since("2.3.0")
def falsePositiveRateByLabel(self):
"""
Returns false positive rate for each label (category).
"""
return self._call_java("falsePositiveRateByLabel")

@property
@since("2.3.0")
def precisionByLabel(self):
"""
Returns precision for each label (category).
"""
return self._call_java("precisionByLabel")

@property
@since("2.3.0")
def recallByLabel(self):
"""
Returns recall for each label (category).
"""
return self._call_java("recallByLabel")

@property
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this annotation.

@since("2.3.0")
def fMeasureByLabel(self, beta=1.0):
"""
Returns f-measure for each label (category).
"""
return self._call_java("fMeasureByLabel", beta)

@property
@since("2.3.0")
def accuracy(self):
"""
Returns accuracy.
(equals to the total number of correctly classified instances
out of the total number of instances.)
"""
return self._call_java("accuracy")

@property
@since("2.3.0")
def weightedTruePositiveRate(self):
"""
Returns weighted true positive rate.
(equals to precision, recall and f-measure)
"""
return self._call_java("weightedTruePositiveRate")

@property
@since("2.3.0")
def weightedFalsePositiveRate(self):
"""
Returns weighted false positive rate.
"""
return self._call_java("weightedFalsePositiveRate")

@property
@since("2.3.0")
def weightedRecall(self):
"""
Returns weighted averaged recall.
(equals to precision, recall and f-measure)
"""
return self._call_java("weightedRecall")

@property
@since("2.3.0")
def weightedPrecision(self):
"""
Returns weighted averaged precision.
"""
return self._call_java("weightedPrecision")

@property
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this annotation.

@since("2.3.0")
def weightedFMeasure(self, beta=1.0):
"""
Returns weighted averaged f-measure.
"""
return self._call_java("weightedFMeasure", beta)


@inherit_doc
class LogisticRegressionTrainingSummary(LogisticRegressionSummary):
Expand Down
34 changes: 34 additions & 0 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,6 +1478,40 @@ def test_logistic_regression_summary(self):
sameSummary = model.evaluate(df)
self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC)

def test_multiclass_logistic_regression_summary(self):
df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
(0.0, 2.0, Vectors.sparse(1, [], [])),
(2.0, 2.0, Vectors.dense(2.0)),
(2.0, 2.0, Vectors.dense(1.9))],
["label", "weight", "features"])
lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False)
model = lr.fit(df)
self.assertTrue(model.hasSummary)
s = model.summary
# test that api is callable and returns expected types
self.assertTrue(isinstance(s.predictions, DataFrame))
self.assertEqual(s.probabilityCol, "probability")
self.assertEqual(s.labelCol, "label")
self.assertEqual(s.featuresCol, "features")
self.assertEqual(s.predictionCol, "prediction")
objHist = s.objectiveHistory
self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float))
self.assertGreater(s.totalIterations, 0)
self.assertTrue(isinstance(s.labels, list))
self.assertTrue(isinstance(s.truePositiveRateByLabel, list))
self.assertTrue(isinstance(s.falsePositiveRateByLabel, list))
self.assertTrue(isinstance(s.precisionByLabel, list))
self.assertTrue(isinstance(s.recallByLabel, list))
self.assertTrue(isinstance(s.fMeasureByLabel, list))
self.assertAlmostEqual(s.accuracy, 0.75, 2)
self.assertAlmostEqual(s.weightedTruePositiveRate, 0.75, 2)
self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.25, 2)
self.assertAlmostEqual(s.weightedRecall, 0.75, 2)
self.assertAlmostEqual(s.weightedPrecision, 0.583, 2)
self.assertAlmostEqual(s.weightedFMeasure, 0.65, 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add these check for the above test_logistic_regression_summary and rename it to test_binary_logistic_regression_summary, since binary logistic regression summary has these variables as well.

# test evaluation (with training dataset) produces a summary with same values
# one check is enough to verify a summary is returned, Scala version runs full test
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add test for evaluation like:

sameSummary = model.evaluate(df)
self.assertAlmostEqual(sameSummary.accuracy, s.accuracy)


def test_gaussian_mixture_summary(self):
data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),),
(Vectors.sparse(1, [], []),)]
Expand Down