Skip to content
Open
Show file tree
Hide file tree
Changes from 26 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion configs/common/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@

# Enable automatic mixed precision for training which does not
# change model's inference behavior.
amp=dict(enabled=False),
# set type to `fp16` or `bf16` according to your own needs.
amp=dict(
enabled=False,
type="fp16"
),

# Enable activation checkpointing to allow for training
# with larger models, sequences, and batch sizes.
Expand Down
2 changes: 2 additions & 0 deletions libai/engine/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,9 +541,11 @@ def build_model(cls, cfg):
# insert for amp training if provided.
if try_get_key(cfg.model, "cfg.amp_enabled") is not None:
cfg.model.cfg.amp_enabled = cfg.train.amp.enabled and cfg.graph.enabled
cfg.model.cfg.amp_type = cfg.train.amp.type
# In case some model define without cfg keyword.
elif try_get_key(cfg.model, "amp_enabled") is not None:
cfg.model.amp_enabled = cfg.train.amp.enabled and cfg.graph.enabled
cfg.model.cfg.amp_type = cfg.train.amp.type
model = build_model(cfg.model)
logger = logging.getLogger(__name__)
logger.info("Model:\n{}".format(model))
Expand Down
3 changes: 2 additions & 1 deletion libai/models/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def build_graph(cfg, model, optimizer=None, lr_scheduler=None, is_train=False):
graph.model = model
graph.optimizer = optimizer
graph.lr_scheduler = lr_scheduler
graph.fp16 = try_get_key(cfg, "train.amp.enabled", default=False)
graph.amp = try_get_key(cfg, "train.amp.enabled", default=False)
graph.amp_type = try_get_key(cfg, "train.amp.type", default="fp16")
graph.activation_checkpoint = try_get_key(
cfg, "train.activation_checkpoint.enabled", default=False
)
Expand Down
27 changes: 17 additions & 10 deletions libai/models/utils/graph_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def __init__(
model: nn.Module,
optimizer: flow.optim.Optimizer = None,
lr_scheduler: flow.optim.lr_scheduler = None,
fp16=False,
amp=False,
amp_type="fp16",
activation_checkpoint=False,
grad_acc_steps=1,
zero_optim=False,
Expand All @@ -45,15 +46,21 @@ def __init__(

if is_train:
self.add_optimizer(optimizer, lr_sch=lr_scheduler)
if fp16:
self.config.enable_amp(True)
grad_scaler = flow.amp.GradScaler(
init_scale=2 ** 30,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=2000,
)
self.set_grad_scaler(grad_scaler)
if amp:
if amp_type == "fp16":
self.config.enable_amp(True)

grad_scaler = flow.amp.GradScaler(
init_scale=2 ** 12,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=1000,
)
self.set_grad_scaler(grad_scaler)
elif amp_type == "bf16":
self.config.enable_amp(True, dtype=flow.bfloat16)
else:
raise NotImplementedError(f"Not support {amp_type}")

if grad_acc_steps > 1:
self.config.set_gradient_accumulation_steps(grad_acc_steps)
Expand Down
57 changes: 57 additions & 0 deletions projects/RWKV_v4/configs/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from omegaconf import OmegaConf
import oneflow as flow

from libai.config import get_config
from libai.config import LazyCall
from libai.tokenizer import GPT2Tokenizer
from libai.data.build import build_nlp_test_loader, build_nlp_train_loader

from projects.RWKV_v4.modeling.model import GPT, GPTConfig
from projects.RWKV_v4.utils.config_optimizer import get_RWKV_v4_config_optim
from projects.RWKV_v4.dataset import RWKVDataset

graph = get_config("common/models/graph.py").graph
train = get_config("common/train.py").train
optim = LazyCall(flow.optim.Adam)(
params=LazyCall(get_RWKV_v4_config_optim)(),
lr=8e-4,
)

model = LazyCall(GPT)(vocab_size=6064, ctx_len=1024, model_type="RWKV", n_layer=6, n_embd=512)

load_torch_checkpoint = OmegaConf.create()
load_torch_checkpoint.enable = True
load_torch_checkpoint.weight_style = "pytorch"
load_torch_checkpoint.path = "data_test/rwkv_test/for_load.pth"

datafile = "data_test/rwkv_test/enwik8"
dataloader = OmegaConf.create()
dataloader.train = LazyCall(build_nlp_train_loader)(
dataset=[
LazyCall(RWKVDataset)(
data_dir=datafile,
ctx_len=1024,
epoch_length_fixed=9996,
),
],
num_workers=4,
)

train.train_micro_batch_size = 12
train.train_iter = 0
train.train_epoch = 1

train.amp.enabled = True
train.amp.type = "bf16"

train.rdma_enabled = False
train.activation_checkpoint.enabled = True

train.input_placement_device = "cpu"
train.dist.data_parallel_size = 1
train.dist.tensor_parallel_size = 1
train.dist.pipeline_parallel_size = 1
train.dist.pipeline_num_layers = model.n_layer

train.output_dir = "output/RWKV4/"
# train.load_weight = "output/RWKV_v4/model/model_final/"
59 changes: 59 additions & 0 deletions projects/RWKV_v4/configs/config_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from omegaconf import OmegaConf

from libai.config import get_config
from libai.config import LazyCall
from libai.tokenizer import GPT2Tokenizer
from libai.data.build import build_nlp_test_loader, build_nlp_train_loader
import oneflow as flow

from projects.RWKV_v4.modeling.model import GPT, GPTConfig
from projects.RWKV_v4.dataset import RWKVDataset
from projects.RWKV_v4.utils.config_optimizer import get_RWKV_v4_config_optim


load_torch_checkpoint = OmegaConf.create()
load_torch_checkpoint.enable = True
load_torch_checkpoint.weight_style = "pytorch"
load_torch_checkpoint.path = "data_test/rwkv_test/for_load.pth"

graph = get_config("common/models/graph.py").graph

graph.enabled = True

optim = LazyCall(flow.optim.Adam)(
params=LazyCall(get_RWKV_v4_config_optim)(),
lr=8e-4,
)

model = LazyCall(GPT)(vocab_size=6064, ctx_len=1024, model_type="RWKV", n_layer=6, n_embd=512)

train = get_config("common/train.py").train
train.input_placement_device = "cpu"
train.dist.pipeline_num_layers = 6
train.train_micro_batch_size = 4
train.scheduler = LazyCall(flow.optim.lr_scheduler.StepLR)(step_size=1000, gamma=1.0)

train.amp.enabled = True
train.amp.type = "bf16"

datafile = "data_test/rwkv_test/enwik8"

dataloader = OmegaConf.create()
dataloader.train = LazyCall(build_nlp_train_loader)(
dataset=[
LazyCall(RWKVDataset)(
data_dir=datafile,
ctx_len=1024,
epoch_length_fixed=9996,
),
],
num_workers=1,
)

train.train_iter = 300
# train.train_epoch = 1

train.output_dir = "output/rwkv_output_loss_compare"
train.rdma_enabled = False
train.evaluation.enabled = False
train.activation_checkpoint.enabled = True
1 change: 1 addition & 0 deletions projects/RWKV_v4/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .dataset import RWKVDataset
132 changes: 132 additions & 0 deletions projects/RWKV_v4/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import json
import random

import numpy as np
import oneflow as flow
import oneflow.nn.functional as F
from oneflow.utils.data import Dataset

from libai.data.structures import DistTensorData, Instance


class RWKVDataset(Dataset):
def __init__(self, data_dir, ctx_len, epoch_length_fixed):
data = open(data_dir, "r", encoding="utf-8").read()
print("building token list...", end=" ")
unique = sorted(list(set(data)))

xx = 0
xxObj = {}
for u in unique:
xxObj[xx] = u
xx += 1
# NOTE: comment out write json file.
# with open("vocab.json", "w", encoding="utf-16") as vocab_file:
# vocab_file.write(json.dumps(xxObj, ensure_ascii=False))

data_size, vocab_size = len(data), len(unique)
print("data has %d tokens, %d unique." % (data_size, vocab_size))
self.stoi = {ch: i for i, ch in enumerate(unique)}
self.itos = {i: ch for i, ch in enumerate(unique)}
self.ctx_len = ctx_len
self.epoch_length_fixed = epoch_length_fixed
self.vocab_size = vocab_size
self.data = data

def __len__(self):
return self.epoch_length_fixed

def __getitem__(self, idx):
# cheat: pick a random spot in dataset
i = np.random.randint(0, len(self.data) - (self.ctx_len + 1))
chunk = self.data[i : i + self.ctx_len + 1]
dix = [self.stoi[s] for s in chunk]
x = flow.tensor(dix[:-1], dtype=flow.long)
y = flow.tensor(dix[1:], dtype=flow.long)
return Instance(idx=DistTensorData(x), targets=DistTensorData(y))


class TOKENIZER:
def __init__(self, WORD_NAME, UNKNOWN_CHAR="\ue083"):
with open(WORD_NAME + ".json", "r", encoding="utf-16") as result_file:
self.word_table = json.load(result_file)

self.vocab_size = len(self.word_table)

self.stoi = {v: int(k) for k, v in self.word_table.items()}
self.itos = {int(k): v for k, v in self.word_table.items()}

self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]

def refine_context(self, context):

context = context.strip().split("\n")
for c in range(len(context)):
context[c] = context[c].strip().strip("\u3000").strip("\r")
context = list(filter(lambda c: c != "", context))
context = "\n" + ("\n".join(context)).strip()
if context == "":
context = "\n"

return context

def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None):
# out[self.UNKNOWN_CHAR] = -float('Inf')

lastChar = int(x[-1])

probs = F.softmax(flow.tensor(out), dim=-1)

if self.itos[lastChar] == "\n":
top_p = top_p_newline
else:
top_p = top_p_usual

sorted_probs, s_index = flow.sort(probs, descending=True)

# for j in range(30):
# pp = sorted_probs[j].item()
# if pp < 0.005:
# break
# ss = self.itos[int(s_index[j])].replace('\n','_')
# print(f'{math.floor(pp*100):>3.0f}{ss}', end='')
# print('')

cumulative_probs = flow.cumsum(sorted_probs, dim=-1).numpy()
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])

probs[probs < cutoff] = 0
# print("[" + str(round(cutoff,4)) + ' ' +
# str(round(to_float(sum(probs)),3)) + "]", end = "")

if temperature != 1.0:
probs = probs.pow(1.0 / temperature)

return flow.multinomial(probs, num_samples=1)[0]


def to_float(x):
return x.cpu().detach().numpy().flatten()[0].astype(float)


def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
flow.manual_seed(seed)
flow.cuda.manual_seed_all(seed)
Loading