Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions projects/test_bert_load_huggingface_weight/load_huggingface_weight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from collections import OrderedDict
import oneflow as flow
import torch


def convert_tensor(tensor):
tensor = tensor.float()
return flow.Tensor(tensor.cpu().numpy())

def conver_state(state, layers, hidden_size, num_heads, head_size):
save = OrderedDict()
not_saved = []
Layers = layers
for name, tensor in state.items():
if 'embeddings' in name:
if 'word_embeddings' in name:
save['embeddings.vocab_embeddings.weight'] = convert_tensor(tensor)
elif 'position_embeddings' in name:
save['embeddings.position_embeddings.weight'] = convert_tensor(tensor)
elif 'token_type_embeddings' in name:
save['embeddings.tokentype_embeddings.weight'] = convert_tensor(tensor)
elif 'LayerNorm.gamma' in name:
save['encoders.0.input_layernorm.weight'] = convert_tensor(tensor)
elif 'LayerNorm.beta' in name:
save['encoders.0.input_layernorm.bias'] = convert_tensor(tensor)

elif 'attention' in name:
if 'self' in name:
index = name.split('.')[3]
if 'encoders.'+ index +'.self_attention.query_key_value.weight' in save.keys():
continue
q_w = name.replace(name.split('.')[6], 'query').replace(name.split('.')[7], 'weight')
k_w = name.replace(name.split('.')[6], 'key').replace(name.split('.')[7], 'weight')
v_w = name.replace(name.split('.')[6], 'value').replace(name.split('.')[7], 'weight')
q_b = name.replace(name.split('.')[6], 'query').replace(name.split('.')[7], 'bias')
k_b = name.replace(name.split('.')[6], 'key').replace(name.split('.')[7], 'bias')
v_b = name.replace(name.split('.')[6], 'value').replace(name.split('.')[7], 'bias')

qkv_w = torch.cat((state[q_w], state[k_w], state[v_w]), dim=0) # 【768*3, 768】
# function for weight-----------------------------------
qkv_w = qkv_w.view([3, num_heads, head_size, hidden_size])
qkv_w = qkv_w.permute(1,0,2,3).contiguous().view(3*hidden_size, hidden_size)
# ---------------------------------------------------------

qkv_b = torch.cat((state[q_b], state[k_b], state[v_b]), dim=-1)
# function for bias--------------------------------------
qkv_b = qkv_b.view(3, num_heads, head_size)
qkv_b = qkv_b.permute(1,0,2).contiguous().view(-1)
# ---------------------------------------------------------

target_w = 'encoders.'+ index + '.self_attention.query_key_value.weight'
save[target_w] = convert_tensor(qkv_w)
target_b = 'encoders.'+ index + '.self_attention.query_key_value.bias'
save[target_b] = convert_tensor(qkv_b)
elif 'output' in name:
index = name.split('.')[3]
if 'dense' in name:
if 'weight' in name:
target = 'encoders.'+ index +'.self_attention.dense.weight'
save[target] = convert_tensor(tensor)
elif 'bias' in name:
target = 'encoders.'+ index +'.self_attention.dense.bias'
save[target] = convert_tensor(tensor)
elif 'LayerNorm' in name:
if 'gamma' in name:
target = 'encoders.'+ index +'.post_attention_layernorm.weight'
save[target] = convert_tensor(tensor)
elif 'beta' in name:
target = 'encoders.'+ index +'.post_attention_layernorm.bias'
save[target] = convert_tensor(tensor)

elif 'intermediate' in name:
index = name.split('.')[3]
if 'encoders.'+ index +'.mlp.dense_h_to_4h.weight' in save.keys():
continue
w = 'bert.encoder.layer.'+ index + '.intermediate.dense.weight'
b = 'bert.encoder.layer.'+ index + '.intermediate.dense.bias'
t_w = 'encoders.'+ index +'.mlp.dense_h_to_4h.weight'
t_b = 'encoders.'+ index +'.mlp.dense_h_to_4h.bias'
save[t_w] = convert_tensor(state[w])
save[t_b] = convert_tensor(state[b])

