Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Adding CategoricalAccuracy metric class.
  • Loading branch information
pavithrasv committed Aug 30, 2019
commit 9689ab361418cebbe16975ae7cf2911703d37bc8
42 changes: 42 additions & 0 deletions keras/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,48 @@ def __init__(self, name='binary_accuracy', dtype=None, threshold=0.5):
binary_accuracy, name, dtype=dtype, threshold=threshold)


class CategoricalAccuracy(MeanMetricWrapper):
"""Calculates how often predictions matches labels.

For example, if `y_true` is [[0, 0, 1], [0, 1, 0]] and `y_pred` is
[[0.1, 0.9, 0.8], [0.05, 0.95, 0]] then the categorical accuracy is 1/2 or .5.
If the weights were specified as [0.7, 0.3] then the categorical accuracy
would be .3. You can provide logits of classes as `y_pred`, since argmax of
logits and probabilities are same.

This metric creates two local variables, `total` and `count` that are used to
compute the frequency with which `y_pred` matches `y_true`. This frequency is
ultimately returned as `categorical accuracy`: an idempotent operation that
simply divides `total` by `count`.

`y_pred` and `y_true` should be passed in as vectors of probabilities, rather
than as labels. If necessary, use `tf.one_hot` to expand `y_true` as a vector.

If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.

Usage with the compile API:

```python
model = keras.Model(inputs, outputs)
model.compile(
'sgd',
loss='mse',
metrics=[keras.metrics.CategoricalAccuracy()])
```
"""

def __init__(self, name='categorical_accuracy', dtype=None):
"""Creates a `CategoricalAccuracy` instance.

Args:
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
"""
super(CategoricalAccuracy, self).__init__(
categorical_accuracy, name, dtype=dtype)


def accuracy(y_true, y_pred):
if not K.is_tensor(y_pred):
y_pred = K.constant(y_pred)
Expand Down
22 changes: 22 additions & 0 deletions tests/keras/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,28 @@ def test_binary_accuracy_threshold(self):
result = K.eval(result_t)
assert np.isclose(result, 0.5, atol=1e-3)

def test_categorical_accuracy(self):
acc_obj = metrics.CategoricalAccuracy(name='my_acc')

# check config
assert acc_obj.name == 'my_acc'
assert acc_obj.stateful
assert len(acc_obj.weights) == 2
assert acc_obj.dtype, 'float32'

# verify that correct value is returned
result_t = acc_obj([[0, 0, 1], [0, 1, 0]],
[[0.1, 0.1, 0.8], [0.05, 0.95, 0]])
result = K.eval(result_t)
assert result == 1 # 2/2

# check with sample_weight
result_t = acc_obj([[0, 0, 1], [0, 1, 0]],
[[0.1, 0.1, 0.8], [0.05, 0, 0.95]],
[[0.5], [0.2]])
result = K.eval(result_t)
assert np.isclose(result, 2.5 / 2.7, atol=1e-3) # 2.5/2.7


class TestMeanSquaredErrorTest(object):

Expand Down