Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
390 changes: 390 additions & 0 deletions keras/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,6 +1150,396 @@ def __init__(self, thresholds=None, name=None, dtype=None):
dtype=dtype)


class SensitivitySpecificityBase(Metric):
"""Abstract base class for computing sensitivity and specificity.

For additional information about specificity and sensitivity, see the
following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
"""

def __init__(self, value, num_thresholds=200, name=None, dtype=None):
super(SensitivitySpecificityBase, self).__init__(name=name, dtype=dtype)
if num_thresholds <= 0:
raise ValueError('`num_thresholds` must be > 0.')
self.value = value
self.true_positives = self.add_weight(
'true_positives',
shape=(num_thresholds,),
initializer='zeros')
self.true_negatives = self.add_weight(
'true_negatives',
shape=(num_thresholds,),
initializer='zeros')
self.false_positives = self.add_weight(
'false_positives',
shape=(num_thresholds,),
initializer='zeros')
self.false_negatives = self.add_weight(
'false_negatives',
shape=(num_thresholds,),
initializer='zeros')

# Compute `num_thresholds` thresholds in [0, 1]
if num_thresholds == 1:
self.thresholds = [0.5]
else:
thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
for i in range(num_thresholds - 2)]
self.thresholds = [0.0] + thresholds + [1.0]

def update_state(self, y_true, y_pred, sample_weight=None):
return metrics_utils.update_confusion_matrix_variables(
{
metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives,
metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives,
metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives,
metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives,
},
y_true,
y_pred,
thresholds=self.thresholds,
sample_weight=sample_weight)

def reset_states(self):
num_thresholds = len(self.thresholds)
K.batch_set_value(
[(v, np.zeros((num_thresholds,))) for v in self.variables])


class SensitivityAtSpecificity(SensitivitySpecificityBase):
"""Computes the sensitivity at a given specificity.

`Sensitivity` measures the proportion of actual positives that are correctly
identified as such (tp / (tp + fn)).
`Specificity` measures the proportion of actual negatives that are correctly
identified as such (tn / (tn + fp)).

This metric creates four local variables, `true_positives`, `true_negatives`,
`false_positives` and `false_negatives` that are used to compute the
sensitivity at the given specificity. The threshold for the given specificity
value is computed and used to evaluate the corresponding sensitivity.

If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.

For additional information about specificity and sensitivity, see the
following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity

Usage with the compile API:

```python
model = keras.Model(inputs, outputs)
model.compile(
'sgd',
loss='mse',
metrics=[keras.metrics.SensitivityAtSpecificity()])
```

# Arguments
specificity: A scalar value in range `[0, 1]`.
num_thresholds: (Optional) Defaults to 200. The number of thresholds to
use for matching the given specificity.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
"""

def __init__(self, specificity, num_thresholds=200, name=None, dtype=None):
if specificity < 0 or specificity > 1:
raise ValueError('`specificity` must be in the range [0, 1].')
self.specificity = specificity
self.num_thresholds = num_thresholds
super(SensitivityAtSpecificity, self).__init__(
specificity, num_thresholds=num_thresholds, name=name, dtype=dtype)

def result(self):
# Calculate specificities at all the thresholds.
specificities = K.switch(
K.greater(self.true_negatives + self.false_positives, 0),
(self.true_negatives / (self.true_negatives + self.false_positives)),
K.zeros_like(self.thresholds))

# Find the index of the threshold where the specificity is closest to the
# given specificity.
min_index = K.argmin(
K.abs(specificities - self.value), axis=0)
min_index = K.cast(min_index, 'int32')

# Compute sensitivity at that index.
return K.switch(
K.greater((self.true_positives[min_index] +
self.false_negatives[min_index]), 0),
(self.true_positives[min_index] /
(self.true_positives[min_index] + self.false_negatives[min_index])),
K.zeros_like(self.true_positives[min_index]))

