diff --git a/keras/metrics.py b/keras/metrics.py index 1a2a2200191d..260889c6c598 100644 --- a/keras/metrics.py +++ b/keras/metrics.py @@ -261,7 +261,7 @@ class Mean(Reduce): Usage: ```python - m = tf.keras.metrics.Mean() + m = keras.metrics.Mean() m.update_state([1, 3, 5, 7]) m.result() ``` @@ -321,7 +321,7 @@ def update_state(self, y_true, y_pred, sample_weight=None): def get_config(self): config = {} for k, v in six.iteritems(self._fn_kwargs): - config[k] = K.eval(v) if is_tensor_or_variable(v) else v + config[k] = K.eval(v) if K.is_tensor(v) else v base_config = super(MeanMetricWrapper, self).get_config() return dict(list(base_config.items()) + list(config.items())) @@ -350,7 +350,199 @@ def __init__(self, name='mean_squared_error', dtype=None): mean_squared_error, name, dtype=dtype) -def binary_accuracy(y_true, y_pred): +class Accuracy(MeanMetricWrapper): + """Calculates how often predictions matches labels. + + For example, if `y_true` is [1, 2, 3, 4] and `y_pred` is [0, 2, 3, 4] + then the accuracy is 3/4 or .75. If the weights were specified as + [1, 1, 0, 0] then the accuracy would be 1/2 or .5. + + This metric creates two local variables, `total` and `count` that are used to + compute the frequency with which `y_pred` matches `y_true`. This frequency is + ultimately returned as `binary accuracy`: an idempotent operation that simply + divides `total` by `count`. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + ``` + + Usage with the compile API: + + ```python + model = keras.Model(inputs, outputs) + model.compile('sgd', loss='mse', metrics=[keras.metrics.Accuracy()]) + ``` + """ + + def __init__(self, name='accuracy', dtype=None): + super(Accuracy, self).__init__(accuracy, name, dtype=dtype) + + +class BinaryAccuracy(MeanMetricWrapper): + """Calculates how often predictions matches labels. + + For example, if `y_true` is [1, 1, 0, 0] and `y_pred` is [0.98, 1, 0, 0.6] + then the binary accuracy is 3/4 or .75. If the weights were specified as + [1, 0, 0, 1] then the binary accuracy would be 1/2 or .5. + + This metric creates two local variables, `total` and `count` that are used to + compute the frequency with which `y_pred` matches `y_true`. This frequency is + ultimately returned as `binary accuracy`: an idempotent operation that simply + divides `total` by `count`. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + Usage with the compile API: + + ```python + model = keras.Model(inputs, outputs) + model.compile('sgd', loss='mse', metrics=[keras.metrics.BinaryAccuracy()]) + ``` + + # Arguments + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + threshold: (Optional) Float representing the threshold for deciding + whether prediction values are 1 or 0. + """ + + def __init__(self, name='binary_accuracy', dtype=None, threshold=0.5): + super(BinaryAccuracy, self).__init__( + binary_accuracy, name, dtype=dtype, threshold=threshold) + + +class CategoricalAccuracy(MeanMetricWrapper): + """Calculates how often predictions matches labels. + + For example, if `y_true` is [[0, 0, 1], [0, 1, 0]] and `y_pred` is + [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] then the categorical accuracy is 1/2 or .5. + If the weights were specified as [0.7, 0.3] then the categorical accuracy + would be .3. You can provide logits of classes as `y_pred`, since argmax of + logits and probabilities are same. + + This metric creates two local variables, `total` and `count` that are used to + compute the frequency with which `y_pred` matches `y_true`. This frequency is + ultimately returned as `categorical accuracy`: an idempotent operation that + simply divides `total` by `count`. + + `y_pred` and `y_true` should be passed in as vectors of probabilities, rather + than as labels. If necessary, use `K.one_hot` to expand `y_true` as a vector. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + Usage with the compile API: + + ```python + model = keras.Model(inputs, outputs) + model.compile( + 'sgd', + loss='mse', + metrics=[keras.metrics.CategoricalAccuracy()]) + ``` + + # Arguments + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + """ + + def __init__(self, name='categorical_accuracy', dtype=None): + super(CategoricalAccuracy, self).__init__( + categorical_accuracy, name, dtype=dtype) + + +class SparseCategoricalAccuracy(MeanMetricWrapper): + """Calculates how often predictions matches integer labels. + + For example, if `y_true` is [[2], [1]] and `y_pred` is + [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] then the categorical accuracy is 1/2 or .5. + If the weights were specified as [0.7, 0.3] then the categorical accuracy + would be .3. You can provide logits of classes as `y_pred`, since argmax of + logits and probabilities are same. + + This metric creates two local variables, `total` and `count` that are used to + compute the frequency with which `y_pred` matches `y_true`. This frequency is + ultimately returned as `sparse categorical accuracy`: an idempotent operation + that simply divides `total` by `count`. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + Usage with the compile API: + + ```python + model = keras.Model(inputs, outputs) + model.compile( + 'sgd', + loss='mse', + metrics=[keras.metrics.SparseCategoricalAccuracy()]) + ``` + """ + + def __init__(self, name='sparse_categorical_accuracy', dtype=None): + super(SparseCategoricalAccuracy, self).__init__( + sparse_categorical_accuracy, name, dtype=dtype) + + +class TopKCategoricalAccuracy(MeanMetricWrapper): + """Computes how often targets are in the top `K` predictions. + + Usage with the compile API: + + ```python + model = keras.Model(inputs, outputs) + model.compile('sgd', metrics=[keras.metrics.TopKCategoricalAccuracy()]) + ``` + + # Arguments + k: (Optional) Number of top elements to look at for computing accuracy. + Defaults to 5. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + """ + + def __init__(self, k=5, name='top_k_categorical_accuracy', dtype=None): + super(TopKCategoricalAccuracy, self).__init__( + top_k_categorical_accuracy, name, dtype=dtype, k=k) + + +class SparseTopKCategoricalAccuracy(MeanMetricWrapper): + """Computes how often integer targets are in the top `K` predictions. + + Usage with the compile API: + + ```python + model = keras.Model(inputs, outputs) + model.compile( + 'sgd', + metrics=[keras.metrics.SparseTopKCategoricalAccuracy()]) + ``` + + # Arguments + k: (Optional) Number of top elements to look at for computing accuracy. + Defaults to 5. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + """ + + def __init__(self, k=5, name='sparse_top_k_categorical_accuracy', dtype=None): + super(SparseTopKCategoricalAccuracy, self).__init__( + sparse_top_k_categorical_accuracy, name, dtype=dtype, k=k) + + +def accuracy(y_true, y_pred): + if not K.is_tensor(y_pred): + y_pred = K.constant(y_pred) + y_true = K.cast(y_true, y_pred.dtype) + return K.cast(K.equal(y_true, y_pred), K.floatx()) + + +def binary_accuracy(y_true, y_pred, threshold=0.5): + if threshold != 0.5: + threshold = K.cast(threshold, y_pred.dtype) + y_pred = K.cast(y_pred > threshold, y_pred.dtype) return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1) diff --git a/tests/keras/metrics_test.py b/tests/keras/metrics_test.py index 99d6bfe018c9..e172466ef177 100644 --- a/tests/keras/metrics_test.py +++ b/tests/keras/metrics_test.py @@ -177,6 +177,134 @@ def test_multiple_instances(self): assert K.eval(m.count) == 1 +class TestAccuracy(object): + + def test_accuracy(self): + acc_obj = metrics.Accuracy(name='my_acc') + + # check config + assert acc_obj.name == 'my_acc' + assert acc_obj.stateful + assert len(acc_obj.weights) == 2 + assert acc_obj.dtype == 'float32' + + # verify that correct value is returned + result = K.eval(acc_obj([[1], [2], [3], [4]], [[1], [2], [3], [4]])) + assert result == 1 # 2/2 + + # Check save and restore config + a2 = metrics.Accuracy.from_config(acc_obj.get_config()) + assert a2.name == 'my_acc' + assert a2.stateful + assert len(a2.weights) == 2 + assert a2.dtype, 'float32' + + # check with sample_weight + result_t = acc_obj([[2], [1]], [[2], [0]], sample_weight=[[0.5], [0.2]]) + result = K.eval(result_t) + assert np.isclose(result, 4.5 / 4.7, atol=1e-3) + + def test_binary_accuracy(self): + acc_obj = metrics.BinaryAccuracy(name='my_acc') + + # check config + assert acc_obj.name == 'my_acc' + assert acc_obj.stateful + assert len(acc_obj.weights) == 2 + assert acc_obj.dtype == 'float32' + + # verify that correct value is returned + result_t = acc_obj([[1], [0]], [[1], [0]]) + result = K.eval(result_t) + assert result == 1 # 2/2 + + # check y_pred squeeze + result_t = acc_obj([[1], [1]], [[[1]], [[0]]]) + result = K.eval(result_t) + assert np.isclose(result, 3. / 4., atol=1e-3) + + # check y_true squeeze + result_t = acc_obj([[[1]], [[1]]], [[1], [0]]) + result = K.eval(result_t) + assert np.isclose(result, 4. / 6., atol=1e-3) + + # check with sample_weight + result_t = acc_obj([[1], [1]], [[1], [0]], [[0.5], [0.2]]) + result = K.eval(result_t) + assert np.isclose(result, 4.5 / 6.7, atol=1e-3) + + def test_binary_accuracy_threshold(self): + acc_obj = metrics.BinaryAccuracy(threshold=0.7) + result_t = acc_obj([[1], [1], [0], [0]], [[0.9], [0.6], [0.4], [0.8]]) + result = K.eval(result_t) + assert np.isclose(result, 0.5, atol=1e-3) + + def test_categorical_accuracy(self): + acc_obj = metrics.CategoricalAccuracy(name='my_acc') + + # check config + assert acc_obj.name == 'my_acc' + assert acc_obj.stateful + assert len(acc_obj.weights) == 2 + assert acc_obj.dtype == 'float32' + + # verify that correct value is returned + result_t = acc_obj([[0, 0, 1], [0, 1, 0]], + [[0.1, 0.1, 0.8], [0.05, 0.95, 0]]) + result = K.eval(result_t) + assert result == 1 # 2/2 + + # check with sample_weight + result_t = acc_obj([[0, 0, 1], [0, 1, 0]], + [[0.1, 0.1, 0.8], [0.05, 0, 0.95]], + [[0.5], [0.2]]) + result = K.eval(result_t) + assert np.isclose(result, 2.5 / 2.7, atol=1e-3) # 2.5/2.7 + + def test_sparse_categorical_accuracy(self): + acc_obj = metrics.SparseCategoricalAccuracy(name='my_acc') + + # check config + assert acc_obj.name == 'my_acc' + assert acc_obj.stateful + assert len(acc_obj.weights) == 2 + assert acc_obj.dtype == 'float32' + + # verify that correct value is returned + result_t = acc_obj([[2], [1]], + [[0.1, 0.1, 0.8], + [0.05, 0.95, 0]]) + result = K.eval(result_t) + assert result == 1 # 2/2 + + # check with sample_weight + result_t = acc_obj([[2], [1]], + [[0.1, 0.1, 0.8], [0.05, 0, 0.95]], + [[0.5], [0.2]]) + result = K.eval(result_t) + assert np.isclose(result, 2.5 / 2.7, atol=1e-3) + + def test_sparse_categorical_accuracy_mismatched_dims(self): + acc_obj = metrics.SparseCategoricalAccuracy(name='my_acc') + + # check config + assert acc_obj.name == 'my_acc' + assert acc_obj.stateful + assert len(acc_obj.weights) == 2 + assert acc_obj.dtype == 'float32' + + # verify that correct value is returned + result_t = acc_obj([2, 1], [[0.1, 0.1, 0.8], [0.05, 0.95, 0]]) + result = K.eval(result_t) + assert result == 1 # 2/2 + + # check with sample_weight + result_t = acc_obj([2, 1], [[0.1, 0.1, 0.8], [0.05, 0, 0.95]], + [[0.5], [0.2]]) + result = K.eval(result_t) + assert np.isclose(result, 2.5 / 2.7, atol=1e-3) + + class TestMeanSquaredErrorTest(object): def test_config(self): @@ -204,3 +332,86 @@ def test_weighted(self): sample_weight = (1., 1.5, 2., 2.5) result = mse_obj(y_true, y_pred, sample_weight=sample_weight) np.isclose(0.54285, K.eval(result), atol=1e-5) + + +class TestTopKCategoricalAccuracy(object): + + def test_config(self): + a_obj = metrics.TopKCategoricalAccuracy(name='topkca', dtype='int32') + assert a_obj.name == 'topkca' + assert a_obj.dtype == 'int32' + + a_obj2 = metrics.TopKCategoricalAccuracy.from_config(a_obj.get_config()) + assert a_obj2.name == 'topkca' + assert a_obj2.dtype == 'int32' + + def test_correctness(self): + a_obj = metrics.TopKCategoricalAccuracy() + y_true = [[0, 0, 1], [0, 1, 0]] + y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] + + result = a_obj(y_true, y_pred) + assert 1 == K.eval(result) # both the samples match + + # With `k` < 5. + a_obj = metrics.TopKCategoricalAccuracy(k=1) + result = a_obj(y_true, y_pred) + assert 0.5 == K.eval(result) # only sample #2 matches + + # With `k` > 5. + y_true = ([[0, 0, 1, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0]]) + y_pred = [[0.5, 0.9, 0.1, 0.7, 0.6, 0.5, 0.4], + [0.05, 0.95, 0, 0, 0, 0, 0]] + a_obj = metrics.TopKCategoricalAccuracy(k=6) + result = a_obj(y_true, y_pred) + assert 0.5 == K.eval(result) # only 1 sample matches. + + def test_weighted(self): + a_obj = metrics.TopKCategoricalAccuracy(k=2) + y_true = [[0, 1, 0], [1, 0, 0], [0, 0, 1]] + y_pred = [[0, 0.9, 0.1], [0, 0.9, 0.1], [0, 0.9, 0.1]] + sample_weight = (1.0, 0.0, 1.0) + result = a_obj(y_true, y_pred, sample_weight=sample_weight) + assert np.allclose(1.0, K.eval(result), atol=1e-5) + + +class TestSparseTopKCategoricalAccuracy(object): + + def test_config(self): + a_obj = metrics.SparseTopKCategoricalAccuracy( + name='stopkca', dtype='int32') + assert a_obj.name == 'stopkca' + assert a_obj.dtype == 'int32' + + a_obj2 = metrics.SparseTopKCategoricalAccuracy.from_config( + a_obj.get_config()) + assert a_obj2.name == 'stopkca' + assert a_obj2.dtype == 'int32' + + def test_correctness(self): + a_obj = metrics.SparseTopKCategoricalAccuracy() + y_true = [2, 1] + y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] + + result = a_obj(y_true, y_pred) + assert 1 == K.eval(result) # both the samples match + + # With `k` < 5. + a_obj = metrics.SparseTopKCategoricalAccuracy(k=1) + result = a_obj(y_true, y_pred) + assert 0.5 == K.eval(result) # only sample #2 matches + + # With `k` > 5. + y_pred = [[0.5, 0.9, 0.1, 0.7, 0.6, 0.5, 0.4], + [0.05, 0.95, 0, 0, 0, 0, 0]] + a_obj = metrics.SparseTopKCategoricalAccuracy(k=6) + result = a_obj(y_true, y_pred) + assert 0.5 == K.eval(result) # only 1 sample matches. + + def test_weighted(self): + a_obj = metrics.SparseTopKCategoricalAccuracy(k=2) + y_true = [1, 0, 2] + y_pred = [[0, 0.9, 0.1], [0, 0.9, 0.1], [0, 0.9, 0.1]] + sample_weight = (1.0, 0.0, 1.0) + result = a_obj(y_true, y_pred, sample_weight=sample_weight) + assert np.allclose(1.0, K.eval(result), atol=1e-5)