Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Adding TopK Accuracy metric classes.
  • Loading branch information
pavithrasv committed Aug 30, 2019
commit 3cf3d645de8d8ce717d9f06bb18b58cde3348ade
68 changes: 59 additions & 9 deletions keras/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def update_state(self, y_true, y_pred, sample_weight=None):
def get_config(self):
config = {}
for k, v in six.iteritems(self._fn_kwargs):
config[k] = K.eval(v) if is_tensor_or_variable(v) else v
config[k] = K.eval(v) if K.is_tensor(v) else v
base_config = super(MeanMetricWrapper, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

Expand Down Expand Up @@ -405,11 +405,11 @@ class BinaryAccuracy(MeanMetricWrapper):
def __init__(self, name='binary_accuracy', dtype=None, threshold=0.5):
"""Creates a `BinaryAccuracy` instance.

Args:
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
threshold: (Optional) Float representing the threshold for deciding
whether prediction values are 1 or 0.
# Arguments
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
threshold: (Optional) Float representing the threshold for deciding
whether prediction values are 1 or 0.
"""
super(BinaryAccuracy, self).__init__(
binary_accuracy, name, dtype=dtype, threshold=threshold)
Expand Down Expand Up @@ -449,9 +449,9 @@ class CategoricalAccuracy(MeanMetricWrapper):
def __init__(self, name='categorical_accuracy', dtype=None):
"""Creates a `CategoricalAccuracy` instance.

Args:
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
# Arguments
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
"""
super(CategoricalAccuracy, self).__init__(
categorical_accuracy, name, dtype=dtype)
Expand Down Expand Up @@ -490,6 +490,56 @@ def __init__(self, name='sparse_categorical_accuracy', dtype=None):
sparse_categorical_accuracy, name, dtype=dtype)


class TopKCategoricalAccuracy(MeanMetricWrapper):
"""Computes how often targets are in the top `K` predictions.

Usage with the compile API:

```python
model = keras.Model(inputs, outputs)
model.compile('sgd', metrics=[keras.metrics.TopKCategoricalAccuracy()])
```
"""

def __init__(self, k=5, name='top_k_categorical_accuracy', dtype=None):
"""Creates a `TopKCategoricalAccuracy` instance.

# Arguments
k: (Optional) Number of top elements to look at for computing accuracy.
Defaults to 5.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
"""
super(TopKCategoricalAccuracy, self).__init__(
top_k_categorical_accuracy, name, dtype=dtype, k=k)


class SparseTopKCategoricalAccuracy(MeanMetricWrapper):
"""Computes how often integer targets are in the top `K` predictions.

Usage with the compile API:

```python
model = keras.Model(inputs, outputs)
model.compile(
'sgd',
metrics=[keras.metrics.SparseTopKCategoricalAccuracy()])
```
"""

def __init__(self, k=5, name='sparse_top_k_categorical_accuracy', dtype=None):
"""Creates a `SparseTopKCategoricalAccuracy` instance.

# Arguments
k: (Optional) Number of top elements to look at for computing accuracy.
Defaults to 5.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
"""
super(SparseTopKCategoricalAccuracy, self).__init__(
sparse_top_k_categorical_accuracy, name, dtype=dtype, k=k)


def accuracy(y_true, y_pred):
if not K.is_tensor(y_pred):
y_pred = K.constant(y_pred)
Expand Down
83 changes: 83 additions & 0 deletions tests/keras/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,3 +332,86 @@ def test_weighted(self):
sample_weight = (1., 1.5, 2., 2.5)
result = mse_obj(y_true, y_pred, sample_weight=sample_weight)
np.isclose(0.54285, K.eval(result), atol=1e-5)


class TestTopKCategoricalAccuracy(object):

def test_config(self):
a_obj = metrics.TopKCategoricalAccuracy(name='topkca', dtype='int32')
assert a_obj.name == 'topkca'
assert a_obj.dtype == 'int32'

a_obj2 = metrics.TopKCategoricalAccuracy.from_config(a_obj.get_config())
assert a_obj2.name == 'topkca'
assert a_obj2.dtype == 'int32'

def test_correctness(self):
a_obj = metrics.TopKCategoricalAccuracy()
y_true = [[0, 0, 1], [0, 1, 0]]
y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]

result = a_obj(y_true, y_pred)
assert 1 == K.eval(result) # both the samples match

# With `k` < 5.
a_obj = metrics.TopKCategoricalAccuracy(k=1)
result = a_obj(y_true, y_pred)
assert 0.5 == K.eval(result) # only sample #2 matches

# With `k` > 5.
y_true = ([[0, 0, 1, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0]])
y_pred = [[0.5, 0.9, 0.1, 0.7, 0.6, 0.5, 0.4],
[0.05, 0.95, 0, 0, 0, 0, 0]]
a_obj = metrics.TopKCategoricalAccuracy(k=6)
result = a_obj(y_true, y_pred)
assert 0.5 == K.eval(result) # only 1 sample matches.

def test_weighted(self):
a_obj = metrics.TopKCategoricalAccuracy(k=2)
y_true = [[0, 1, 0], [1, 0, 0], [0, 0, 1]]
y_pred = [[0, 0.9, 0.1], [0, 0.9, 0.1], [0, 0.9, 0.1]]
sample_weight = (1.0, 0.0, 1.0)
result = a_obj(y_true, y_pred, sample_weight=sample_weight)
assert np.allclose(1.0, K.eval(result), atol=1e-5)


class TestSparseTopKCategoricalAccuracy(object):

def test_config(self):
a_obj = metrics.SparseTopKCategoricalAccuracy(
name='stopkca', dtype='int32')
assert a_obj.name == 'stopkca'
assert a_obj.dtype == 'int32'

a_obj2 = metrics.SparseTopKCategoricalAccuracy.from_config(
a_obj.get_config())
assert a_obj2.name == 'stopkca'
assert a_obj2.dtype == 'int32'

def test_correctness(self):
a_obj = metrics.SparseTopKCategoricalAccuracy()
y_true = [2, 1]
y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]

result = a_obj(y_true, y_pred)
assert 1 == K.eval(result) # both the samples match

# With `k` < 5.
a_obj = metrics.SparseTopKCategoricalAccuracy(k=1)
result = a_obj(y_true, y_pred)
assert 0.5 == K.eval(result) # only sample #2 matches

# With `k` > 5.
y_pred = [[0.5, 0.9, 0.1, 0.7, 0.6, 0.5, 0.4],
[0.05, 0.95, 0, 0, 0, 0, 0]]
a_obj = metrics.SparseTopKCategoricalAccuracy(k=6)
result = a_obj(y_true, y_pred)
assert 0.5 == K.eval(result) # only 1 sample matches.

def test_weighted(self):
a_obj = metrics.SparseTopKCategoricalAccuracy(k=2)
y_true = [1, 0, 2]
y_pred = [[0, 0.9, 0.1], [0, 0.9, 0.1], [0, 0.9, 0.1]]
sample_weight = (1.0, 0.0, 1.0)
result = a_obj(y_true, y_pred, sample_weight=sample_weight)
assert np.allclose(1.0, K.eval(result), atol=1e-5)