Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
91e7c8a
disable compute_stream when data parallel
L1aoXingyu May 25, 2022
fa82eac
change input and label placmenet at the beginning
L1aoXingyu May 25, 2022
64e3307
add env variable
L1aoXingyu May 25, 2022
82e39f7
globa tensor to local in graph build
L1aoXingyu May 25, 2022
9ddc9a2
finish softmax fusion
L1aoXingyu May 25, 2022
7cff1a0
turn off evaluation
L1aoXingyu May 25, 2022
f6df295
* add fused_scale_mask_softmax_dropout in bert and t5
L1aoXingyu May 25, 2022
14c2231
using graph block set stage by placement
chengtbf May 26, 2022
718e9ba
fix multihead_fusion loss non-decreasing
L1aoXingyu May 27, 2022
3709b6d
Merge branch 'lxy_libai_bench' of https://github.com/Oneflow-Inc/liba…
L1aoXingyu May 27, 2022
f0d2199
Add all set stage for libai models: resmlp, swin-t, t5, vit
chengtbf May 27, 2022
4ea8f5f
Merge branch 'lxy_libai_bench' of https://github.com/Oneflow-Inc/liba…
chengtbf May 27, 2022
d53fde5
remove expend by broadcast softmax dropout
chengtbf May 27, 2022
c91e39b
pull master and fix conflict
chengtbf May 31, 2022
c2ddf10
Merge branch 'lxy_libai_bench' of github.com:Oneflow-Inc/libai into l…
chengtbf May 31, 2022
90706bd
fix sbp for 2-D in loss cls_head attention all gather and all2all
chengtbf Jun 2, 2022
258dba7
change pipeline_num_layers in dist (#296)
CPFLAME Jun 7, 2022
31e5c40
Merge branch 'main' of github.com:Oneflow-Inc/libai into lxy_libai_bench
chengtbf Jun 7, 2022
87e10cf
fuse optimizer and fp16 cast
chengtbf Jun 8, 2022
2ebab25
disable fuse optim cast for zzk correctness bug fix
chengtbf Jun 13, 2022
1c30c50
del casual mask in gpt init && use fused tri in attention
CPFLAME Jun 13, 2022
85c188e
Merge branch 'lxy_libai_bench' of github.com:Oneflow-Inc/libai into l…
CPFLAME Jun 13, 2022
8fafd11
init rdma after dataloader
chengtbf Jun 13, 2022
c851591
refine zero config in graph base
chengtbf Jun 13, 2022
7001a2c
refine rdma && add persistent_workers in dataloader
CPFLAME Jun 14, 2022
6230697
delete rdma in graph trainer
CPFLAME Jun 14, 2022
4eb6efd
Merge branch 'main' into lxy_libai_bench
chengtbf Jun 17, 2022
e0efe7d
fix import dist
chengtbf Jun 18, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/common/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
layernorm_eps=1e-5,
bias_gelu_fusion=True,
bias_dropout_fusion=True,
scale_mask_softmax_fusion=False,
scale_mask_softmax_fusion=True,
apply_query_key_layer_scaling=True,
apply_residual_post_layernorm=False,
add_binary_head=True,
Expand Down
2 changes: 1 addition & 1 deletion configs/common/models/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
use_scaled_init_for_output_weights=True,
bias_gelu_fusion=True,
bias_dropout_fusion=True,
scale_mask_softmax_fusion=False,
scale_mask_softmax_fusion=True,
apply_query_key_layer_scaling=True,
apply_residual_post_layernorm=False,
amp_enabled=False,
Expand Down
2 changes: 1 addition & 1 deletion configs/common/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
# You can set the maximum evaluation iterations to run for validation/test.
# You can also set a customized evaluator for use.
evaluation=dict(
enabled=True,
enabled=False,
# evaluator for calculating top-k acc
evaluator=LazyCall(ClsEvaluator)(topk=(1, 5)),
eval_period=5000,
Expand Down
9 changes: 8 additions & 1 deletion libai/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def build_nlp_train_loader(
dataset,
batch_sampler=sampler,
num_workers=num_workers,
persistent_workers=True if num_workers > 0 else False,
collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
**kwargs,
)
Expand Down Expand Up @@ -259,7 +260,11 @@ def build_nlp_test_loader(
sampler = instantiate(sampler)

test_loader = DataLoader(
dataset, batch_sampler=sampler, num_workers=num_workers, collate_fn=collate_fn
dataset,
batch_sampler=sampler,
num_workers=num_workers,
persistent_workers=True if num_workers > 0 else False,
collate_fn=collate_fn,
)
return test_loader

Expand Down Expand Up @@ -330,6 +335,7 @@ def build_image_train_loader(
dataset,
batch_sampler=sampler,
num_workers=num_workers,
persistent_workers=True if num_workers > 0 else False,
collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
**kwargs,
)
Expand Down Expand Up @@ -383,6 +389,7 @@ def build_image_test_loader(
dataset,
batch_sampler=sampler,
num_workers=num_workers,
persistent_workers=True if num_workers > 0 else False,
collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
**kwargs,
)
Expand Down
10 changes: 5 additions & 5 deletions libai/data/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ def to_global(self, sbp=None, placement=None):
# We do that to make sure that all the tensors used by the model are all generated
# by the fist device group, in case that each device group containg
# some random augmentations to the tensors without setting the same global seed.
main_placement = dist.get_layer_placement(0)
main_placement = dist.get_layer_placement(0, device_type="cpu") # put it on cpu first
self.tensor = self.tensor.to_global(sbp=self.sbp, placement=main_placement)
if self.placement_idx != 0:
self.tensor = self.tensor.to_global(
placement=dist.get_layer_placement(self.placement_idx)
)
# if self.placement_idx != 0:
# self.tensor = self.tensor.to_global(
# placement=dist.get_layer_placement(self.placement_idx)
# )

@staticmethod
def stack(distTensor_lists: List["DistTensorData"]) -> "DistTensorData":
Expand Down
6 changes: 4 additions & 2 deletions libai/engine/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,8 @@ def __init__(self, cfg):

self.test_loader.extend(self.build_test_loader(cfg, self.tokenizer))

flow.env.init_rdma()

# Automatically scale the hyperparams
self.auto_scale_hyperparams(cfg, self.train_loader)

Expand Down Expand Up @@ -399,7 +401,7 @@ def build_hooks(self):
ret = [
hooks.IterationTimer(),
hooks.LRScheduler(), # for beauty lr scheduler printer in `nn.Graph` mode
hooks.PeriodicCheckpointer(self.checkpointer, self.cfg.train.checkpointer.period),
# hooks.PeriodicCheckpointer(self.checkpointer, self.cfg.train.checkpointer.period),
]

if self.cfg.train.evaluation.enabled:
Expand Down Expand Up @@ -452,7 +454,7 @@ def build_writers(self):
return [
# It may not always print what you want to see, since it prints "common" metrics only.
CommonMetricPrinter(self.global_batch_size, self.max_iter),
JSONWriter(os.path.join(self.cfg.train.output_dir, "metrics.json")),
# JSONWriter(os.path.join(self.cfg.train.output_dir, "metrics.json")),
TensorboardXWriter(self.cfg.train.output_dir),
]

Expand Down
20 changes: 10 additions & 10 deletions libai/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import weakref
from typing import Callable, List, Mapping

import numpy as np
import oneflow as flow

from libai.utils import distributed as dist
Expand Down Expand Up @@ -191,10 +190,11 @@ def write_metrics(
"""
# Only get metric value on rank0
# Consider if it's 2d mesh, ranks should be [[0]] instead of [0]
metrics_dict = {
k: dist.tton(v, local_only=False, ranks=[0] if v.placement.ranks.ndim == 1 else [[0]])
for k, v in loss_dict.items()
}
# metrics_dict = {
# k: dist.tton(v, local_only=False, ranks=[0] if v.placement.ranks.ndim == 1 else [[0]])
# for k, v in loss_dict.items()
# }
metrics_dict = {k: dist.ttol(v, pure_local=True) for k, v in loss_dict.items()}
metrics_dict["data_time"] = data_time

# TODO: Gather metrics among all workers for logging
Expand All @@ -216,11 +216,11 @@ def write_metrics(
# }
metrics_dict = all_metrics_dict
total_losses_reduced = sum(metrics_dict.values())
if not np.isfinite(total_losses_reduced):
raise FloatingPointError(
f"Loss became infinite or NaN at iteration={storage.iter}!\n"
f"loss_dict = {metrics_dict}"
)
# if not np.isfinite(total_losses_reduced):
# raise FloatingPointError(
# f"Loss became infinite or NaN at iteration={storage.iter}!\n"
# f"loss_dict = {metrics_dict}"
# )

storage.put_scalar("{}total_loss".format(prefix), total_losses_reduced)
if len(metrics_dict) > 1:
Expand Down
99 changes: 79 additions & 20 deletions libai/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
import math
import os
from typing import Tuple

import oneflow as flow
Expand All @@ -22,6 +24,11 @@
from .linear import Linear


class AttnMaskType(enum.Enum):
padding = 1
causal = 2


class MultiheadAttention(nn.Module):
"""Multi-head attention layer, support self attention and cross attention.

Expand Down Expand Up @@ -59,10 +66,12 @@ def __init__(
bias_dropout_fusion=False,
scale_mask_softmax_fusion=False,
apply_query_key_layer_scaling=False,
attn_mask_type=AttnMaskType.padding,
*,
layer_idx=0
):
super().__init__()
self.multihead_attn_fusion = os.getenv("MULTIHEAD_ATTN_FUSION") is not None
self.hidden_size = hidden_size
if output_layer_init_method is None:
output_layer_init_method = init_method
Expand All @@ -73,7 +82,9 @@ def __init__(

self.num_heads = num_attention_heads
self.head_size = hidden_size // num_attention_heads
self.attn_mask_type = attn_mask_type

self.attention_dropout_prob = attention_dropout_prob
self.dropout = nn.Dropout(p=attention_dropout_prob)
self.norm_factor = 1.0 / math.sqrt(float(self.head_size))
self.coeff = None
Expand Down Expand Up @@ -123,6 +134,24 @@ def __init__(
layer_idx=layer_idx,
)

def fused_multihead_attn(self, h, attention_mask):
qmk, v = flow._C.fused_self_attention(
h, head_size=self.head_size, alpha=(1.0 / self.norm_factor)
)
if self.scale_mask_softmax_fusion:
attention_weights = flow._C.fused_scale_tril_softmax_mask_scale(
qmk, p=self.attention_dropout_prob, diagonal=0, tril_scale_value=self.coeff
)[0]
else:
if self.coeff is not None:
qmk *= self.coeff
attention_scores = flow.mul(qmk, attention_mask)
attention_scores = attention_scores - 10000.0 * (1 - attention_mask)
attention_weights = flow.softmax(attention_scores, dim=-1)
# [bsz, num_heads, tgt_len, src_len]
attention_weights = self.dropout(attention_weights)
return flow._C.matmul(attention_weights, v)

def forward(
self,
hidden_states: flow.Tensor,
Expand Down Expand Up @@ -184,11 +213,17 @@ def forward(
# hidden_states is the last-added state,
# the full key and value could be obtained by concatenating with past_key_value.
query_key_value = self.query_key_value(hidden_states)
query_key_value = query_key_value.view(bsz, -1, self.num_heads, 3 * self.head_size)
query_key_value = query_key_value.permute(
0, 2, 1, 3
) # [bsz, num_heads, src_len, 3 * head_size]
query, key, value = flow.chunk(query_key_value, chunks=3, dim=-1)
if self.multihead_attn_fusion:
attention_scores, value = flow._C.fused_self_attention(
query_key_value, head_size=self.head_size, alpha=self.norm_factor
)
else:
query_key_value = query_key_value.view(bsz, -1, self.num_heads, 3 * self.head_size)
query_key_value = query_key_value.permute(
0, 2, 1, 3
) # [bsz, num_heads, src_len, 3 * head_size]
query, key, value = flow.chunk(query_key_value, chunks=3, dim=-1)

if past_key_value is not None:
past_key, past_value = past_key_value
key = flow.cat((past_key.type_as(key), key), dim=2)
Expand All @@ -198,41 +233,65 @@ def forward(
if use_cache:
past_key_value = (key, value)

# [bsz, num_heads, tgt_len, src_len] with [S(0), S(1)]
attention_scores = flow.matmul(query, key, transpose_b=True, alpha=self.norm_factor)
if not self.multihead_attn_fusion:
# [bsz, num_heads, tgt_len, src_len] with [S(0), S(1)]
attention_scores = flow.matmul(query, key, transpose_b=True, alpha=self.norm_factor)

# [S(0), S(1)] x [S(0), B] = [S(0), S(1)]
if attention_mask is not None:
if self.scale_mask_softmax_fusion:
attention_weights = flow._C.fused_scale_mask_softmax(
attention_scores, attention_mask, fill_value=-10000.0
)
if self.scale_mask_softmax_fusion:
if self.attn_mask_type == AttnMaskType.padding:
# attention_mask = attention_mask.repeat(1, attention_scores.shape[1], 1, 1)
attention_weights = flow._C.fused_scale_mask_softmax_dropout(
attention_scores,
attention_mask,
fill_value=-10000.0,
scale=self.coeff,
p=self.attention_dropout_prob,
)[0]
else:
if self.coeff is not None:
attention_scores *= self.coeff
attention_scores = flow.mul(attention_scores, attention_mask)
attention_scores = attention_scores - 10000.0 * (1 - attention_mask)
# TODO(l1aoxingyu): graph will occur `where_scalar` errors when using `masked_fill`
# TODO(xingyu.liao): graph will occur `where_scalar` errors
# when using `masked_fill`
# attention_scores = attention_scores.masked_fill(1 - attention_mask, -10000.0)
attention_weights = flow.softmax(attention_scores, dim=-1)
# [bsz, num_heads, tgt_len, src_len]
attention_weights = self.dropout(attention_weights)
else:
attention_weights = flow.softmax(attention_scores, dim=-1)

# [bsz, num_heads, tgt_len, src_len]
attention_weights = self.dropout(attention_weights)
if self.scale_mask_softmax_fusion:
if self.attn_mask_type == AttnMaskType.causal:
attention_weights = flow._C.fused_scale_tril_softmax_mask_scale(
attention_scores,
p=self.attention_dropout_prob,
diagonal=0,
tril_scale_value=self.coeff,
)[0]
else:
attention_weights = flow.softmax(attention_scores, dim=-1)
# [bsz, num_heads, tgt_len, src_len]
attention_weights = self.dropout(attention_weights)

# Context shape: [bsz, num_heads, tgt_len, head_size] with [S(0), S(1)]
context = flow.matmul(attention_weights, value)
# Change shape: [bsz, num_heads, tgt_len, head_size] -> [bsz, tgt_len, num_heads, head_size]
context = context.transpose(1, 2)

if self.multihead_attn_fusion:
context = flow._C.transpose(context, perm=(2, 0, 1, 3))
else:
# Change shape: [bsz, num_heads, tgt_len, head_size] ->
# [bsz, tgt_len, num_heads, head_size]
# context = context.transpose(1, 2)
context = flow._C.transpose(context, perm=(0, 2, 1, 3))

# Concat multi-head results from
# [bsz, tgt_len, num_heads, head_size] -> [bsz, tgt_len, num_heads * head_size]
# SBP sign: [S(0), S(2)]
context = context.view(bsz, tgt_len, self.hidden_size)
# context = context.view(bsz, tgt_len, self.hidden_size)

# [S(0), S(2)] x [B, S(0)] = [S(0), P] -> [S(0), B]
output = self.dense(context)
output = self.dense(context.flatten(2))

if self.bias_dropout_fusion:
output, bias = output
Expand Down
2 changes: 2 additions & 0 deletions libai/layers/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def forward(self, logits: flow.Tensor, target: flow.Tensor):
assert target.ndim == 2
assert logits.shape[0:2] == target.shape

target = target.to_global(placement=logits.placement)

# Change -1 in target to 0 because sparse_softmax_cross_entropy don't accept -1
target = target * (target >= 0)

Expand Down
3 changes: 2 additions & 1 deletion libai/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ def forward(self, x):
# x.grad sbp must be x.sbp, otherwise backward pass cannot be performed correctly.
x = x.to_global(grad_sbp=x.sbp)
# Change x.sbp to [S(0), S(0)] if weight is [B, B]
x = x.to_global(sbp=dist.get_nd_sbp([flow.sbp.split(0), flow.sbp.split(0)]))
# NOTE(chengcheng): when input x is [S(0), B], there is no need to change sbp for x.
# x = x.to_global(sbp=dist.get_nd_sbp([flow.sbp.split(0), flow.sbp.split(0)]))
x = flow.matmul(x, self.weight, transpose_b=True)
else:
# Not supported weight_sbp, deduce sbp and communicate with nccl automatically.
Expand Down
5 changes: 4 additions & 1 deletion libai/layers/transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from libai.utils import distributed as dist

from .attention import MultiheadAttention
from .attention import AttnMaskType, MultiheadAttention
from .droppath import DropPath
from .layer_norm import LayerNorm
from .mlp import MLP
Expand Down Expand Up @@ -72,6 +72,7 @@ def __init__(
scale_mask_softmax_fusion=False,
apply_query_key_layer_scaling=False,
apply_residual_post_layernorm=False,
attn_mask_type=AttnMaskType.padding,
*,
layer_idx=0
):
Expand All @@ -82,6 +83,7 @@ def __init__(
self.attention_dropout_prob = attention_dropout_prob
self.output_dropout_prob = output_dropout_prob
self.layernorm_epsilon = layernorm_epsilon
self.attn_mask_type = attn_mask_type

self.layer_idx = layer_idx
self.is_decoder = is_decoder
Expand Down Expand Up @@ -241,5 +243,6 @@ def build_attention(self, is_cross_attention=False):
bias_dropout_fusion=self.bias_dropout_fusion,
scale_mask_softmax_fusion=self.scale_mask_softmax_fusion,
apply_query_key_layer_scaling=self.apply_query_key_layer_scaling,
attn_mask_type=self.attn_mask_type,
layer_idx=self.layer_idx,
)
Loading