@@ -549,17 +549,17 @@ def get_constants(self, x):
549549 if 0 < self .dropout_U < 1 :
550550 ones = K .ones_like (K .reshape (x [:, 0 , 0 ], (- 1 , 1 )))
551551 ones = K .concatenate ([ones ] * self .output_dim , 1 )
552- B_U = [K .dropout (ones , self .dropout_U ) for _ in range (3 )]
552+ B_U = [K .in_train_phase ( K . dropout (ones , self .dropout_U ), ones ) for _ in range (3 )]
553553 constants .append (B_U )
554554 else :
555555 constants .append ([K .cast_to_floatx (1. ) for _ in range (4 )])
556556
557- if self . consume_less == 'cpu' and 0 < self .dropout_W < 1 :
557+ if 0 < self .dropout_W < 1 :
558558 input_shape = self .input_spec [0 ].shape
559559 input_dim = input_shape [- 1 ]
560560 ones = K .ones_like (K .reshape (x [:, 0 , 0 ], (- 1 , 1 )))
561561 ones = K .concatenate ([ones ] * input_dim , 1 )
562- B_W = [K .dropout (ones , self .dropout_W ) for _ in range (3 )]
562+ B_W = [K .in_train_phase ( K . dropout (ones , self .dropout_W ), ones ) for _ in range (3 )]
563563 constants .append (B_W )
564564 else :
565565 constants .append ([K .cast_to_floatx (1. ) for _ in range (4 )])
@@ -715,9 +715,9 @@ def reset_states(self):
715715 self .states = [K .zeros ((input_shape [0 ], self .output_dim )),
716716 K .zeros ((input_shape [0 ], self .output_dim ))]
717717
718- def preprocess_input (self , x , train = False ):
718+ def preprocess_input (self , x ):
719719 if self .consume_less == 'cpu' :
720- if train and ( 0 < self .dropout_W < 1 ) :
720+ if 0 < self .dropout_W < 1 :
721721 dropout = self .dropout_W
722722 else :
723723 dropout = 0
@@ -767,17 +767,17 @@ def get_constants(self, x):
767767 if 0 < self .dropout_U < 1 :
768768 ones = K .ones_like (K .reshape (x [:, 0 , 0 ], (- 1 , 1 )))
769769 ones = K .concatenate ([ones ] * self .output_dim , 1 )
770- B_U = [K .dropout (ones , self .dropout_U ) for _ in range (4 )]
770+ B_U = [K .in_train_phase ( K . dropout (ones , self .dropout_U ), ones ) for _ in range (4 )]
771771 constants .append (B_U )
772772 else :
773773 constants .append ([K .cast_to_floatx (1. ) for _ in range (4 )])
774774
775- if self . consume_less == 'cpu' and 0 < self .dropout_W < 1 :
775+ if 0 < self .dropout_W < 1 :
776776 input_shape = self .input_spec [0 ].shape
777777 input_dim = input_shape [- 1 ]
778778 ones = K .ones_like (K .reshape (x [:, 0 , 0 ], (- 1 , 1 )))
779779 ones = K .concatenate ([ones ] * input_dim , 1 )
780- B_W = [K .dropout (ones , self .dropout_W ) for _ in range (4 )]
780+ B_W = [K .in_train_phase ( K . dropout (ones , self .dropout_W ), ones ) for _ in range (4 )]
781781 constants .append (B_W )
782782 else :
783783 constants .append ([K .cast_to_floatx (1. ) for _ in range (4 )])
0 commit comments