elif 'output' in name:
index = name.split('.')[3]
if 'dense.weight' in name:
target = 'encoders.' + index + '.mlp.dense_4h_to_h.weight'
save[target] = convert_tensor(tensor)
elif 'dense.bias' in name:
target = 'encoders.'+ index +'.mlp.dense_4h_to_h.bias'
save[target] = convert_tensor(tensor)
elif 'LayerNorm.gamma' in name:
if index == str(Layers-1):
target = 'final_layernorm.weight'
save[target] = convert_tensor(tensor)
continue
target = 'encoders.' + str(int(index)+1) + '.input_layernorm.weight'
save[target] = convert_tensor(tensor)
elif 'LayerNorm.beta' in name:
if index == str(Layers-1):
target = 'final_layernorm.bias'
save[target] = convert_tensor(tensor)
continue
target = 'encoders.' + str(int(index)+1) + '.input_layernorm.bias'
save[target] = convert_tensor(tensor)

elif 'pooler' in name:
if 'weight' in name:
save['pooler.dense.weight'] = convert_tensor(tensor)
elif 'bias' in name:
save['pooler.dense.bias'] = convert_tensor(tensor)
else:
not_saved.append(name)
return save, not_saved


def load_tensor(tensor_lhs, tensor_rhs):
tensor_rhs = flow.to_global(tensor_rhs, placement=tensor_lhs.placement, sbp=tensor_lhs.sbp)
tensor_lhs.copy_(tensor_rhs)


def load_huggingface_bert(model, path, hidden_size, num_heads, layers=12):
head_size = hidden_size // num_heads
huggingface_state_dict = torch.load(path)
of_state_dict, _ = conver_state(huggingface_state_dict, layers=layers, hidden_size=hidden_size, num_heads=num_heads, head_size=head_size)
for key, value in of_state_dict.items():
load_tensor(model.state_dict()[key], value)
11 changes: 11 additions & 0 deletions projects/test_bert_load_huggingface_weight/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/usr/bin/env bash

PRETRAINED_PATH="./bert-base-chinese"

if [ ! -d "$PRETRAINED_PATH" ]; then
wget https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt -P ./bert-base-chinese/
wget https://huggingface.co/bert-base-chinese/resolve/main/pytorch_model.bin -P ./bert-base-chinese/
wget https://huggingface.co/bert-base-chinese/resolve/main/config.json -P ./bert-base-chinese/
fi

python3 test_output.py
56 changes: 56 additions & 0 deletions projects/test_bert_load_huggingface_weight/test_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import oneflow as flow
import libai
from libai.models import build_model
from libai.config import LazyCall
from load_huggingface_weight import load_huggingface_bert
from libai.utils import distributed as dist
import transformers
import torch
import numpy as np


input_ids = [[101, 1962, 2110, 739, 999, 1, 2, 3, 102]]
mask = [[1]*len(input_ids)]

# libai's Bert
cfg = dict(
vocab_size=21128,
hidden_size=768,
hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
num_tokentypes=2,
add_pooling_layer=True,
initializer_range=0.02,
layernorm_eps=1e-12,
bias_gelu_fusion=False, #
bias_dropout_fusion=False,#
scale_mask_softmax_fusion=False,
apply_query_key_layer_scaling=False,#
add_binary_head=True,
amp_enabled=False,
apply_residual_post_layernorm=True
)
bert_lib = build_model(LazyCall(libai.models.BertModel)(cfg=cfg))
load_huggingface_bert(bert_lib, './bert-base-chinese/pytorch_model.bin', cfg['hidden_size'], cfg['num_attention_heads'])
input_of = flow.tensor(input_ids, dtype=flow.long, sbp=dist.get_nd_sbp([flow.sbp.split(0), flow.sbp.broadcast]), placement=flow.placement("cuda" if flow.cuda.is_available() else "cpu", [0]),)
mask_of = flow.tensor(mask, dtype=flow.long, sbp=dist.get_nd_sbp([flow.sbp.split(0), flow.sbp.broadcast]), placement=flow.placement("cuda" if flow.cuda.is_available() else "cpu", [0]),)
bert_lib.eval()
last_hidden_state_of, pooler_output_of = bert_lib(input_of, mask_of)


# huggingface's Bert
bert_hug = transformers.BertModel.from_pretrained('./bert-base-chinese')
bert_hug.eval()
input_pt = torch.tensor(input_ids)
mask_pt = torch.tensor(mask)
last_hidden_state_pt = bert_hug(input_pt, mask_pt).last_hidden_state

res1 = last_hidden_state_of.detach().numpy()
res2 = last_hidden_state_pt.detach().numpy()

print(res1.sum())
print(res2.sum())