Skip to content

Commit e341e73

Browse files
committed
Fix Dropout in RNNs
1 parent 5dad378 commit e341e73

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

keras/layers/recurrent.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -549,17 +549,17 @@ def get_constants(self, x):
549549
if 0 < self.dropout_U < 1:
550550
ones = K.ones_like(K.reshape(x[:, 0, 0], (-1, 1)))
551551
ones = K.concatenate([ones] * self.output_dim, 1)
552-
B_U = [K.dropout(ones, self.dropout_U) for _ in range(3)]
552+
B_U = [K.in_train_phase(K.dropout(ones, self.dropout_U), ones) for _ in range(3)]
553553
constants.append(B_U)
554554
else:
555555
constants.append([K.cast_to_floatx(1.) for _ in range(4)])
556556

557-
if self.consume_less == 'cpu' and 0 < self.dropout_W < 1:
557+
if 0 < self.dropout_W < 1:
558558
input_shape = self.input_spec[0].shape
559559
input_dim = input_shape[-1]
560560
ones = K.ones_like(K.reshape(x[:, 0, 0], (-1, 1)))
561561
ones = K.concatenate([ones] * input_dim, 1)
562-
B_W = [K.dropout(ones, self.dropout_W) for _ in range(3)]
562+
B_W = [K.in_train_phase(K.dropout(ones, self.dropout_W), ones) for _ in range(3)]
563563
constants.append(B_W)
564564
else:
565565
constants.append([K.cast_to_floatx(1.) for _ in range(4)])
@@ -715,9 +715,9 @@ def reset_states(self):
715715
self.states = [K.zeros((input_shape[0], self.output_dim)),
716716
K.zeros((input_shape[0], self.output_dim))]
717717

718-
def preprocess_input(self, x, train=False):
718+
def preprocess_input(self, x):
719719
if self.consume_less == 'cpu':
720-
if train and (0 < self.dropout_W < 1):
720+
if 0 < self.dropout_W < 1:
721721
dropout = self.dropout_W
722722
else:
723723
dropout = 0
@@ -767,17 +767,17 @@ def get_constants(self, x):
767767
if 0 < self.dropout_U < 1:
768768
ones = K.ones_like(K.reshape(x[:, 0, 0], (-1, 1)))
769769
ones = K.concatenate([ones] * self.output_dim, 1)
770-
B_U = [K.dropout(ones, self.dropout_U) for _ in range(4)]
770+
B_U = [K.in_train_phase(K.dropout(ones, self.dropout_U), ones) for _ in range(4)]
771771
constants.append(B_U)
772772
else:
773773
constants.append([K.cast_to_floatx(1.) for _ in range(4)])
774774

775-
if self.consume_less == 'cpu' and 0 < self.dropout_W < 1:
775+
if 0 < self.dropout_W < 1:
776776
input_shape = self.input_spec[0].shape
777777
input_dim = input_shape[-1]
778778
ones = K.ones_like(K.reshape(x[:, 0, 0], (-1, 1)))
779779
ones = K.concatenate([ones] * input_dim, 1)
780-
B_W = [K.dropout(ones, self.dropout_W) for _ in range(4)]
780+
B_W = [K.in_train_phase(K.dropout(ones, self.dropout_W), ones) for _ in range(4)]
781781
constants.append(B_W)
782782
else:
783783
constants.append([K.cast_to_floatx(1.) for _ in range(4)])

0 commit comments

Comments
 (0)