Skip to content

Commit 0792332

Browse files
pavithrasvfchollet
authored andcommitted
Adding Loss, LossFunctionWrapper, MeanSquaredError classes. (#12859)
* Adding Loss, LossFunctionWrapper, MeanSquaredError classes. * Fixing formatting issues. * Adding arguments list to MeanSquaredError. * Fix abstract method
1 parent 5430a7b commit 0792332

File tree

8 files changed

+499
-88
lines changed

8 files changed

+499
-88
lines changed

keras/backend/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@
149149
from .load_backend import name_scope
150150
from .load_backend import symbolic
151151
from .load_backend import eager
152+
from .load_backend import size
152153

153154
if backend() == 'theano':
154155
from .load_backend import pattern_broadcast

keras/backend/cntk_backend.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,10 @@ def cast(x, dtype):
559559
return x
560560

561561

562+
def size(x, name=None):
563+
return sum(ones_like(x, name=name))
564+
565+
562566
def dot(x, y):
563567
if len(x.shape) > 2 or len(y.shape) > 2:
564568
y_shape = int_shape(y)

keras/backend/tensorflow_backend.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,29 @@ def ndim(x):
831831
return x.shape.rank
832832

833833

834+
def size(x, name=None):
835+
"""Returns the size of a tensor.
836+
837+
# Arguments
838+
x: Tensor or variable.
839+
name: A name for the operation (optional).
840+
841+
# Returns
842+
Size of the tensor.
843+
844+
# Examples
845+
```python
846+
>>> from keras import backend as K
847+
>>> val = np.array([[1, 2], [3, 4]])
848+
>>> kvar = K.variable(value=val)
849+
>>> K.size(inputs)
850+
<tf.Tensor: id=9, shape=(), dtype=int32, numpy=4>
851+
```
852+
853+
"""
854+
return tf.size(x, name=name)
855+
856+
834857
def dtype(x):
835858
"""Returns the dtype of a Keras tensor or variable, as a string.
836859

keras/backend/theano_backend.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,18 @@ def count_params(x):
374374

375375
def cast(x, dtype):
376376
return T.cast(x, dtype)
377+
378+
379+
def size(x, name=None):
380+
"""Returns the size of a tensor.
381+
# Arguments
382+
x: The input tensor.
383+
name: A name for the operation (optional).
384+
# Returns
385+
Size of the tensor.
386+
```
387+
"""
388+
return sum(ones_like(x, name=name))
377389

378390

379391
# UPDATES OPS

keras/losses.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,174 @@
44
from __future__ import division
55
from __future__ import print_function
66

7+
import abc
78
import six
9+
810
from . import backend as K
11+
from .utils import losses_utils
912
from .utils.generic_utils import deserialize_keras_object
1013
from .utils.generic_utils import serialize_keras_object
1114

1215

16+
@six.add_metaclass(abc.ABCMeta)
17+
class Loss(object):
18+
"""Loss base class.
19+
20+
To be implemented by subclasses:
21+
* `call()`: Contains the logic for loss calculation using `y_true`, `y_pred`.
22+
23+
Example subclass implementation:
24+
```python
25+
class MeanSquaredError(Loss):
26+
def call(self, y_true, y_pred):
27+
y_pred = ops.convert_to_tensor(y_pred)
28+
y_true = math_ops.cast(y_true, y_pred.dtype)
29+
return K.mean(math_ops.square(y_pred - y_true), axis=-1)
30+
```
31+
32+
# Arguments
33+
reduction: (Optional) Type of loss Reduction to apply to loss.
34+
Default value is `SUM_OVER_BATCH_SIZE`.
35+
name: Optional name for the op.
36+
"""
37+
38+
def __init__(self,
39+
reduction=losses_utils.Reduction.SUM_OVER_BATCH_SIZE,
40+
name=None):
41+
self.reduction = reduction
42+
self.name = name
43+
44+
def __call__(self, y_true, y_pred, sample_weight=None):
45+
"""Invokes the `Loss` instance.
46+
47+
# Arguments
48+
y_true: Ground truth values.
49+
y_pred: The predicted values.
50+
sample_weight: Optional `Tensor` whose rank is either 0, or the same rank
51+
as `y_true`, or is broadcastable to `y_true`. `sample_weight` acts as a
52+
coefficient for the loss. If a scalar is provided, then the loss is
53+
simply scaled by the given value. If `sample_weight` is a tensor of size
54+
`[batch_size]`, then the total loss for each sample of the batch is
55+
rescaled by the corresponding element in the `sample_weight` vector. If
56+
the shape of `sample_weight` matches the shape of `y_pred`, then the
57+
loss of each measurable element of `y_pred` is scaled by the
58+
corresponding value of `sample_weight`.
59+
60+
# Returns
61+
Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
62+
shape as `y_true`; otherwise, it is scalar.
63+
64+
# Raises
65+
ValueError: If the shape of `sample_weight` is invalid.
66+
"""
67+
# If we are wrapping a lambda function strip '<>' from the name as it is not
68+
# accepted in scope name.
69+
scope_name = 'lambda' if self.name == '<lambda>' else self.name
70+
with K.name_scope(scope_name):
71+
losses = self.call(y_true, y_pred)
72+
return losses_utils.compute_weighted_loss(
73+
losses, sample_weight, reduction=self.reduction)
74+
75+
@classmethod
76+
def from_config(cls, config):
77+
"""Instantiates a `Loss` from its config (output of `get_config()`).
78+
79+
# Arguments
80+
config: Output of `get_config()`.
81+
82+
# Returns
83+
A `Loss` instance.
84+
"""
85+
return cls(**config)
86+
87+
def get_config(self):
88+
return {'reduction': self.reduction, 'name': self.name}
89+
90+
@abc.abstractmethod
91+
def call(self, y_true, y_pred):
92+
"""Invokes the `Loss` instance.
93+
94+
# Arguments
95+
y_true: Ground truth values, with the same shape as 'y_pred'.
96+
y_pred: The predicted values.
97+
"""
98+
raise NotImplementedError('Must be implemented in subclasses.')
99+
100+
101+
class LossFunctionWrapper(Loss):
102+
"""Wraps a loss function in the `Loss` class.
103+
104+
# Arguments
105+
fn: The loss function to wrap, with signature `fn(y_true, y_pred,
106+
**kwargs)`.
107+
reduction: (Optional) Type of loss reduction to apply to loss.
108+
Default value is `SUM_OVER_BATCH_SIZE`.
109+
name: (Optional) name for the loss.
110+
**kwargs: The keyword arguments that are passed on to `fn`.
111+
"""
112+
113+
def __init__(self,
114+
fn,
115+
reduction=losses_utils.Reduction.SUM_OVER_BATCH_SIZE,
116+
name=None,
117+
**kwargs):
118+
super(LossFunctionWrapper, self).__init__(reduction=reduction, name=name)
119+
self.fn = fn
120+
self._fn_kwargs = kwargs
121+
122+
def call(self, y_true, y_pred):
123+
"""Invokes the `LossFunctionWrapper` instance.
124+
125+
# Arguments
126+
y_true: Ground truth values.
127+
y_pred: The predicted values.
128+
129+
# Returns
130+
Loss values per sample.
131+
"""
132+
return self.fn(y_true, y_pred, **self._fn_kwargs)
133+
134+
def get_config(self):
135+
config = {}
136+
for k, v in six.iteritems(self._fn_kwargs):
137+
config[k] = K.eval(v) if is_tensor_or_variable(v) else v
138+
base_config = super(LossFunctionWrapper, self).get_config()
139+
return dict(list(base_config.items()) + list(config.items()))
140+
141+
142+
class MeanSquaredError(LossFunctionWrapper):
143+
"""Computes the mean of squares of errors between labels and predictions.
144+
145+
For example, if `y_true` is [0., 0., 1., 1.] and `y_pred` is [1., 1., 1., 0.]
146+
then the mean squared error value is 3/4 (0.75).
147+
148+
Standalone usage:
149+
150+
```python
151+
mse = keras.losses.MeanSquaredError()
152+
loss = mse([0., 0., 1., 1.], [1., 1., 1., 0.])
153+
```
154+
155+
Usage with the `compile` API:
156+
157+
```python
158+
model = keras.Model(inputs, outputs)
159+
model.compile('sgd', loss=keras.losses.MeanSquaredError())
160+
```
161+
162+
# Arguments
163+
reduction: (Optional) Type of loss reduction to apply to loss.
164+
Default value is `SUM_OVER_BATCH_SIZE`.
165+
name: (Optional) name for the loss.
166+
"""
167+
168+
def __init__(self,
169+
reduction=losses_utils.Reduction.SUM_OVER_BATCH_SIZE,
170+
name='mean_squared_error'):
171+
super(MeanSquaredError, self).__init__(
172+
mean_squared_error, name=name, reduction=reduction)
173+
174+
13175
def mean_squared_error(y_true, y_pred):
14176
return K.mean(K.square(y_pred - y_true), axis=-1)
15177

keras/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from . import data_utils
55
from . import io_utils
66
from . import conv_utils
7+
from . import losses_utils
78

89
# Globally-importable utils.
910
from .io_utils import HDF5Matrix

keras/utils/losses_utils.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
"""Utilities related to losses."""
2+
from __future__ import absolute_import
3+
from __future__ import division
4+
from __future__ import print_function
5+
6+
import numpy as np
7+
8+
from .. import backend as K
9+
10+
11+
class Reduction(object):
12+
"""Types of loss reduction.
13+
14+
Contains the following values:
15+
16+
* `NONE`: Un-reduced weighted losses with the same shape as input. When this
17+
reduction type used with built-in Keras training loops like
18+
`fit`/`evaluate`, the unreduced vector loss is passed to the optimizer but
19+
the reported loss will be a scalar value.
20+
* `SUM`: Scalar sum of weighted losses.
21+
* `SUM_OVER_BATCH_SIZE`: Scalar `SUM` divided by number of elements in losses.
22+
"""
23+
24+
NONE = 'none'
25+
SUM = 'sum'
26+
SUM_OVER_BATCH_SIZE = 'sum_over_batch_size'
27+
28+
@classmethod
29+
def all(cls):
30+
return (cls.NONE, cls.SUM, cls.SUM_OVER_BATCH_SIZE)
31+
32+
@classmethod
33+
def validate(cls, key):
34+
if key not in cls.all():
35+
raise ValueError('Invalid Reduction Key %s.' % key)
36+
37+
38+
def squeeze_or_expand_dimensions(y_pred, y_true, sample_weight):
39+
"""Squeeze or expand last dimension if needed.
40+
41+
1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1.
42+
2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1
43+
from the new rank of `y_pred`.
44+
If `sample_weight` is scalar, it is kept scalar.
45+
46+
# Arguments
47+
y_pred: Predicted values, a `Tensor` of arbitrary dimensions.
48+
y_true: Optional label `Tensor` whose dimensions match `y_pred`.
49+
sample_weight: Optional weight scalar or `Tensor` whose dimensions match
50+
`y_pred`.
51+
52+
# Returns
53+
Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has
54+
the last dimension squeezed, `sample_weight` could be extended by one
55+
dimension.
56+
"""
57+
if y_true is not None:
58+
y_pred_rank = K.ndim(y_pred)
59+
y_pred_shape = K.int_shape(y_pred)
60+
y_true_rank = K.ndim(y_true)
61+
y_true_shape = K.int_shape(y_true)
62+
63+
if (y_pred_rank - y_true_rank == 1) and (y_pred_shape[-1] == 1):
64+
y_pred = K.squeeze(y_pred, -1)
65+
elif (y_true_rank - y_pred_rank == 1) and (y_true_shape[-1] == 1):
66+
y_true = K.squeeze(y_true, -1)
67+
68+
if sample_weight is None:
69+
return y_pred, y_true, None
70+
71+
y_pred_rank = K.ndim(y_pred)
72+
weights_rank = K.ndim(sample_weight)
73+
if weights_rank != 0:
74+
if weights_rank - y_pred_rank == 1:
75+
sample_weight = K.squeeze(sample_weight, -1)
76+
elif y_pred_rank - weights_rank == 1:
77+
sample_weight = K.expand_dims(sample_weight, -1)
78+
return y_pred, y_true, sample_weight
79+
80+
81+
def _num_elements(losses):
82+
"""Computes the number of elements in `losses` tensor."""
83+
with K.name_scope('num_elements') as scope:
84+
return K.cast(K.size(losses, name=scope), losses.dtype)
85+
86+
87+
def reduce_weighted_loss(weighted_losses, reduction=Reduction.SUM_OVER_BATCH_SIZE):
88+
"""Reduces the individual weighted loss measurements."""
89+
if reduction == Reduction.NONE:
90+
loss = weighted_losses
91+
else:
92+
loss = K.sum(weighted_losses)
93+
if reduction == Reduction.SUM_OVER_BATCH_SIZE:
94+
loss = loss / _num_elements(weighted_losses)
95+
return loss
96+
97+
98+
def compute_weighted_loss(losses,
99+
sample_weight=None,
100+
reduction=Reduction.SUM_OVER_BATCH_SIZE,
101+
name=None):
102+
"""Computes the weighted loss.
103+
104+
# Arguments
105+
losses: `Tensor` of shape `[batch_size, d1, ... dN]`.
106+
sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as
107+
` losses`, or be broadcastable to `losses`.
108+
reduction: (Optional) Type of Reduction to apply to loss.
109+
Default value is `SUM_OVER_BATCH_SIZE`.
110+
name: Optional name for the op.
111+
112+
# Raises
113+
ValueError: If the shape of `sample_weight` is not compatible with `losses`.
114+
115+
# Returns
116+
Weighted loss `Tensor` of the same type as `losses`. If `reduction` is
117+
`NONE`, this has the same shape as `losses`; otherwise, it is scalar.
118+
"""
119+
Reduction.validate(reduction)
120+
if sample_weight is None:
121+
sample_weight = 1.0
122+
with K.name_scope(name or 'weighted_loss'):
123+
input_dtype = K.dtype(losses)
124+
losses = K.cast(losses, K.floatx())
125+
sample_weight = K.cast(sample_weight, K.floatx())
126+
127+
# Update dimensions of `sample_weight` to match with `losses` if possible.
128+
losses, _, sample_weight = squeeze_or_expand_dimensions(
129+
losses, None, sample_weight)
130+
131+
weighted_losses = losses * sample_weight
132+
# Apply reduction function to the individual weighted losses.
133+
loss = reduce_weighted_loss(weighted_losses, reduction)
134+
# Convert the result back to the input type.
135+
loss = K.cast(loss, input_dtype)
136+
return loss

0 commit comments

Comments
 (0)