Skip to content

Commit 64ef85f

Browse files
lingvo-botcopybara-github
authored andcommitted
Avoid using inference-time arithmetic on inputs of reshapes, use -1 instead.
PiperOrigin-RevId: 587884292
1 parent e7db023 commit 64ef85f

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

lingvo/core/attention.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1440,7 +1440,7 @@ def PackSource(self,
14401440
if p.use_source_vec_as_attention_value:
14411441
source_vecs = py_utils.HasShape(source_vecs,
14421442
py_utils.GetShape(source_contexts))
1443-
time_steps, batch_size = py_utils.GetShape(source_padding, 2)
1443+
[time_steps] = py_utils.GetShape(source_vecs, 1)
14441444
# source_projected shape [time * source_batch, hidden]
14451445
if p.enable_source_proj:
14461446
source_vecs, w_source_proj = self.ToAqtInputs(
@@ -1466,8 +1466,7 @@ def PackSource(self,
14661466
num_heads = p.num_attention_heads
14671467
# => [time, source_batch * num_heads, hidden / num_heads]
14681468
source_projected = tf.reshape(source_projected, [
1469-
time_steps, batch_size * num_heads,
1470-
symbolic.ToStatic(p.hidden_dim // num_heads)
1469+
time_steps, -1, symbolic.ToStatic(p.hidden_dim // num_heads)
14711470
])
14721471
source_projected = gshard_utils.MeshSplit(source_projected, p.device_mesh,
14731472
p.activation_split_dims_mapping)

0 commit comments

Comments
 (0)