Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 202 additions & 3 deletions keras/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ class Mean(Reduce):
Usage:

```python
m = tf.keras.metrics.Mean()
m = keras.metrics.Mean()
m.update_state([1, 3, 5, 7])
m.result()
```
Expand Down Expand Up @@ -321,7 +321,7 @@ def update_state(self, y_true, y_pred, sample_weight=None):
def get_config(self):
config = {}
for k, v in six.iteritems(self._fn_kwargs):
config[k] = K.eval(v) if is_tensor_or_variable(v) else v
config[k] = K.eval(v) if K.is_tensor(v) else v
base_config = super(MeanMetricWrapper, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

Expand Down Expand Up @@ -350,7 +350,206 @@ def __init__(self, name='mean_squared_error', dtype=None):
mean_squared_error, name, dtype=dtype)


def binary_accuracy(y_true, y_pred):
class Accuracy(MeanMetricWrapper):
"""Calculates how often predictions matches labels.

For example, if `y_true` is [1, 2, 3, 4] and `y_pred` is [0, 2, 3, 4]
then the accuracy is 3/4 or .75. If the weights were specified as
[1, 1, 0, 0] then the accuracy would be 1/2 or .5.

This metric creates two local variables, `total` and `count` that are used to
compute the frequency with which `y_pred` matches `y_true`. This frequency is
ultimately returned as `binary accuracy`: an idempotent operation that simply
divides `total` by `count`.

If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.

```

Usage with the compile API:

```python
model = keras.Model(inputs, outputs)
model.compile('sgd', loss='mse', metrics=[keras.metrics.Accuracy()])
```
"""

def __init__(self, name='accuracy', dtype=None):
super(Accuracy, self).__init__(accuracy, name, dtype=dtype)


class BinaryAccuracy(MeanMetricWrapper):
"""Calculates how often predictions matches labels.

For example, if `y_true` is [1, 1, 0, 0] and `y_pred` is [0.98, 1, 0, 0.6]
then the binary accuracy is 3/4 or .75. If the weights were specified as
[1, 0, 0, 1] then the binary accuracy would be 1/2 or .5.

This metric creates two local variables, `total` and `count` that are used to
compute the frequency with which `y_pred` matches `y_true`. This frequency is
ultimately returned as `binary accuracy`: an idempotent operation that simply
divides `total` by `count`.

If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.

Usage with the compile API:

```python
model = keras.Model(inputs, outputs)
model.compile('sgd', loss='mse', metrics=[keras.metrics.BinaryAccuracy()])
```
"""

def __init__(self, name='binary_accuracy', dtype=None, threshold=0.5):
"""Creates a `BinaryAccuracy` instance.

# Arguments
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
threshold: (Optional) Float representing the threshold for deciding
whether prediction values are 1 or 0.
"""
super(BinaryAccuracy, self).__init__(
binary_accuracy, name, dtype=dtype, threshold=threshold)


class CategoricalAccuracy(MeanMetricWrapper):
"""Calculates how often predictions matches labels.

For example, if `y_true` is [[0, 0, 1], [0, 1, 0]] and `y_pred` is
[[0.1, 0.9, 0.8], [0.05, 0.95, 0]] then the categorical accuracy is 1/2 or .5.
If the weights were specified as [0.7, 0.3] then the categorical accuracy
would be .3. You can provide logits of classes as `y_pred`, since argmax of
logits and probabilities are same.

This metric creates two local variables, `total` and `count` that are used to
compute the frequency with which `y_pred` matches `y_true`. This frequency is
ultimately returned as `categorical accuracy`: an idempotent operation that
simply divides `total` by `count`.

`y_pred` and `y_true` should be passed in as vectors of probabilities, rather
than as labels. If necessary, use `tf.one_hot` to expand `y_true` as a vector.

If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.

Usage with the compile API:

```python
model = keras.Model(inputs, outputs)
model.compile(
'sgd',
loss='mse',
metrics=[keras.metrics.CategoricalAccuracy()])
```
"""

def __init__(self, name='categorical_accuracy', dtype=None):
"""Creates a `CategoricalAccuracy` instance.

# Arguments
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
"""
super(CategoricalAccuracy, self).__init__(
categorical_accuracy, name, dtype=dtype)


class SparseCategoricalAccuracy(MeanMetricWrapper):
"""Calculates how often predictions matches integer labels.

For example, if `y_true` is [[2], [1]] and `y_pred` is
[[0.1, 0.9, 0.8], [0.05, 0.95, 0]] then the categorical accuracy is 1/2 or .5.
If the weights were specified as [0.7, 0.3] then the categorical accuracy
would be .3. You can provide logits of classes as `y_pred`, since argmax of
logits and probabilities are same.

This metric creates two local variables, `total` and `count` that are used to
compute the frequency with which `y_pred` matches `y_true`. This frequency is
ultimately returned as `sparse categorical accuracy`: an idempotent operation
that simply divides `total` by `count`.

If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.

Usage with the compile API:

```python
model = keras.Model(inputs, outputs)
model.compile(
'sgd',
loss='mse',
metrics=[keras.metrics.SparseCategoricalAccuracy()])
```
"""

def __init__(self, name='sparse_categorical_accuracy', dtype=None):
super(SparseCategoricalAccuracy, self).__init__(
sparse_categorical_accuracy, name, dtype=dtype)


class TopKCategoricalAccuracy(MeanMetricWrapper):
"""Computes how often targets are in the top `K` predictions.

Usage with the compile API:

```python
model = keras.Model(inputs, outputs)
model.compile('sgd', metrics=[keras.metrics.TopKCategoricalAccuracy()])
```
"""

def __init__(self, k=5, name='top_k_categorical_accuracy', dtype=None):
"""Creates a `TopKCategoricalAccuracy` instance.

# Arguments
k: (Optional) Number of top elements to look at for computing accuracy.
Defaults to 5.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
"""
super(TopKCategoricalAccuracy, self).__init__(
top_k_categorical_accuracy, name, dtype=dtype, k=k)


class SparseTopKCategoricalAccuracy(MeanMetricWrapper):
"""Computes how often integer targets are in the top `K` predictions.

Usage with the compile API:

```python
model = keras.Model(inputs, outputs)
model.compile(
'sgd',
metrics=[keras.metrics.SparseTopKCategoricalAccuracy()])
```
"""

def __init__(self, k=5, name='sparse_top_k_categorical_accuracy', dtype=None):
"""Creates a `SparseTopKCategoricalAccuracy` instance.

# Arguments
k: (Optional) Number of top elements to look at for computing accuracy.
Defaults to 5.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
"""
super(SparseTopKCategoricalAccuracy, self).__init__(
sparse_top_k_categorical_accuracy, name, dtype=dtype, k=k)


def accuracy(y_true, y_pred):
if not K.is_tensor(y_pred):
y_pred = K.constant(y_pred)
y_true = K.cast(y_true, y_pred.dtype)
return K.cast(K.equal(y_true, y_pred), K.floatx())


def binary_accuracy(y_true, y_pred, threshold=0.5):
threshold = K.cast(threshold, y_pred.dtype)
y_pred = K.cast(y_pred > threshold, y_pred.dtype)
return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1)


Expand Down
Loading