Skip to content

Commit 4680d70

Browse files
braingineerfchollet
authored andcommitted
fixing the constants thing in theano rnn (keras-team#2429)
1 parent ee7f056 commit 4680d70

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

keras/backend/theano_backend.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -564,15 +564,15 @@ def rnn(step_function, inputs, initial_states,
564564
axes = [1, 0] + list(range(2, ndim))
565565
inputs = inputs.dimshuffle(axes)
566566

567+
if constants is None:
568+
constants = []
569+
567570
if mask is not None:
568571
if mask.ndim == ndim-1:
569572
mask = expand_dims(mask)
570573
assert mask.ndim == ndim
571574
mask = mask.dimshuffle(axes)
572575

573-
if constants is None:
574-
constants = []
575-
576576
if unroll:
577577
indices = list(range(input_length))
578578
if go_backwards:
@@ -641,7 +641,7 @@ def _step(input, mask, output_tm1, *states):
641641
successive_states = []
642642
states = initial_states
643643
for i in indices:
644-
output, states = step_function(inputs[i], states)
644+
output, states = step_function(inputs[i], states + constants)
645645
successive_outputs.append(output)
646646
successive_states.append(states)
647647
outputs = T.stack(*successive_outputs)

0 commit comments

Comments
 (0)