def get_config(self):
config = {
'num_thresholds': self.num_thresholds,
'specificity': self.specificity
}
base_config = super(SensitivityAtSpecificity, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


class AUC(Metric):
"""Computes the approximate AUC (Area under the curve) via a Riemann sum.

This metric creates four local variables, `true_positives`, `true_negatives`,
`false_positives` and `false_negatives` that are used to compute the AUC.
To discretize the AUC curve, a linearly spaced set of thresholds is used to
compute pairs of recall and precision values. The area under the ROC-curve is
therefore computed using the height of the recall values by the false positive
rate, while the area under the PR-curve is the computed using the height of
the precision values by the recall.

This value is ultimately returned as `auc`, an idempotent operation that
computes the area under a discretized curve of precision versus recall values
(computed using the aforementioned variables). The `num_thresholds` variable
controls the degree of discretization with larger numbers of thresholds more
closely approximating the true AUC. The quality of the approximation may vary
dramatically depending on `num_thresholds`. The `thresholds` parameter can be
used to manually specify thresholds which split the predictions more evenly.

For best results, `predictions` should be distributed approximately uniformly
in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
approximation may be poor if this is not the case. Setting `summation_method`
to 'minoring' or 'majoring' can help quantify the error in the approximation
by providing lower or upper bound estimate of the AUC.

If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.

Usage with the compile API:

```python
model = keras.Model(inputs, outputs)
model.compile('sgd', loss='mse', metrics=[keras.metrics.AUC()])
```

# Arguments
num_thresholds: (Optional) Defaults to 200. The number of thresholds to
use when discretizing the roc curve. Values must be > 1.
curve: (Optional) Specifies the name of the curve to be computed, 'ROC'
[default] or 'PR' for the Precision-Recall-curve.
summation_method: (Optional) Specifies the Riemann summation method used
(https://en.wikipedia.org/wiki/Riemann_sum): 'interpolation' [default],
applies mid-point summation scheme for `ROC`. For PR-AUC, interpolates
(true/false) positives but not the ratio that is precision (see Davis
& Goadrich 2006 for details); 'minoring' that applies left summation
for increasing intervals and right summation for decreasing intervals;
'majoring' that does the opposite.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
thresholds: (Optional) A list of floating point values to use as the
thresholds for discretizing the curve. If set, the `num_thresholds`
parameter is ignored. Values should be in [0, 1]. Endpoint thresholds
equal to {-epsilon, 1+epsilon} for a small positive epsilon value will
be automatically included with these to correctly handle predictions
equal to exactly 0 or 1.
"""

def __init__(self,
num_thresholds=200,
curve='ROC',
summation_method='interpolation',
name=None,
dtype=None,
thresholds=None):
# Validate configurations.
if (isinstance(curve, metrics_utils.AUCCurve) and
curve not in list(metrics_utils.AUCCurve)):
raise ValueError('Invalid curve: "{}". Valid options are: "{}"'.format(
curve, list(metrics_utils.AUCCurve)))
if isinstance(
summation_method,
metrics_utils.AUCSummationMethod) and summation_method not in list(
metrics_utils.AUCSummationMethod):
raise ValueError(
'Invalid summation method: "{}". Valid options are: "{}"'.format(
summation_method, list(metrics_utils.AUCSummationMethod)))

# Update properties.
if thresholds is not None:
# If specified, use the supplied thresholds.
self.num_thresholds = len(thresholds) + 2
thresholds = sorted(thresholds)
else:
if num_thresholds <= 1:
raise ValueError('`num_thresholds` must be > 1.')

# Otherwise, linearly interpolate (num_thresholds - 2) thresholds in
# (0, 1).
self.num_thresholds = num_thresholds
thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
for i in range(num_thresholds - 2)]

# Add an endpoint "threshold" below zero and above one for either
# threshold method to account for floating point imprecisions.
self.thresholds = [0.0 - K.epsilon()] + thresholds + [1.0 + K.epsilon()]

if isinstance(curve, metrics_utils.AUCCurve):
self.curve = curve
else:
self.curve = metrics_utils.AUCCurve.from_str(curve)
if isinstance(summation_method, metrics_utils.AUCSummationMethod):
self.summation_method = summation_method
else:
self.summation_method = metrics_utils.AUCSummationMethod.from_str(
summation_method)
super(AUC, self).__init__(name=name, dtype=dtype)

# Create metric variables
self.true_positives = self.add_weight(
'true_positives',
shape=(self.num_thresholds,),
initializer='zeros')
self.true_negatives = self.add_weight(
'true_negatives',
shape=(self.num_thresholds,),
initializer='zeros')
self.false_positives = self.add_weight(
'false_positives',
shape=(self.num_thresholds,),
initializer='zeros')
self.false_negatives = self.add_weight(
'false_negatives',
shape=(self.num_thresholds,),
initializer='zeros')

def update_state(self, y_true, y_pred, sample_weight=None):
return metrics_utils.update_confusion_matrix_variables({
metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives,
metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives,
metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives,
metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives,
}, y_true, y_pred, self.thresholds, sample_weight=sample_weight)

