Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 172 additions & 0 deletions projects/test_bert_load_huggingface_weight/load_huggingface_weight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from collections import OrderedDict

import oneflow as flow
import torch


def convert_tensor(tensor):
"""Convert pytorch_tensor to oneflow_tensor

Args:
tensor (torch.tensor): Weight of models' parameters in pytorch_model.bin

Returns:
oneflow.tensor
"""
tensor = tensor.float()
return flow.Tensor(tensor.cpu().numpy())


def convert_state_dict(state, layers, hidden_size, num_heads, head_size):
"""Convert pytorch_tensor to oneflow_tensor and save as a state_dict

Args:
state (OrderedDict): State_dict of Pytorch model.
layers (int): BERT's number of hidden layers.
hidden_size (int): The hidden_size of BERT.
num_heads (int): The num_head of BERT.
head_size (int): The Head_size of BERT.

Returns:
OrderedDict: State_dict of OneFlow model.
"""
save = OrderedDict()
not_saved = []
Layers = layers
for name, tensor in state.items():
if "embeddings" in name:
if "word_embeddings" in name:
save["embeddings.vocab_embeddings.weight"] = convert_tensor(tensor)
elif "position_embeddings" in name:
save["embeddings.position_embeddings.weight"] = convert_tensor(tensor)
elif "token_type_embeddings" in name:
save["embeddings.tokentype_embeddings.weight"] = convert_tensor(tensor)
elif "LayerNorm.gamma" in name:
save["encoders.0.input_layernorm.weight"] = convert_tensor(tensor)
elif "LayerNorm.beta" in name:
save["encoders.0.input_layernorm.bias"] = convert_tensor(tensor)

elif "attention" in name:
if "self" in name:
index = name.split(".")[3]
if "encoders." + index + ".self_attention.query_key_value.weight" in save.keys():
continue
q_w = name.replace(name.split(".")[6], "query").replace(
name.split(".")[7], "weight"
)
k_w = name.replace(name.split(".")[6], "key").replace(name.split(".")[7], "weight")
v_w = name.replace(name.split(".")[6], "value").replace(
name.split(".")[7], "weight"
)
q_b = name.replace(name.split(".")[6], "query").replace(name.split(".")[7], "bias")
k_b = name.replace(name.split(".")[6], "key").replace(name.split(".")[7], "bias")
v_b = name.replace(name.split(".")[6], "value").replace(name.split(".")[7], "bias")

qkv_w = torch.cat((state[q_w], state[k_w], state[v_w]), dim=0) # 【768*3, 768】

# Rearrange the loaded weights for weight, you can refer:
# https://libai.readthedocs.io/en/latest/notes/How_to_implement_huggingface%27s_weights_in_LiBai.html
qkv_w = qkv_w.view([3, num_heads, head_size, hidden_size])
qkv_w = qkv_w.permute(1, 0, 2, 3).contiguous().view(3 * hidden_size, hidden_size)

qkv_b = torch.cat((state[q_b], state[k_b], state[v_b]), dim=-1)

# # Rearrange the loaded weights for bias, you can refer:
# https://libai.readthedocs.io/en/latest/notes/How_to_implement_huggingface%27s_weights_in_LiBai.html
qkv_b = qkv_b.view(3, num_heads, head_size)
qkv_b = qkv_b.permute(1, 0, 2).contiguous().view(-1)

target_w = "encoders." + index + ".self_attention.query_key_value.weight"
save[target_w] = convert_tensor(qkv_w)
target_b = "encoders." + index + ".self_attention.query_key_value.bias"
save[target_b] = convert_tensor(qkv_b)
elif "output" in name:
index = name.split(".")[3]
if "dense" in name:
if "weight" in name:
target = "encoders." + index + ".self_attention.dense.weight"
save[target] = convert_tensor(tensor)
elif "bias" in name:
target = "encoders." + index + ".self_attention.dense.bias"
save[target] = convert_tensor(tensor)
elif "LayerNorm" in name:
if "gamma" in name:
target = "encoders." + index + ".post_attention_layernorm.weight"
save[target] = convert_tensor(tensor)
elif "beta" in name:
target = "encoders." + index + ".post_attention_layernorm.bias"
save[target] = convert_tensor(tensor)

elif "intermediate" in name:
index = name.split(".")[3]
if "encoders." + index + ".mlp.dense_h_to_4h.weight" in save.keys():
continue
w = "bert.encoder.layer." + index + ".intermediate.dense.weight"
b = "bert.encoder.layer." + index + ".intermediate.dense.bias"
t_w = "encoders." + index + ".mlp.dense_h_to_4h.weight"
t_b = "encoders." + index + ".mlp.dense_h_to_4h.bias"
save[t_w] = convert_tensor(state[w])
save[t_b] = convert_tensor(state[b])

elif "output" in name:
index = name.split(".")[3]
if "dense.weight" in name:
target = "encoders." + index + ".mlp.dense_4h_to_h.weight"
save[target] = convert_tensor(tensor)
elif "dense.bias" in name:
target = "encoders." + index + ".mlp.dense_4h_to_h.bias"
save[target] = convert_tensor(tensor)
elif "LayerNorm.gamma" in name:
if index == str(Layers - 1):
target = "final_layernorm.weight"
save[target] = convert_tensor(tensor)
continue
target = "encoders." + str(int(index) + 1) + ".input_layernorm.weight"
save[target] = convert_tensor(tensor)
elif "LayerNorm.beta" in name:
if index == str(Layers - 1):
target = "final_layernorm.bias"
save[target] = convert_tensor(tensor)
continue
target = "encoders." + str(int(index) + 1) + ".input_layernorm.bias"
save[target] = convert_tensor(tensor)

