@@ -1442,31 +1442,26 @@ def PackSource(self,
14421442 py_utils .GetShape (source_contexts ))
14431443 time_steps , batch_size = py_utils .GetShape (source_padding , 2 )
14441444 # source_projected shape [time * source_batch, hidden]
1445- with tf .name_scope ('init__0a' ):
1446- source_vec_depth = py_utils .GetShape (source_vecs )[2 ]
1447- with tf .name_scope ('init__0b' ):
1448- if p .enable_source_proj :
1449- source_vecs = tf .reshape (source_vecs , [- 1 , source_vec_depth ])
1450- source_vecs , w_source_proj = self .ToAqtInputs (
1451- 'source_proj' ,
1452- act = source_vecs ,
1453- weight = theta .source_proj ,
1454- w_feature_axis = - 1 )
1455- w_source_proj = self .QWeight (w_source_proj )
1456- source_projected = tf .matmul (source_vecs , w_source_proj )
1457- source_projected = self .QAct ('source_proj_matmul' , source_projected )
1458- source_projected = self .FromAqtMatmul ('source_proj' , source_projected )
1459- if p .use_bias :
1460- source_projected = fns .qadd (
1461- source_projected ,
1462- self .QWeight (theta .source_proj_b ),
1463- qout_name = 'source_proj_add' )
1464- else :
1465- source_projected = tf .reshape (source_vecs , [- 1 , source_vec_depth ])
1466- if p .activation_split_dims_mapping :
1467- source_projected = gshard_utils .MeshSplit (
1468- source_projected , p .device_mesh ,
1469- p .activation_split_dims_mapping [1 :])
1445+ if p .enable_source_proj :
1446+ source_vecs , w_source_proj = self .ToAqtInputs (
1447+ 'source_proj' ,
1448+ act = source_vecs ,
1449+ weight = theta .source_proj ,
1450+ w_feature_axis = - 1 )
1451+ w_source_proj = self .QWeight (w_source_proj )
1452+ source_projected = tf .matmul (source_vecs , w_source_proj )
1453+ source_projected = self .QAct ('source_proj_matmul' , source_projected )
1454+ source_projected = self .FromAqtMatmul ('source_proj' , source_projected )
1455+ if p .use_bias :
1456+ source_projected = fns .qadd (
1457+ source_projected ,
1458+ self .QWeight (theta .source_proj_b ),
1459+ qout_name = 'source_proj_add' )
1460+ else :
1461+ source_projected = source_vecs
1462+ source_projected = gshard_utils .MeshSplit (
1463+ source_projected , p .device_mesh ,
1464+ p .activation_split_dims_mapping )
14701465 with tf .name_scope ('init__1' ):
14711466 num_heads = p .num_attention_heads
14721467 # => [time, source_batch * num_heads, hidden / num_heads]
@@ -1482,8 +1477,6 @@ def PackSource(self,
14821477 source_contexts_projected = source_projected
14831478 else :
14841479 if p .enable_ctx_pre_proj :
1485- source_contexts = tf .reshape (
1486- source_contexts , [- 1 , py_utils .GetShape (source_contexts )[2 ]])
14871480 source_contexts , w_ctx_proj = self .ToAqtInputs (
14881481 'ctx_proj' ,
14891482 act = source_contexts ,
@@ -1501,10 +1494,9 @@ def PackSource(self,
15011494 source_contexts_projected ,
15021495 self .QWeight (theta .ctx_proj_b ),
15031496 qout_name = 'ctx_pre_proj_add' )
1504- if p .activation_split_dims_mapping :
1505- source_contexts_projected = gshard_utils .MeshSplit (
1506- source_contexts_projected , p .device_mesh ,
1507- p .activation_split_dims_mapping [1 :])
1497+ source_contexts_projected = gshard_utils .MeshSplit (
1498+ source_contexts_projected , p .device_mesh ,
1499+ p .activation_split_dims_mapping )
15081500 else :
15091501 source_contexts_projected = source_contexts
15101502
@@ -3304,7 +3296,7 @@ def Params(cls):
33043296 p .Define ('source_dim' , 0 , 'Default source dimension.' )
33053297 p .Define (
33063298 'query_dim' , 0 , 'Number of query nodes. Child attention params '
3307- 'must have query_dim less or euqal than 0 or equal to this value.' )
3299+ 'must have query_dim less or equal than 0 or equal to this value.' )
33083300 p .Define (
33093301 'primary_source_key' , 'source_0' , 'Key for the primary source '
33103302 'whose attention probabilities will be used as an output.' )
0 commit comments