|
4 | 4 | from __future__ import division |
5 | 5 | from __future__ import print_function |
6 | 6 |
|
| 7 | +import abc |
7 | 8 | import six |
| 9 | + |
8 | 10 | from . import backend as K |
| 11 | +from .utils import losses_utils |
9 | 12 | from .utils.generic_utils import deserialize_keras_object |
10 | 13 | from .utils.generic_utils import serialize_keras_object |
11 | 14 |
|
12 | 15 |
|
| 16 | +@six.add_metaclass(abc.ABCMeta) |
| 17 | +class Loss(object): |
| 18 | + """Loss base class. |
| 19 | +
|
| 20 | + To be implemented by subclasses: |
| 21 | + * `call()`: Contains the logic for loss calculation using `y_true`, `y_pred`. |
| 22 | +
|
| 23 | + Example subclass implementation: |
| 24 | + ```python |
| 25 | + class MeanSquaredError(Loss): |
| 26 | + def call(self, y_true, y_pred): |
| 27 | + y_pred = ops.convert_to_tensor(y_pred) |
| 28 | + y_true = math_ops.cast(y_true, y_pred.dtype) |
| 29 | + return K.mean(math_ops.square(y_pred - y_true), axis=-1) |
| 30 | + ``` |
| 31 | +
|
| 32 | + # Arguments |
| 33 | + reduction: (Optional) Type of loss Reduction to apply to loss. |
| 34 | + Default value is `SUM_OVER_BATCH_SIZE`. |
| 35 | + name: Optional name for the op. |
| 36 | + """ |
| 37 | + |
| 38 | + def __init__(self, |
| 39 | + reduction=losses_utils.Reduction.SUM_OVER_BATCH_SIZE, |
| 40 | + name=None): |
| 41 | + self.reduction = reduction |
| 42 | + self.name = name |
| 43 | + |
| 44 | + def __call__(self, y_true, y_pred, sample_weight=None): |
| 45 | + """Invokes the `Loss` instance. |
| 46 | +
|
| 47 | + # Arguments |
| 48 | + y_true: Ground truth values. |
| 49 | + y_pred: The predicted values. |
| 50 | + sample_weight: Optional `Tensor` whose rank is either 0, or the same rank |
| 51 | + as `y_true`, or is broadcastable to `y_true`. `sample_weight` acts as a |
| 52 | + coefficient for the loss. If a scalar is provided, then the loss is |
| 53 | + simply scaled by the given value. If `sample_weight` is a tensor of size |
| 54 | + `[batch_size]`, then the total loss for each sample of the batch is |
| 55 | + rescaled by the corresponding element in the `sample_weight` vector. If |
| 56 | + the shape of `sample_weight` matches the shape of `y_pred`, then the |
| 57 | + loss of each measurable element of `y_pred` is scaled by the |
| 58 | + corresponding value of `sample_weight`. |
| 59 | +
|
| 60 | + # Returns |
| 61 | + Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same |
| 62 | + shape as `y_true`; otherwise, it is scalar. |
| 63 | +
|
| 64 | + # Raises |
| 65 | + ValueError: If the shape of `sample_weight` is invalid. |
| 66 | + """ |
| 67 | + # If we are wrapping a lambda function strip '<>' from the name as it is not |
| 68 | + # accepted in scope name. |
| 69 | + scope_name = 'lambda' if self.name == '<lambda>' else self.name |
| 70 | + with K.name_scope(scope_name): |
| 71 | + losses = self.call(y_true, y_pred) |
| 72 | + return losses_utils.compute_weighted_loss( |
| 73 | + losses, sample_weight, reduction=self.reduction) |
| 74 | + |
| 75 | + @classmethod |
| 76 | + def from_config(cls, config): |
| 77 | + """Instantiates a `Loss` from its config (output of `get_config()`). |
| 78 | +
|
| 79 | + # Arguments |
| 80 | + config: Output of `get_config()`. |
| 81 | +
|
| 82 | + # Returns |
| 83 | + A `Loss` instance. |
| 84 | + """ |
| 85 | + return cls(**config) |
| 86 | + |
| 87 | + def get_config(self): |
| 88 | + return {'reduction': self.reduction, 'name': self.name} |
| 89 | + |
| 90 | + @abc.abstractmethod |
| 91 | + def call(self, y_true, y_pred): |
| 92 | + """Invokes the `Loss` instance. |
| 93 | +
|
| 94 | + # Arguments |
| 95 | + y_true: Ground truth values, with the same shape as 'y_pred'. |
| 96 | + y_pred: The predicted values. |
| 97 | + """ |
| 98 | + raise NotImplementedError('Must be implemented in subclasses.') |
| 99 | + |
| 100 | + |
| 101 | +class LossFunctionWrapper(Loss): |
| 102 | + """Wraps a loss function in the `Loss` class. |
| 103 | +
|
| 104 | + # Arguments |
| 105 | + fn: The loss function to wrap, with signature `fn(y_true, y_pred, |
| 106 | + **kwargs)`. |
| 107 | + reduction: (Optional) Type of loss reduction to apply to loss. |
| 108 | + Default value is `SUM_OVER_BATCH_SIZE`. |
| 109 | + name: (Optional) name for the loss. |
| 110 | + **kwargs: The keyword arguments that are passed on to `fn`. |
| 111 | + """ |
| 112 | + |
| 113 | + def __init__(self, |
| 114 | + fn, |
| 115 | + reduction=losses_utils.Reduction.SUM_OVER_BATCH_SIZE, |
| 116 | + name=None, |
| 117 | + **kwargs): |
| 118 | + super(LossFunctionWrapper, self).__init__(reduction=reduction, name=name) |
| 119 | + self.fn = fn |
| 120 | + self._fn_kwargs = kwargs |
| 121 | + |
| 122 | + def call(self, y_true, y_pred): |
| 123 | + """Invokes the `LossFunctionWrapper` instance. |
| 124 | +
|
| 125 | + # Arguments |
| 126 | + y_true: Ground truth values. |
| 127 | + y_pred: The predicted values. |
| 128 | +
|
| 129 | + # Returns |
| 130 | + Loss values per sample. |
| 131 | + """ |
| 132 | + return self.fn(y_true, y_pred, **self._fn_kwargs) |
| 133 | + |
| 134 | + def get_config(self): |
| 135 | + config = {} |
| 136 | + for k, v in six.iteritems(self._fn_kwargs): |
| 137 | + config[k] = K.eval(v) if is_tensor_or_variable(v) else v |
| 138 | + base_config = super(LossFunctionWrapper, self).get_config() |
| 139 | + return dict(list(base_config.items()) + list(config.items())) |
| 140 | + |
| 141 | + |
| 142 | +class MeanSquaredError(LossFunctionWrapper): |
| 143 | + """Computes the mean of squares of errors between labels and predictions. |
| 144 | +
|
| 145 | + For example, if `y_true` is [0., 0., 1., 1.] and `y_pred` is [1., 1., 1., 0.] |
| 146 | + then the mean squared error value is 3/4 (0.75). |
| 147 | +
|
| 148 | + Standalone usage: |
| 149 | +
|
| 150 | + ```python |
| 151 | + mse = keras.losses.MeanSquaredError() |
| 152 | + loss = mse([0., 0., 1., 1.], [1., 1., 1., 0.]) |
| 153 | + ``` |
| 154 | +
|
| 155 | + Usage with the `compile` API: |
| 156 | +
|
| 157 | + ```python |
| 158 | + model = keras.Model(inputs, outputs) |
| 159 | + model.compile('sgd', loss=keras.losses.MeanSquaredError()) |
| 160 | + ``` |
| 161 | +
|
| 162 | + # Arguments |
| 163 | + reduction: (Optional) Type of loss reduction to apply to loss. |
| 164 | + Default value is `SUM_OVER_BATCH_SIZE`. |
| 165 | + name: (Optional) name for the loss. |
| 166 | + """ |
| 167 | + |
| 168 | + def __init__(self, |
| 169 | + reduction=losses_utils.Reduction.SUM_OVER_BATCH_SIZE, |
| 170 | + name='mean_squared_error'): |
| 171 | + super(MeanSquaredError, self).__init__( |
| 172 | + mean_squared_error, name=name, reduction=reduction) |
| 173 | + |
| 174 | + |
13 | 175 | def mean_squared_error(y_true, y_pred): |
14 | 176 | return K.mean(K.square(y_pred - y_true), axis=-1) |
15 | 177 |
|
|
0 commit comments