elif "pooler" in name:
if "weight" in name:
save["pooler.dense.weight"] = convert_tensor(tensor)
elif "bias" in name:
save["pooler.dense.bias"] = convert_tensor(tensor)
else:
not_saved.append(name)
return save, not_saved


def load_tensor(tensor_lhs, tensor_rhs):
"""Load the tensor to BERT.

Args:
tensor_lhs (flow.tensor): The tensor in state_dict.
tensor_rhs (flow.tensor): The tensor in LiBai's BERT.
"""
tensor_rhs = flow.to_global(tensor_rhs, placement=tensor_lhs.placement, sbp=tensor_lhs.sbp)
tensor_lhs.copy_(tensor_rhs)


def load_huggingface_bert(model, path, hidden_size, num_heads, layers=12):
"""Load Huggingface's pretrained weights in LiBai

Args:
model: BRET in LiBai.
path (str): The path of pretrained_model file.
"""
head_size = hidden_size // num_heads
huggingface_state_dict = torch.load(path)
of_state_dict, _ = convert_state_dict(
huggingface_state_dict,
layers=layers,
hidden_size=hidden_size,
num_heads=num_heads,
head_size=head_size,
)
for key, value in of_state_dict.items():
load_tensor(model.state_dict()[key], value)
2 changes: 2 additions & 0 deletions projects/test_bert_load_huggingface_weight/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# How to use Huggingface's pretrained weights in LiBai
This example shows how to convert huggingface's pretrained model in LiBai. For more details please refer to [How to use Huggingface’s pretrained weights in LiBai](https://libai.readthedocs.io/en/latest/notes/How_to_implement_huggingface%27s_weights_in_LiBai.html).
11 changes: 11 additions & 0 deletions projects/test_bert_load_huggingface_weight/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/usr/bin/env bash

PRETRAINED_PATH="./bert-base-chinese"

if [ ! -d "$PRETRAINED_PATH" ]; then
wget https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt -P ./bert-base-chinese/
wget https://huggingface.co/bert-base-chinese/resolve/main/pytorch_model.bin -P ./bert-base-chinese/
wget https://huggingface.co/bert-base-chinese/resolve/main/config.json -P ./bert-base-chinese/
fi

python3 test_output.py
85 changes: 85 additions & 0 deletions projects/test_bert_load_huggingface_weight/test_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import unittest

import numpy as np
import oneflow as flow
import torch
import transformers
from load_huggingface_weight import load_huggingface_bert

import libai
from libai.config import LazyCall
from libai.models import build_model
from libai.utils import distributed as dist


class Test_BertModel_Use_Huggingface_Weight(unittest.TestCase):
def __init__(self, methodName="runTest"):
super().__init__(methodName)
self.input_ids = [[101, 1962, 2110, 739, 999, 1, 2, 3, 4, 102]]
self.mask = [[1] * len(self.input_ids)]
# libai's config
self.cfg = dict(
vocab_size=21128,
hidden_size=768,
hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
num_tokentypes=2,
add_pooling_layer=True,
initializer_range=0.02,
layernorm_eps=1e-12,
bias_gelu_fusion=False,
bias_dropout_fusion=False,
scale_mask_softmax_fusion=False,
apply_query_key_layer_scaling=False,
add_binary_head=True,
amp_enabled=False,
apply_residual_post_layernorm=True,
)
self.bert_libai = build_model(LazyCall(libai.models.BertModel)(cfg=self.cfg))
load_huggingface_bert(
self.bert_libai,
"./bert-base-chinese/pytorch_model.bin",
self.cfg["hidden_size"],
self.cfg["num_attention_heads"],
)
self.input_ids_of = flow.tensor(
self.input_ids,
dtype=flow.long,
sbp=dist.get_nd_sbp([flow.sbp.split(0), flow.sbp.broadcast]),
placement=flow.placement("cuda" if flow.cuda.is_available() else "cpu", [0]),
)
self.mask_of = flow.tensor(
self.mask,
dtype=flow.long,
sbp=dist.get_nd_sbp([flow.sbp.split(0), flow.sbp.broadcast]),
placement=flow.placement("cuda" if flow.cuda.is_available() else "cpu", [0]),
)

# huggingface's config
self.bert_huggingface = transformers.BertModel.from_pretrained("./bert-base-chinese")
self.input_ids_pt = torch.tensor(self.input_ids)
self.mask_pt = torch.tensor(self.mask)

def test_output(self):
# libai's bert
self.bert_libai.eval()
last_hidden_state_of = self.bert_libai(self.input_ids_of, self.mask_of)[0]

# huggingface's Bert
self.bert_huggingface.eval()
last_hidden_state_pt = self.bert_huggingface(
self.input_ids_pt, self.mask_pt
).last_hidden_state

res1 = last_hidden_state_of.detach().numpy().sum()
res2 = last_hidden_state_pt.detach().numpy().sum()

self.assertTrue(np.around(res1, 4) == np.around(res2, 4))


if __name__ == "__main__":
unittest.main()