@@ -1440,7 +1440,7 @@ def PackSource(self,
14401440 if p .use_source_vec_as_attention_value :
14411441 source_vecs = py_utils .HasShape (source_vecs ,
14421442 py_utils .GetShape (source_contexts ))
1443- time_steps , batch_size = py_utils .GetShape (source_padding , 2 )
1443+ [ time_steps ] = py_utils .GetShape (source_vecs , 1 )
14441444 # source_projected shape [time * source_batch, hidden]
14451445 if p .enable_source_proj :
14461446 source_vecs , w_source_proj = self .ToAqtInputs (
@@ -1466,8 +1466,7 @@ def PackSource(self,
14661466 num_heads = p .num_attention_heads
14671467 # => [time, source_batch * num_heads, hidden / num_heads]
14681468 source_projected = tf .reshape (source_projected , [
1469- time_steps , batch_size * num_heads ,
1470- symbolic .ToStatic (p .hidden_dim // num_heads )
1469+ time_steps , - 1 , symbolic .ToStatic (p .hidden_dim // num_heads )
14711470 ])
14721471 source_projected = gshard_utils .MeshSplit (source_projected , p .device_mesh ,
14731472 p .activation_split_dims_mapping )
0 commit comments