Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Adding SparseCategoricalAccuracy metric class.
  • Loading branch information
pavithrasv committed Aug 30, 2019
commit fa48da7cfeff38c4077407da210f5b3e15911f9c
33 changes: 33 additions & 0 deletions keras/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,39 @@ def __init__(self, name='categorical_accuracy', dtype=None):
categorical_accuracy, name, dtype=dtype)


class SparseCategoricalAccuracy(MeanMetricWrapper):
"""Calculates how often predictions matches integer labels.

For example, if `y_true` is [[2], [1]] and `y_pred` is
[[0.1, 0.9, 0.8], [0.05, 0.95, 0]] then the categorical accuracy is 1/2 or .5.
If the weights were specified as [0.7, 0.3] then the categorical accuracy
would be .3. You can provide logits of classes as `y_pred`, since argmax of
logits and probabilities are same.

This metric creates two local variables, `total` and `count` that are used to
compute the frequency with which `y_pred` matches `y_true`. This frequency is
ultimately returned as `sparse categorical accuracy`: an idempotent operation
that simply divides `total` by `count`.

If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.

Usage with the compile API:

```python
model = keras.Model(inputs, outputs)
model.compile(
'sgd',
loss='mse',
metrics=[keras.metrics.SparseCategoricalAccuracy()])
```
"""

def __init__(self, name='sparse_categorical_accuracy', dtype=None):
super(SparseCategoricalAccuracy, self).__init__(
sparse_categorical_accuracy, name, dtype=dtype)


def accuracy(y_true, y_pred):
if not K.is_tensor(y_pred):
y_pred = K.constant(y_pred)
Expand Down
45 changes: 44 additions & 1 deletion tests/keras/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def test_categorical_accuracy(self):
assert acc_obj.name == 'my_acc'
assert acc_obj.stateful
assert len(acc_obj.weights) == 2
assert acc_obj.dtype, 'float32'
assert acc_obj.dtype == 'float32'

# verify that correct value is returned
result_t = acc_obj([[0, 0, 1], [0, 1, 0]],
Expand All @@ -261,6 +261,49 @@ def test_categorical_accuracy(self):
result = K.eval(result_t)
assert np.isclose(result, 2.5 / 2.7, atol=1e-3) # 2.5/2.7

def test_sparse_categorical_accuracy(self):
acc_obj = metrics.SparseCategoricalAccuracy(name='my_acc')

# check config
assert acc_obj.name == 'my_acc'
assert acc_obj.stateful
assert len(acc_obj.weights) == 2
assert acc_obj.dtype == 'float32'

# verify that correct value is returned
result_t = acc_obj([[2], [1]],
[[0.1, 0.1, 0.8],
[0.05, 0.95, 0]])
result = K.eval(result_t)
assert result == 1 # 2/2

# check with sample_weight
result_t = acc_obj([[2], [1]],
[[0.1, 0.1, 0.8], [0.05, 0, 0.95]],
[[0.5], [0.2]])
result = K.eval(result_t)
assert np.isclose(result, 2.5 / 2.7, atol=1e-3)

def test_sparse_categorical_accuracy_mismatched_dims(self):
acc_obj = metrics.SparseCategoricalAccuracy(name='my_acc')

# check config
assert acc_obj.name == 'my_acc'
assert acc_obj.stateful
assert len(acc_obj.weights) == 2
assert acc_obj.dtype == 'float32'

# verify that correct value is returned
result_t = acc_obj([2, 1], [[0.1, 0.1, 0.8], [0.05, 0.95, 0]])
result = K.eval(result_t)
assert result == 1 # 2/2

# check with sample_weight
result_t = acc_obj([2, 1], [[0.1, 0.1, 0.8], [0.05, 0, 0.95]],
[[0.5], [0.2]])
result = K.eval(result_t)
assert np.isclose(result, 2.5 / 2.7, atol=1e-3)


class TestMeanSquaredErrorTest(object):

Expand Down