Skip to content

Commit 7e678b8

Browse files
committed
Objective outputs should rescale based on sample_weights
If sample_weights is to be used as a mask as well as for re-weighting then it's important that, at least when used as a mask, the output be rescaled. Otherwise the order of magnitude of your objective changes purely based on the number of masked entries in your training data.
1 parent 02d5f72 commit 7e678b8

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

keras/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def weighted(y_true, y_pred, weights, mask=None):
9797
# apply sample weighting
9898
if weights is not None:
9999
score_array *= weights
100+
score_array /= K.mean(K.cast(K.not_equal(weights, 0), K.floatx()))
100101
return K.mean(score_array)
101102
return weighted
102103

tests/integration_tests/test_temporal_data_tasks.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55

66
from keras.utils.test_utils import get_test_data
77
from keras.models import Sequential
8-
from keras.layers.core import TimeDistributedDense, Dropout, Dense
8+
from keras.layers.core import TimeDistributedDense, Dropout, Dense, Activation
99
from keras.layers.recurrent import GRU, LSTM
10+
from keras.layers.embeddings import Embedding
1011
from keras.utils.np_utils import to_categorical
1112

1213

@@ -126,6 +127,48 @@ def test_stacked_lstm_char_prediction():
126127
# check that it did generate the alphabet correctly
127128
assert(generated == alphabet)
128129

130+
def test_masked_temporal():
131+
'''
132+
Confirm that even with masking on both inputs and outputs, cross-entropies are
133+
of the expected scale.
134+
135+
In this task, there are variable length inputs of integers from 1-9, and a random
136+
subset of unmasked outputs. Each of these outputs has a 50% probability of being
137+
the input number unchanged, and a 50% probability of being 2*input%10.
138+
139+
The ground-truth best cross-entropy loss should, then be -log(0.5) = 0.69
140+
141+
'''
142+
np.random.seed(55318)
143+
model = Sequential()
144+
model.add(Embedding(10, 20, mask_zero=True))
145+
model.add(TimeDistributedDense(10))
146+
model.add(Activation('softmax'))
147+
model.compile(loss='categorical_crossentropy',
148+
optimizer='adam', sample_weight_mode="temporal")
149+
150+
X = np.random.random_integers(1, 9, (50000, 20))
151+
for rowi in range(X.shape[0]):
152+
padding = np.random.random_integers(X.shape[1]/2)
153+
X[rowi, :padding] = 0
154+
155+
# 50% of the time the correct output is the input. The other 50% of the time
156+
# it's 2*input%10
157+
y = (X * np.random.random_integers(1, 2, X.shape))%10
158+
Y = np.zeros((y.size, 10), dtype='int32')
159+
for i, target in enumerate(y.flat):
160+
Y[i, target] = 1
161+
Y = Y.reshape(y.shape + (10,))
162+
163+
# Mask 50% of the outputs via sample weights
164+
sample_weight = np.random.random_integers(0, 1, y.shape)
165+
print("X shape: ", X.shape)
166+
print("Y shape: ", Y.shape)
167+
print("sample_weight shape: ", Y.shape)
168+
169+
history = model.fit(X, Y, validation_split=0.05, sample_weight=sample_weight,verbose=1, nb_epoch=2)
170+
ground_truth = -np.log(0.5)
171+
assert(np.abs(history.history['val_loss'][-1] - ground_truth) < 0.05)
129172

130173
if __name__ == '__main__':
131174
pytest.main([__file__])

0 commit comments

Comments
 (0)