Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
modified LogisticRegressionSummary and LogisticRegressionModel in cla…
…ssification.py
  • Loading branch information
Ming Jiang authored and jmwdpk committed Sep 11, 2017
commit 60579d5f36ef26f6e3ec675a795ccc86d6a71c8d
115 changes: 113 additions & 2 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,8 +529,13 @@ def summary(self):
"""
if self.hasSummary:
java_blrt_summary = self._call_java("summary")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename this to java_lrt_summary, as it's not always binary logistic regression.

# Note: Once multiclass is added, update this to return correct summary
return BinaryLogisticRegressionTrainingSummary(java_blrt_summary)
java_blrt_interceptVector = self._call_java("interceptVector")
java_blrt_numClasses = self._call_java("numClasses")
java_blrt_binarysummary = self._call_java("binarySummary")
if (len(java_blrt_interceptVector) == 1):
return BinaryLogisticRegressionTrainingSummary(java_blrt_binarysummary)
else:
return LogisticRegressionTrainingSummary(java_blrt_summary)
else:
raise RuntimeError("No training summary available for this %s" %
self.__class__.__name__)
Expand Down Expand Up @@ -611,6 +616,112 @@ def featuresCol(self):
"""
return self._call_java("featuresCol")

@property
@since("2.3.0")
def labels(self):
"""
Returns the sequence of labels in ascending order. This order matches the order used
in metrics which are specified as arrays over labels, e.g., truePositiveRateByLabel.

Note: In most cases, it will be values {0.0, 1.0, ..., numClasses-1}, However, if the
training set is missing a label, then all of the arrays over labels
(e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the
expected numClasses.
"""
return self._call_java("labels")

@property
@since("2.3.0")
def truePositiveRateByLabel(self):
"""
Returns true positive rate for each label (category).
"""
return self._call_java("truePositiveRateByLabel")

@property
@since("2.3.0")
def falsePositiveRateByLabel(self):
"""
Returns false positive rate for each label (category).
"""
return self._call_java("falsePositiveRateByLabel")

@property
@since("2.3.0")
def precisionByLabel(self):
"""
Returns precision for each label (category).
"""
return self._call_java("precisionByLabel")

@property
@since("2.3.0")
def recallByLabel(self):
"""
Returns recall for each label (category).
"""
return self._call_java("recallByLabel")

@property
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this annotation.

@since("2.3.0")
def fMeasureByLabel(self, beta=1.0):
"""
Returns f-measure for each label (category).
"""
return self._call_java("fMeasureByLabel", beta)

@property
@since("2.3.0")
def accuracy(self):
"""
Returns accuracy.
(equals to the total number of correctly classified instances
out of the total number of instances.)
"""
return self._call_java("accuracy")

@property
@since("2.3.0")
def weightedTruePositiveRate(self):
"""
Returns weighted true positive rate.
(equals to precision, recall and f-measure)
"""
return self._call_java("weightedTruePositiveRate")

@property
@since("2.3.0")
def weightedFalsePositiveRate(self):
"""
Returns weighted false positive rate.
"""
return self._call_java("weightedFalsePositiveRate")

@property
@since("2.3.0")
def weightedRecall(self):
"""
Returns weighted averaged recall.
(equals to precision, recall and f-measure)
"""
return self._call_java("weightedRecall")

@property
@since("2.3.0")
def weightedPrecision(self):
"""
Returns weighted averaged precision.
"""
return self._call_java("weightedPrecision")

@property
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this annotation.

@since("2.3.0")
def weightedFMeasure(self, beta=1.0):
"""
Returns weighted averaged f-measure.
"""
return self._call_java("weightedFMeasure", beta)


@inherit_doc
class LogisticRegressionTrainingSummary(LogisticRegressionSummary):
Expand Down