def interpolate_pr_auc(self):
"""Interpolation formula inspired by section 4 of Davis & Goadrich 2006.

https://www.biostat.wisc.edu/~page/rocpr.pdf

Note here we derive & use a closed formula not present in the paper
as follows:

Precision = TP / (TP + FP) = TP / P

Modeling all of TP (true positive), FP (false positive) and their sum
P = TP + FP (predicted positive) as varying linearly within each interval
[A, B] between successive thresholds, we get

Precision slope = dTP / dP
= (TP_B - TP_A) / (P_B - P_A)
= (TP - TP_A) / (P - P_A)
Precision = (TP_A + slope * (P - P_A)) / P

The area within the interval is (slope / total_pos_weight) times

int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P}
int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}

where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in

int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)

Bringing back the factor (slope / total_pos_weight) we'd put aside, we get

slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight

where dTP == TP_B - TP_A.

Note that when P_A == 0 the above calculation simplifies into

int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A)

which is really equivalent to imputing constant precision throughout the
first bucket having >0 true positives.

# Returns
pr_auc: an approximation of the area under the P-R curve.
"""
dtp = self.true_positives[:self.num_thresholds -
1] - self.true_positives[1:]
p = self.true_positives + self.false_positives
dp = p[:self.num_thresholds - 1] - p[1:]

prec_slope = dtp / K.maximum(dp, 0)
intercept = self.true_positives[1:] - (prec_slope * p[1:])

# Logical and
pMin = K.expand_dims(p[:self.num_thresholds - 1] > 0, 0)
pMax = K.expand_dims(p[1:] > 0, 0)
are_different = K.concatenate([pMin, pMax], axis=0)
switch_condition = K.all(are_different, axis=0)

safe_p_ratio = K.switch(
switch_condition,
p[:self.num_thresholds - 1] / K.maximum(p[1:], 0),
K.ones_like(p[1:]))

numer = prec_slope * (dtp + intercept * K.log(safe_p_ratio))
denom = K.maximum(self.true_positives[1:] + self.false_negatives[1:], 0)
return K.sum((numer / denom))

def result(self):
if (self.curve == metrics_utils.AUCCurve.PR and
(self.summation_method ==
metrics_utils.AUCSummationMethod.INTERPOLATION)):
# This use case is different and is handled separately.
return self.interpolate_pr_auc()

# Set `x` and `y` values for the curves based on `curve` config.
recall = K.switch(
K.greater((self.true_positives), 0),
(self.true_positives /
(self.true_positives + self.false_negatives)),
K.zeros_like(self.true_positives))
if self.curve == metrics_utils.AUCCurve.ROC:
fp_rate = K.switch(
K.greater((self.false_positives), 0),
(self.false_positives /
(self.false_positives + self.true_negatives)),
K.zeros_like(self.false_positives))
x = fp_rate
y = recall
else: # curve == 'PR'.
precision = K.switch(
K.greater((self.true_positives), 0),
(self.true_positives / (self.true_positives + self.false_positives)),
K.zeros_like(self.true_positives))
x = recall
y = precision

# Find the rectangle heights based on `summation_method`.
if self.summation_method == metrics_utils.AUCSummationMethod.INTERPOLATION:
# Note: the case ('PR', 'interpolation') has been handled above.
heights = (y[:self.num_thresholds - 1] + y[1:]) / 2.
elif self.summation_method == metrics_utils.AUCSummationMethod.MINORING:
heights = K.minimum(y[:self.num_thresholds - 1], y[1:])
else: # self.summation_method = metrics_utils.AUCSummationMethod.MAJORING:
heights = K.maximum(y[:self.num_thresholds - 1], y[1:])

# Sum up the areas of all the rectangles.
return K.sum((x[:self.num_thresholds - 1] - x[1:]) * heights)

def reset_states(self):
K.batch_set_value(
[(v, np.zeros((self.num_thresholds,))) for v in self.variables])

def get_config(self):
config = {
'num_thresholds': self.num_thresholds,
'curve': self.curve.value,
'summation_method': self.summation_method.value,
# We remove the endpoint thresholds as an inverse of how the thresholds
# were initialized. This ensures that a metric initialized from this
# config has the same thresholds.
'thresholds': self.thresholds[1:-1],
}
base_config = super(AUC, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


def accuracy(y_true, y_pred):
if not K.is_tensor(y_pred):
y_pred = K.constant(y_pred)
Expand Down
Loading