diff --git a/keras/metrics.py b/keras/metrics.py index d767562e2e62..b586846d661b 100644 --- a/keras/metrics.py +++ b/keras/metrics.py @@ -1150,6 +1150,396 @@ def __init__(self, thresholds=None, name=None, dtype=None): dtype=dtype) +class SensitivitySpecificityBase(Metric): + """Abstract base class for computing sensitivity and specificity. + + For additional information about specificity and sensitivity, see the + following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity + """ + + def __init__(self, value, num_thresholds=200, name=None, dtype=None): + super(SensitivitySpecificityBase, self).__init__(name=name, dtype=dtype) + if num_thresholds <= 0: + raise ValueError('`num_thresholds` must be > 0.') + self.value = value + self.true_positives = self.add_weight( + 'true_positives', + shape=(num_thresholds,), + initializer='zeros') + self.true_negatives = self.add_weight( + 'true_negatives', + shape=(num_thresholds,), + initializer='zeros') + self.false_positives = self.add_weight( + 'false_positives', + shape=(num_thresholds,), + initializer='zeros') + self.false_negatives = self.add_weight( + 'false_negatives', + shape=(num_thresholds,), + initializer='zeros') + + # Compute `num_thresholds` thresholds in [0, 1] + if num_thresholds == 1: + self.thresholds = [0.5] + else: + thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) + for i in range(num_thresholds - 2)] + self.thresholds = [0.0] + thresholds + [1.0] + + def update_state(self, y_true, y_pred, sample_weight=None): + return metrics_utils.update_confusion_matrix_variables( + { + metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, + metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives, + metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, + metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, + }, + y_true, + y_pred, + thresholds=self.thresholds, + sample_weight=sample_weight) + + def reset_states(self): + num_thresholds = len(self.thresholds) + K.batch_set_value( + [(v, np.zeros((num_thresholds,))) for v in self.variables]) + + +class SensitivityAtSpecificity(SensitivitySpecificityBase): + """Computes the sensitivity at a given specificity. + + `Sensitivity` measures the proportion of actual positives that are correctly + identified as such (tp / (tp + fn)). + `Specificity` measures the proportion of actual negatives that are correctly + identified as such (tn / (tn + fp)). + + This metric creates four local variables, `true_positives`, `true_negatives`, + `false_positives` and `false_negatives` that are used to compute the + sensitivity at the given specificity. The threshold for the given specificity + value is computed and used to evaluate the corresponding sensitivity. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + For additional information about specificity and sensitivity, see the + following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity + + Usage with the compile API: + + ```python + model = keras.Model(inputs, outputs) + model.compile( + 'sgd', + loss='mse', + metrics=[keras.metrics.SensitivityAtSpecificity()]) + ``` + + # Arguments + specificity: A scalar value in range `[0, 1]`. + num_thresholds: (Optional) Defaults to 200. The number of thresholds to + use for matching the given specificity. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + """ + + def __init__(self, specificity, num_thresholds=200, name=None, dtype=None): + if specificity < 0 or specificity > 1: + raise ValueError('`specificity` must be in the range [0, 1].') + self.specificity = specificity + self.num_thresholds = num_thresholds + super(SensitivityAtSpecificity, self).__init__( + specificity, num_thresholds=num_thresholds, name=name, dtype=dtype) + + def result(self): + # Calculate specificities at all the thresholds. + specificities = K.switch( + K.greater(self.true_negatives + self.false_positives, 0), + (self.true_negatives / (self.true_negatives + self.false_positives)), + K.zeros_like(self.thresholds)) + + # Find the index of the threshold where the specificity is closest to the + # given specificity. + min_index = K.argmin( + K.abs(specificities - self.value), axis=0) + min_index = K.cast(min_index, 'int32') + + # Compute sensitivity at that index. + return K.switch( + K.greater((self.true_positives[min_index] + + self.false_negatives[min_index]), 0), + (self.true_positives[min_index] / + (self.true_positives[min_index] + self.false_negatives[min_index])), + K.zeros_like(self.true_positives[min_index])) + + def get_config(self): + config = { + 'num_thresholds': self.num_thresholds, + 'specificity': self.specificity + } + base_config = super(SensitivityAtSpecificity, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +class AUC(Metric): + """Computes the approximate AUC (Area under the curve) via a Riemann sum. + + This metric creates four local variables, `true_positives`, `true_negatives`, + `false_positives` and `false_negatives` that are used to compute the AUC. + To discretize the AUC curve, a linearly spaced set of thresholds is used to + compute pairs of recall and precision values. The area under the ROC-curve is + therefore computed using the height of the recall values by the false positive + rate, while the area under the PR-curve is the computed using the height of + the precision values by the recall. + + This value is ultimately returned as `auc`, an idempotent operation that + computes the area under a discretized curve of precision versus recall values + (computed using the aforementioned variables). The `num_thresholds` variable + controls the degree of discretization with larger numbers of thresholds more + closely approximating the true AUC. The quality of the approximation may vary + dramatically depending on `num_thresholds`. The `thresholds` parameter can be + used to manually specify thresholds which split the predictions more evenly. + + For best results, `predictions` should be distributed approximately uniformly + in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC + approximation may be poor if this is not the case. Setting `summation_method` + to 'minoring' or 'majoring' can help quantify the error in the approximation + by providing lower or upper bound estimate of the AUC. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + Usage with the compile API: + + ```python + model = keras.Model(inputs, outputs) + model.compile('sgd', loss='mse', metrics=[keras.metrics.AUC()]) + ``` + + # Arguments + num_thresholds: (Optional) Defaults to 200. The number of thresholds to + use when discretizing the roc curve. Values must be > 1. + curve: (Optional) Specifies the name of the curve to be computed, 'ROC' + [default] or 'PR' for the Precision-Recall-curve. + summation_method: (Optional) Specifies the Riemann summation method used + (https://en.wikipedia.org/wiki/Riemann_sum): 'interpolation' [default], + applies mid-point summation scheme for `ROC`. For PR-AUC, interpolates + (true/false) positives but not the ratio that is precision (see Davis + & Goadrich 2006 for details); 'minoring' that applies left summation + for increasing intervals and right summation for decreasing intervals; + 'majoring' that does the opposite. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + thresholds: (Optional) A list of floating point values to use as the + thresholds for discretizing the curve. If set, the `num_thresholds` + parameter is ignored. Values should be in [0, 1]. Endpoint thresholds + equal to {-epsilon, 1+epsilon} for a small positive epsilon value will + be automatically included with these to correctly handle predictions + equal to exactly 0 or 1. + """ + + def __init__(self, + num_thresholds=200, + curve='ROC', + summation_method='interpolation', + name=None, + dtype=None, + thresholds=None): + # Validate configurations. + if (isinstance(curve, metrics_utils.AUCCurve) and + curve not in list(metrics_utils.AUCCurve)): + raise ValueError('Invalid curve: "{}". Valid options are: "{}"'.format( + curve, list(metrics_utils.AUCCurve))) + if isinstance( + summation_method, + metrics_utils.AUCSummationMethod) and summation_method not in list( + metrics_utils.AUCSummationMethod): + raise ValueError( + 'Invalid summation method: "{}". Valid options are: "{}"'.format( + summation_method, list(metrics_utils.AUCSummationMethod))) + + # Update properties. + if thresholds is not None: + # If specified, use the supplied thresholds. + self.num_thresholds = len(thresholds) + 2 + thresholds = sorted(thresholds) + else: + if num_thresholds <= 1: + raise ValueError('`num_thresholds` must be > 1.') + + # Otherwise, linearly interpolate (num_thresholds - 2) thresholds in + # (0, 1). + self.num_thresholds = num_thresholds + thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) + for i in range(num_thresholds - 2)] + + # Add an endpoint "threshold" below zero and above one for either + # threshold method to account for floating point imprecisions. + self.thresholds = [0.0 - K.epsilon()] + thresholds + [1.0 + K.epsilon()] + + if isinstance(curve, metrics_utils.AUCCurve): + self.curve = curve + else: + self.curve = metrics_utils.AUCCurve.from_str(curve) + if isinstance(summation_method, metrics_utils.AUCSummationMethod): + self.summation_method = summation_method + else: + self.summation_method = metrics_utils.AUCSummationMethod.from_str( + summation_method) + super(AUC, self).__init__(name=name, dtype=dtype) + + # Create metric variables + self.true_positives = self.add_weight( + 'true_positives', + shape=(self.num_thresholds,), + initializer='zeros') + self.true_negatives = self.add_weight( + 'true_negatives', + shape=(self.num_thresholds,), + initializer='zeros') + self.false_positives = self.add_weight( + 'false_positives', + shape=(self.num_thresholds,), + initializer='zeros') + self.false_negatives = self.add_weight( + 'false_negatives', + shape=(self.num_thresholds,), + initializer='zeros') + + def update_state(self, y_true, y_pred, sample_weight=None): + return metrics_utils.update_confusion_matrix_variables({ + metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, + metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives, + metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, + metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, + }, y_true, y_pred, self.thresholds, sample_weight=sample_weight) + + def interpolate_pr_auc(self): + """Interpolation formula inspired by section 4 of Davis & Goadrich 2006. + + https://www.biostat.wisc.edu/~page/rocpr.pdf + + Note here we derive & use a closed formula not present in the paper + as follows: + + Precision = TP / (TP + FP) = TP / P + + Modeling all of TP (true positive), FP (false positive) and their sum + P = TP + FP (predicted positive) as varying linearly within each interval + [A, B] between successive thresholds, we get + + Precision slope = dTP / dP + = (TP_B - TP_A) / (P_B - P_A) + = (TP - TP_A) / (P - P_A) + Precision = (TP_A + slope * (P - P_A)) / P + + The area within the interval is (slope / total_pos_weight) times + + int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P} + int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P} + + where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in + + int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A) + + Bringing back the factor (slope / total_pos_weight) we'd put aside, we get + + slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight + + where dTP == TP_B - TP_A. + + Note that when P_A == 0 the above calculation simplifies into + + int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A) + + which is really equivalent to imputing constant precision throughout the + first bucket having >0 true positives. + + # Returns + pr_auc: an approximation of the area under the P-R curve. + """ + dtp = self.true_positives[:self.num_thresholds - + 1] - self.true_positives[1:] + p = self.true_positives + self.false_positives + dp = p[:self.num_thresholds - 1] - p[1:] + + prec_slope = dtp / K.maximum(dp, 0) + intercept = self.true_positives[1:] - (prec_slope * p[1:]) + + # Logical and + pMin = K.expand_dims(p[:self.num_thresholds - 1] > 0, 0) + pMax = K.expand_dims(p[1:] > 0, 0) + are_different = K.concatenate([pMin, pMax], axis=0) + switch_condition = K.all(are_different, axis=0) + + safe_p_ratio = K.switch( + switch_condition, + p[:self.num_thresholds - 1] / K.maximum(p[1:], 0), + K.ones_like(p[1:])) + + numer = prec_slope * (dtp + intercept * K.log(safe_p_ratio)) + denom = K.maximum(self.true_positives[1:] + self.false_negatives[1:], 0) + return K.sum((numer / denom)) + + def result(self): + if (self.curve == metrics_utils.AUCCurve.PR and + (self.summation_method == + metrics_utils.AUCSummationMethod.INTERPOLATION)): + # This use case is different and is handled separately. + return self.interpolate_pr_auc() + + # Set `x` and `y` values for the curves based on `curve` config. + recall = K.switch( + K.greater((self.true_positives), 0), + (self.true_positives / + (self.true_positives + self.false_negatives)), + K.zeros_like(self.true_positives)) + if self.curve == metrics_utils.AUCCurve.ROC: + fp_rate = K.switch( + K.greater((self.false_positives), 0), + (self.false_positives / + (self.false_positives + self.true_negatives)), + K.zeros_like(self.false_positives)) + x = fp_rate + y = recall + else: # curve == 'PR'. + precision = K.switch( + K.greater((self.true_positives), 0), + (self.true_positives / (self.true_positives + self.false_positives)), + K.zeros_like(self.true_positives)) + x = recall + y = precision + + # Find the rectangle heights based on `summation_method`. + if self.summation_method == metrics_utils.AUCSummationMethod.INTERPOLATION: + # Note: the case ('PR', 'interpolation') has been handled above. + heights = (y[:self.num_thresholds - 1] + y[1:]) / 2. + elif self.summation_method == metrics_utils.AUCSummationMethod.MINORING: + heights = K.minimum(y[:self.num_thresholds - 1], y[1:]) + else: # self.summation_method = metrics_utils.AUCSummationMethod.MAJORING: + heights = K.maximum(y[:self.num_thresholds - 1], y[1:]) + + # Sum up the areas of all the rectangles. + return K.sum((x[:self.num_thresholds - 1] - x[1:]) * heights) + + def reset_states(self): + K.batch_set_value( + [(v, np.zeros((self.num_thresholds,))) for v in self.variables]) + + def get_config(self): + config = { + 'num_thresholds': self.num_thresholds, + 'curve': self.curve.value, + 'summation_method': self.summation_method.value, + # We remove the endpoint thresholds as an inverse of how the thresholds + # were initialized. This ensures that a metric initialized from this + # config has the same thresholds. + 'thresholds': self.thresholds[1:-1], + } + base_config = super(AUC, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + def accuracy(y_true, y_pred): if not K.is_tensor(y_pred): y_pred = K.constant(y_pred) diff --git a/keras/utils/metrics_utils.py b/keras/utils/metrics_utils.py index 6177899a286b..3548fd490a3c 100644 --- a/keras/utils/metrics_utils.py +++ b/keras/utils/metrics_utils.py @@ -9,6 +9,9 @@ from . import losses_utils +NEG_INF = -1e10 + + class Reduction(object): """Types of metrics reduction. @@ -94,7 +97,53 @@ class ConfusionMatrix(Enum): FALSE_NEGATIVES = 'fn' +class AUCCurve(Enum): + """Type of AUC Curve (ROC or PR).""" + ROC = 'ROC' + PR = 'PR' + + @staticmethod + def from_str(key): + if key in ('pr', 'PR'): + return AUCCurve.PR + elif key in ('roc', 'ROC'): + return AUCCurve.ROC + else: + raise ValueError('Invalid AUC curve value "%s".' % key) + + +class AUCSummationMethod(Enum): + """Type of AUC summation method. + + https://en.wikipedia.org/wiki/Riemann_sum) + + Contains the following values: + * 'interpolation': Applies mid-point summation scheme for `ROC` curve. For + `PR` curve, interpolates (true/false) positives but not the ratio that is + precision (see Davis & Goadrich 2006 for details). + * 'minoring': Applies left summation for increasing intervals and right + summation for decreasing intervals. + * 'majoring': Applies right summation for increasing intervals and left + summation for decreasing intervals. + """ + INTERPOLATION = 'interpolation' + MAJORING = 'majoring' + MINORING = 'minoring' + + @staticmethod + def from_str(key): + if key in ('interpolation', 'Interpolation'): + return AUCSummationMethod.INTERPOLATION + elif key in ('majoring', 'Majoring'): + return AUCSummationMethod.MAJORING + elif key in ('minoring', 'Minoring'): + return AUCSummationMethod.MINORING + else: + raise ValueError('Invalid AUC summation method value "%s".' % key) + + def weighted_assign_add(label, pred, weights, var): + # Logical and label = K.expand_dims(label, 0) pred = K.expand_dims(pred, 0) are_different = K.concatenate([label, pred], axis=0) @@ -222,10 +271,6 @@ def update_confusion_matrix_variables(variables_to_update, y_pred, K.cast(sample_weight, dtype=K.floatx())) weights_tiled = K.tile( K.reshape(weights, [1, -1]), [num_thresholds, 1]) - - print('label_is_pos: ', label_is_pos) - print('pred_is_pos: ', pred_is_pos) - print('Sample_weight: ', weights_tiled) else: weights_tiled = None diff --git a/tests/keras/metrics_confusion_matrix_test.py b/tests/keras/metrics_confusion_matrix_test.py index 364a35c22b1e..9697d573f1bb 100644 --- a/tests/keras/metrics_confusion_matrix_test.py +++ b/tests/keras/metrics_confusion_matrix_test.py @@ -4,6 +4,7 @@ from keras import metrics from keras import backend as K +from keras.utils import metrics_utils if K.backend() != 'tensorflow': # Need TensorFlow to use metric.__call__ @@ -241,3 +242,285 @@ def test_weighted_with_thresholds(self): result = fn_obj(y_true, y_pred, sample_weight=sample_weight) assert np.allclose([4., 16., 23.], K.eval(result)) + + +class TestSensitivityAtSpecificity(object): + + def test_config(self): + s_obj = metrics.SensitivityAtSpecificity( + 0.4, num_thresholds=100, name='sensitivity_at_specificity_1') + assert s_obj.name == 'sensitivity_at_specificity_1' + assert len(s_obj.weights) == 4 + assert s_obj.specificity == 0.4 + assert s_obj.num_thresholds == 100 + + # Check save and restore config + s_obj2 = metrics.SensitivityAtSpecificity.from_config(s_obj.get_config()) + assert s_obj2.name == 'sensitivity_at_specificity_1' + assert len(s_obj2.weights) == 4 + assert s_obj2.specificity == 0.4 + assert s_obj2.num_thresholds == 100 + + def test_unweighted_all_correct(self): + s_obj = metrics.SensitivityAtSpecificity(0.7) + inputs = np.random.randint(0, 2, size=(100, 1)) + y_pred = K.constant(inputs, dtype='float32') + y_true = K.constant(inputs) + result = s_obj(y_true, y_pred) + assert np.isclose(1, K.eval(result)) + + def test_unweighted_high_specificity(self): + s_obj = metrics.SensitivityAtSpecificity(0.8) + pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.45, 0.5, 0.8, 0.9] + label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] + + y_pred = K.constant(pred_values, dtype='float32') + y_true = K.constant(label_values) + result = s_obj(y_true, y_pred) + assert np.isclose(0.8, K.eval(result)) + + def test_unweighted_low_specificity(self): + s_obj = metrics.SensitivityAtSpecificity(0.4) + pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26] + label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] + + y_pred = K.constant(pred_values, dtype='float32') + y_true = K.constant(label_values) + result = s_obj(y_true, y_pred) + assert np.isclose(0.6, K.eval(result)) + + def test_weighted(self): + s_obj = metrics.SensitivityAtSpecificity(0.4) + pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26] + label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] + weight_values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + y_pred = K.constant(pred_values, dtype='float32') + y_true = K.constant(label_values, dtype='float32') + weights = K.constant(weight_values) + result = s_obj(y_true, y_pred, sample_weight=weights) + assert np.isclose(0.675, K.eval(result)) + + def test_invalid_specificity(self): + with pytest.raises(Exception): + metrics.SensitivityAtSpecificity(-1) + + def test_invalid_num_thresholds(self): + with pytest.raises(Exception): + metrics.SensitivityAtSpecificity(0.4, num_thresholds=-1) + + +class TestAUC(object): + + def setup(self): + self.num_thresholds = 3 + self.y_pred = K.constant([0, 0.5, 0.3, 0.9], dtype='float32') + self.y_true = K.constant([0, 0, 1, 1]) + self.sample_weight = [1, 2, 3, 4] + + # threshold values are [0 - 1e-7, 0.5, 1 + 1e-7] + # y_pred when threshold = 0 - 1e-7 : [1, 1, 1, 1] + # y_pred when threshold = 0.5 : [0, 0, 0, 1] + # y_pred when threshold = 1 + 1e-7 : [0, 0, 0, 0] + + # without sample_weight: + # tp = np.sum([[0, 0, 1, 1], [0, 0, 0, 1], [0, 0, 0, 0]], axis=1) + # fp = np.sum([[1, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], axis=1) + # fn = np.sum([[0, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 1]], axis=1) + # tn = np.sum([[0, 0, 0, 0], [1, 1, 0, 0], [1, 1, 0, 0]], axis=1) + + # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2] + + # with sample_weight: + # tp = np.sum([[0, 0, 3, 4], [0, 0, 0, 4], [0, 0, 0, 0]], axis=1) + # fp = np.sum([[1, 2, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], axis=1) + # fn = np.sum([[0, 0, 0, 0], [0, 0, 3, 0], [0, 0, 3, 4]], axis=1) + # tn = np.sum([[0, 0, 0, 0], [1, 2, 0, 0], [1, 2, 0, 0]], axis=1) + + # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3] + + def test_config(self): + auc_obj = metrics.AUC( + num_thresholds=100, + curve='PR', + summation_method='majoring', + name='auc_1') + assert auc_obj.name == 'auc_1' + assert len(auc_obj.weights) == 4 + assert auc_obj.num_thresholds == 100 + assert auc_obj.curve == metrics_utils.AUCCurve.PR + assert auc_obj.summation_method == metrics_utils.AUCSummationMethod.MAJORING + + # Check save and restore config. + auc_obj2 = metrics.AUC.from_config(auc_obj.get_config()) + assert auc_obj2.name == 'auc_1' + assert len(auc_obj2.weights) == 4 + assert auc_obj2.num_thresholds == 100 + assert auc_obj2.curve == metrics_utils.AUCCurve.PR + assert auc_obj2.summation_method == metrics_utils.AUCSummationMethod.MAJORING + + def test_config_manual_thresholds(self): + auc_obj = metrics.AUC( + num_thresholds=None, + curve='PR', + summation_method='majoring', + name='auc_1', + thresholds=[0.3, 0.5]) + assert auc_obj.name == 'auc_1' + assert len(auc_obj.weights) == 4 + assert auc_obj.num_thresholds == 4 + assert np.allclose(auc_obj.thresholds, [0.0, 0.3, 0.5, 1.0], atol=1e-3) + assert auc_obj.curve == metrics_utils.AUCCurve.PR + assert auc_obj.summation_method == metrics_utils.AUCSummationMethod.MAJORING + + # Check save and restore config. + auc_obj2 = metrics.AUC.from_config(auc_obj.get_config()) + assert auc_obj2.name == 'auc_1' + assert len(auc_obj2.weights) == 4 + assert auc_obj2.num_thresholds == 4 + assert auc_obj2.curve == metrics_utils.AUCCurve.PR + assert auc_obj2.summation_method == metrics_utils.AUCSummationMethod.MAJORING + + def test_unweighted_all_correct(self): + self.setup() + auc_obj = metrics.AUC() + result = auc_obj(self.y_true, self.y_true) + assert K.eval(result) == 1 + + def test_unweighted(self): + self.setup() + auc_obj = metrics.AUC(num_thresholds=self.num_thresholds) + result = auc_obj(self.y_true, self.y_pred) + + # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2] + # recall = [2/2, 1/(1+1), 0] = [1, 0.5, 0] + # fp_rate = [2/2, 0, 0] = [1, 0, 0] + # heights = [(1 + 0.5)/2, (0.5 + 0)/2] = [0.75, 0.25] + # widths = [(1 - 0), (0 - 0)] = [1, 0] + expected_result = (0.75 * 1 + 0.25 * 0) + assert np.allclose(K.eval(result), expected_result, atol=1e-3) + + def test_manual_thresholds(self): + self.setup() + # Verify that when specified, thresholds are used instead of num_thresholds. + auc_obj = metrics.AUC(num_thresholds=2, thresholds=[0.5]) + assert auc_obj.num_thresholds == 3 + assert np.allclose(auc_obj.thresholds, [0.0, 0.5, 1.0], atol=1e-3) + result = auc_obj(self.y_true, self.y_pred) + + # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2] + # recall = [2/2, 1/(1+1), 0] = [1, 0.5, 0] + # fp_rate = [2/2, 0, 0] = [1, 0, 0] + # heights = [(1 + 0.5)/2, (0.5 + 0)/2] = [0.75, 0.25] + # widths = [(1 - 0), (0 - 0)] = [1, 0] + expected_result = (0.75 * 1 + 0.25 * 0) + assert np.allclose(K.eval(result), expected_result, atol=1e-3) + + def test_weighted_roc_interpolation(self): + self.setup() + auc_obj = metrics.AUC(num_thresholds=self.num_thresholds) + result = auc_obj(self.y_true, self.y_pred, sample_weight=self.sample_weight) + + # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3] + # recall = [7/7, 4/(4+3), 0] = [1, 0.571, 0] + # fp_rate = [3/3, 0, 0] = [1, 0, 0] + # heights = [(1 + 0.571)/2, (0.571 + 0)/2] = [0.7855, 0.2855] + # widths = [(1 - 0), (0 - 0)] = [1, 0] + expected_result = (0.7855 * 1 + 0.2855 * 0) + assert np.allclose(K.eval(result), expected_result, atol=1e-3) + + def test_weighted_roc_majoring(self): + self.setup() + auc_obj = metrics.AUC( + num_thresholds=self.num_thresholds, summation_method='majoring') + result = auc_obj(self.y_true, self.y_pred, sample_weight=self.sample_weight) + + # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3] + # recall = [7/7, 4/(4+3), 0] = [1, 0.571, 0] + # fp_rate = [3/3, 0, 0] = [1, 0, 0] + # heights = [max(1, 0.571), max(0.571, 0)] = [1, 0.571] + # widths = [(1 - 0), (0 - 0)] = [1, 0] + expected_result = (1 * 1 + 0.571 * 0) + assert np.allclose(K.eval(result), expected_result, atol=1e-3) + + def test_weighted_roc_minoring(self): + self.setup() + auc_obj = metrics.AUC( + num_thresholds=self.num_thresholds, summation_method='minoring') + result = auc_obj(self.y_true, self.y_pred, sample_weight=self.sample_weight) + + # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3] + # recall = [7/7, 4/(4+3), 0] = [1, 0.571, 0] + # fp_rate = [3/3, 0, 0] = [1, 0, 0] + # heights = [min(1, 0.571), min(0.571, 0)] = [0.571, 0] + # widths = [(1 - 0), (0 - 0)] = [1, 0] + expected_result = (0.571 * 1 + 0 * 0) + assert np.allclose(K.eval(result), expected_result, atol=1e-3) + + def test_weighted_pr_majoring(self): + self.setup() + auc_obj = metrics.AUC( + num_thresholds=self.num_thresholds, + curve='PR', + summation_method='majoring') + result = auc_obj(self.y_true, self.y_pred, sample_weight=self.sample_weight) + + # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3] + # precision = [7/(7+3), 4/4, 0] = [0.7, 1, 0] + # recall = [7/7, 4/(4+3), 0] = [1, 0.571, 0] + # heights = [max(0.7, 1), max(1, 0)] = [1, 1] + # widths = [(1 - 0.571), (0.571 - 0)] = [0.429, 0.571] + expected_result = (1 * 0.429 + 1 * 0.571) + assert np.allclose(K.eval(result), expected_result, atol=1e-3) + + def test_weighted_pr_minoring(self): + self.setup() + auc_obj = metrics.AUC( + num_thresholds=self.num_thresholds, + curve='PR', + summation_method='minoring') + result = auc_obj(self.y_true, self.y_pred, sample_weight=self.sample_weight) + + # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3] + # precision = [7/(7+3), 4/4, 0] = [0.7, 1, 0] + # recall = [7/7, 4/(4+3), 0] = [1, 0.571, 0] + # heights = [min(0.7, 1), min(1, 0)] = [0.7, 0] + # widths = [(1 - 0.571), (0.571 - 0)] = [0.429, 0.571] + expected_result = (0.7 * 0.429 + 0 * 0.571) + assert np.allclose(K.eval(result), expected_result, atol=1e-3) + + def test_weighted_pr_interpolation(self): + self.setup() + auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, curve='PR') + result = auc_obj(self.y_true, self.y_pred, sample_weight=self.sample_weight) + + # auc = (slope / Total Pos) * [dTP - intercept * log(Pb/Pa)] + + # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3] + # P = tp + fp = [10, 4, 0] + # dTP = [7-4, 4-0] = [3, 4] + # dP = [10-4, 4-0] = [6, 4] + # slope = dTP/dP = [0.5, 1] + # intercept = (TPa+(slope*Pa) = [(4 - 0.5*4), (0 - 1*0)] = [2, 0] + # (Pb/Pa) = (Pb/Pa) if Pb > 0 AND Pa > 0 else 1 = [10/4, 4/0] = [2.5, 1] + # auc * TotalPos = [(0.5 * (3 + 2 * log(2.5))), (1 * (4 + 0))] + # = [2.416, 4] + # auc = [2.416, 4]/(tp[1:]+fn[1:]) + # expected_result = (2.416 / 7 + 4 / 7) + expected_result = 0.345 + 0.571 + assert np.allclose(K.eval(result), expected_result, atol=1e-3) + + def test_invalid_num_thresholds(self): + with pytest.raises(Exception): + metrics.AUC(num_thresholds=-1) + + with pytest.raises(Exception): + metrics.AUC(num_thresholds=1) + + def test_invalid_curve(self): + with pytest.raises(Exception): + metrics.AUC(curve='Invalid') + + def test_invalid_summation_method(self): + with pytest.raises(Exception): + metrics.AUC(summation_method='Invalid')