@@ -154,12 +154,12 @@ def mean(x, axis=None, keepdims=False):
154154def any (x , axis = None , keepdims = False ):
155155 '''Bitwise reduction (logical OR).
156156
157- Return array of int8 (0s and 1s).
157+ Return array of uint8 (0s and 1s).
158158 '''
159159 axis = normalize_axis (axis , ndim (x ))
160160 x = tf .cast (x , tf .bool )
161161 x = tf .reduce_any (x , reduction_indices = axis , keep_dims = keepdims )
162- return tf .cast (x , tf .int8 )
162+ return tf .cast (x , tf .uint8 )
163163
164164
165165def argmax (x , axis = - 1 ):
@@ -289,6 +289,7 @@ def repeat(x, n):
289289 if x has shape (samples, dim) and n=2,
290290 the output will have shape (samples, 2, dim)
291291 '''
292+ assert ndim (x ) == 2
292293 tensors = [x ] * n
293294 stacked = tf .pack (tensors )
294295 return tf .transpose (stacked , (1 , 0 , 2 ))
@@ -429,54 +430,53 @@ def rnn(step_function, inputs, initial_states,
429430 axes = [1 , 0 ] + list (range (2 , ndim ))
430431 inputs = tf .transpose (inputs , (axes ))
431432 input_list = tf .unpack (inputs )
432- if mask is None :
433- mask = ones_like (tf .slice (inputs , [0 , 0 , 0 ], [- 1 , - 1 , 1 ]))
434- inputs_shape = inputs .get_shape ()
435-
436- # TODO: the mask's shape should be automatically inferred, by
437- # tensorflow yet for some reason it fails to in some test-cases. This
438- # fixes the issue, but should be removed in future.
439- mask .set_shape ([inputs_shape [0 ].value , inputs_shape [1 ].value , 1 ])
440- mask = tf .cast (mask , tf .bool )
441- else :
442- # Transpose not supported by bool tensor types, hence round-trip to uint8.
443- mask = tf .cast (tf .transpose (tf .cast (mask , tf .uint8 ), axes ), tf .bool )
444-
445- mask_list = tf .unpack (mask )
446433
447434 states = initial_states
448435 successive_states = []
449436 successive_outputs = []
450437 if go_backwards :
451438 input_list .reverse ()
452439
453- for input , mask_t in zip (input_list , mask_list ):
454- output , new_states = step_function (input , states )
455-
456- # tf.select needs its condition tensor to be the same shape as its two
457- # result tensors, but in our case the condition (mask) tensor is
458- # (nsamples, 1), and A and B are (nsamples, ndimensions). So we need to
459- # broadcast the mask to match the shape of A and B. That's what the
460- # tile call does, is just repeat the mask along its second dimension
461- # ndimensions times.
462- tiled_mask_t = tf .tile (mask_t , tf .pack ([1 , tf .shape (output )[1 ]]))
463-
464- if len (successive_outputs ) == 0 :
465- prev_output = zeros_like (output )
466- else :
467- prev_output = successive_outputs [- 1 ]
468-
469- output = tf .select (tiled_mask_t , output , prev_output )
470-
471- return_states = []
472- for state , new_state in zip (states , new_states ):
473- # (see earlier comment for tile explanation)
474- tiled_mask_t = tf .tile (mask_t , tf .pack ([1 , tf .shape (new_state )[1 ]]))
475- return_states .append (tf .select (tiled_mask_t , new_state , state ))
476-
477- states = return_states
478- successive_outputs .append (output )
479- successive_states .append (states )
440+ if mask is not None :
441+ # Transpose not supported by bool tensor types, hence round-trip to uint8.
442+ mask = tf .cast (mask , tf .uint8 )
443+ if len (mask .get_shape ()) == ndim - 1 :
444+ mask = expand_dims (mask )
445+ mask = tf .cast (tf .transpose (mask , axes ), tf .bool )
446+ mask_list = tf .unpack (mask )
447+
448+ for input , mask_t in zip (input_list , mask_list ):
449+ output , new_states = step_function (input , states )
450+
451+ # tf.select needs its condition tensor to be the same shape as its two
452+ # result tensors, but in our case the condition (mask) tensor is
453+ # (nsamples, 1), and A and B are (nsamples, ndimensions). So we need to
454+ # broadcast the mask to match the shape of A and B. That's what the
455+ # tile call does, is just repeat the mask along its second dimension
456+ # ndimensions times.
457+ tiled_mask_t = tf .tile (mask_t , tf .pack ([1 , tf .shape (output )[1 ]]))
458+
459+ if len (successive_outputs ) == 0 :
460+ prev_output = zeros_like (output )
461+ else :
462+ prev_output = successive_outputs [- 1 ]
463+
464+ output = tf .select (tiled_mask_t , output , prev_output )
465+
466+ return_states = []
467+ for state , new_state in zip (states , new_states ):
468+ # (see earlier comment for tile explanation)
469+ tiled_mask_t = tf .tile (mask_t , tf .pack ([1 , tf .shape (new_state )[1 ]]))
470+ return_states .append (tf .select (tiled_mask_t , new_state , state ))
471+
472+ states = return_states
473+ successive_outputs .append (output )
474+ successive_states .append (states )
475+ else :
476+ for input in input_list :
477+ output , states = step_function (input , states )
478+ successive_outputs .append (output )
479+ successive_states .append (states )
480480
481481 last_output = successive_outputs [- 1 ]
482482 outputs = tf .pack (successive_outputs )
0 commit comments