Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
use fuse multi head att
  • Loading branch information
xiezipeng-ML authored and xiezipeng-ML committed Oct 28, 2022
commit a1fc307db5d920e138166716c2798af38f85acf5
3 changes: 2 additions & 1 deletion projects/T5/configs/mt5_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
train_data_path = "projects/T5/data/training_data/part_0"
pretrained_model_path = None

micro_batch_size = 64
micro_batch_size = 16
optim["lr"] = 1e-4

# dataloader
Expand Down Expand Up @@ -54,6 +54,7 @@
model.cfg.embedding_dropout_prob = 0.0
model.cfg.layernorm_eps = 1e-6
model.cfg.model_type = "mt5"
model.cfg.scale_mask_softmax_fusion = True
model.cfg.pretrained_model_path = pretrained_model_path

train.update(
Expand Down
1 change: 1 addition & 0 deletions projects/T5/configs/t5_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
embedding_dropout_prob=0.1,
initializer_range=0.02,
layernorm_eps=1e-5,
scale_mask_softmax_fusion=True,
amp_enabled=False,
model_type="t5",
)
Expand Down
30 changes: 22 additions & 8 deletions projects/T5/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
output_dropout_prob=0.0,
init_method=nn.init.xavier_normal_,
output_layer_init_method=None,
scale_mask_softmax_fusion=True,
*,
layer_idx=0,
has_relative_attention_bias=False,
Expand All @@ -65,6 +66,7 @@ def __init__(
self.has_relative_attention_bias = has_relative_attention_bias
self.is_decoder = is_decoder
self.attention_dropout_prob = attention_dropout_prob
self.scale_mask_softmax_fusion = scale_mask_softmax_fusion

if output_layer_init_method is None:
output_layer_init_method = init_method
Expand Down Expand Up @@ -230,14 +232,26 @@ def forward(

# [S(0), S(1)] x [S(0), B] = [S(0), S(1)]
if attention_mask is not None:
attention_scores = flow.mul(attention_scores, attention_mask)
attention_scores = attention_scores - 10000.0 * (1 - attention_mask)
# TODO(xingyu.liao): graph will occur `where_scalar` errors
# when using `masked_fill`
# attention_scores = attention_scores.masked_fill(1 - attention_mask, -10000.0)
attention_weights = flow.softmax(attention_scores, dim=-1)
# [bsz, num_heads, tgt_len, src_len]
attention_weights = self.dropout(attention_weights)
if self.scale_mask_softmax_fusion:
attention_mask = (
attention_mask.expand_as(attention_scores) if use_cache else attention_mask
)
attention_weights = flow._C.fused_scale_mask_softmax_dropout(
attention_scores,
attention_mask,
fill_value=-10000.0,
scale=1,
p=self.attention_dropout_prob,
)[0]
else:
attention_scores = flow.mul(attention_scores, attention_mask)
attention_scores = attention_scores - 10000.0 * (1 - attention_mask)
# TODO(xingyu.liao): graph will occur `where_scalar` errors
# when using `masked_fill`
# attention_scores = attention_scores.masked_fill(1 - attention_mask, -10000.0)
attention_weights = flow.softmax(attention_scores, dim=-1)
# [bsz, num_heads, tgt_len, src_len]
attention_weights = self.dropout(attention_weights)
else:
attention_weights = flow.softmax(attention_scores, dim=-1)
# [bsz, num_heads, tgt_len, src_len]
Expand Down
4 changes: 4 additions & 0 deletions projects/T5/models/t5_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
hidden_dropout_prob,
attention_probs_dropout_prob,
relative_attention_num_buckets,
scale_mask_softmax_fusion=True,
initializer_range=0.02,
layernorm_eps=1e-12,
amp_enabled=False,
Expand Down Expand Up @@ -73,6 +74,7 @@ def __init__(
layernorm_epsilon=layernorm_eps,
init_method=init_method,
output_layer_init_method=scaled_init_method,
scale_mask_softmax_fusion=scale_mask_softmax_fusion,
layer_idx=i,
model_type=model_type,
has_relative_attention_bias=bool(i == 0),
Expand Down Expand Up @@ -105,6 +107,7 @@ def __init__(
layernorm_epsilon=layernorm_eps,
init_method=init_method,
output_layer_init_method=scaled_init_method,
scale_mask_softmax_fusion=scale_mask_softmax_fusion,
layer_idx=i,
model_type=model_type,
has_relative_attention_bias=bool(i - hidden_layers == 0),
Expand Down Expand Up @@ -150,6 +153,7 @@ def from_config(cls, cfg):
"layernorm_eps": cfg.layernorm_eps,
"amp_enabled": cfg.amp_enabled,
"model_type": cfg.model_type,
"scale_mask_softmax_fusion": cfg.scale_mask_softmax_fusion,
}

def forward(
Expand Down
6 changes: 6 additions & 0 deletions projects/T5/models/transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
layernorm_epsilon=1e-5,
init_method=nn.init.xavier_normal_,
output_layer_init_method=None,
scale_mask_softmax_fusion=True,
*,
layer_idx=0,
model_type="t5",
Expand All @@ -73,6 +74,7 @@ def __init__(
self.layernorm_epsilon = layernorm_epsilon
self.layer_idx = layer_idx
self.is_decoder = is_decoder
self.scale_mask_softmax_fusion = scale_mask_softmax_fusion

self.init_method = init_method
if output_layer_init_method is None:
Expand All @@ -89,6 +91,7 @@ def __init__(
is_cross_attention=False,
relative_attention_num_buckets=relative_attention_num_buckets,
has_relative_attention_bias=has_relative_attention_bias,
scale_mask_softmax_fusion=scale_mask_softmax_fusion,
is_decoder=self.is_decoder,
)
self.post_attention_layernorm = LayerNorm(
Expand All @@ -99,6 +102,7 @@ def __init__(
self.cross_attention = self.build_attention(
is_cross_attention=True,
relative_attention_num_buckets=relative_attention_num_buckets,
scale_mask_softmax_fusion=scale_mask_softmax_fusion,
is_decoder=self.is_decoder,
)
self.post_cross_attention_layernorm = LayerNorm(
Expand Down Expand Up @@ -234,6 +238,7 @@ def build_attention(
is_cross_attention=False,
relative_attention_num_buckets=None,
has_relative_attention_bias=False,
scale_mask_softmax_fusion=True,
is_decoder=False,
):
return MultiheadAttention(
Expand All @@ -246,6 +251,7 @@ def build_attention(
output_dropout_prob=self.output_dropout_prob,
init_method=self.init_method,
output_layer_init_method=self.output_layer_init_method,
scale_mask_softmax_fusion=scale_mask_softmax_fusion,
layer_idx=self.layer_idx,
has_relative_attention_bias=has_relative_attention_bias,
is_decoder=is_decoder,
Expand Down