Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
codegeex support oneflow backend
  • Loading branch information
BBuf committed Feb 13, 2023
commit 8869145a014b2f051ca80c63bfab0b88ed37d63e
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
__pycache__/
codegeex.egg-info/
168 changes: 91 additions & 77 deletions codegeex/oneflow/codegeex_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,100 +328,114 @@ def forward(
if get_key_value:
present = (key_layer, value_layer)

# ===================================
# Raw attention scores. [b, np, sq, sk]
# ===================================
origin_query_layer = query_layer
origin_key_layer = key_layer
origin_value_layer = value_layer

if hasattr(torch._C, 'fused_multi_head_attention_inference'):
if layer_past is not None:
context_layer = torch._C.fused_multi_head_attention_inference(
origin_query_layer.view(query_layer.size()[1], query_layer.size()[0], -1), origin_key_layer.view(key_layer.size()[1], key_layer.size()[0], -1), origin_value_layer.view(value_layer.size()[1], value_layer.size()[0], -1), self.num_attention_heads, causal=False
).transpose(0, 1)
else:
context_layer = torch._C.fused_multi_head_attention_inference(
origin_query_layer.view(query_layer.size()[1], query_layer.size()[0], -1), origin_key_layer.view(key_layer.size()[1], key_layer.size()[0], -1), origin_value_layer.view(value_layer.size()[1], value_layer.size()[0], -1), self.num_attention_heads, causal=True
).transpose(0, 1)
else:
# ===================================
# Raw attention scores. [b, np, sq, sk]
# ===================================

# [b, np, sq, sk]
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0))
# [b, np, sq, sk]
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0))

# [s, b, np, hn] -> [s, b * np, hn]
query_layer = query_layer.contiguous().view(output_size[2], output_size[0] * output_size[1], -1)
key_layer = key_layer.contiguous().view(output_size[3], output_size[0] * output_size[1], -1)
# [s, b, np, hn] -> [s, b * np, hn]
query_layer = query_layer.contiguous().view(output_size[2], output_size[0] * output_size[1], -1)
key_layer = key_layer.contiguous().view(output_size[3], output_size[0] * output_size[1], -1)

# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.matmul(query_layer.transpose(0, 1),
key_layer.transpose(0, 1).transpose(1, 2)) / self.norm_factor
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.matmul(query_layer.transpose(0, 1),
key_layer.transpose(0, 1).transpose(1, 2)) / self.norm_factor

# change view to [b, np, s, s]
attention_scores = matmul_result.view(*output_size)
# change view to [b, np, s, s]
attention_scores = matmul_result.view(*output_size)

# ==================================================
# Update attention mask for inference. [b, np, sq, sk]
# ==================================================
# ==================================================
# Update attention mask for inference. [b, np, sq, sk]
# ==================================================

if get_key_value:
with torch.no_grad():
if layer_past is not None:
attention_mask = attention_mask[
...,
attention_scores.size(3) - 1,
:attention_scores.size(3)].unsqueeze(2)
if get_key_value:
with torch.no_grad():
if layer_past is not None:
attention_mask = attention_mask[
...,
attention_scores.size(3) - 1,
:attention_scores.size(3)].unsqueeze(2)
else:
attention_mask = attention_mask[
...,
:attention_scores.size(3),
:attention_scores.size(3)]

if context_length is not None:
attention_mask = torch.clone(attention_mask)
attention_mask[:, :, context_length:, :] = True

# attention scores and attention mask [b, np, sq, sk]
# attention_scores = attention_mask_func(attention_scores, attention_mask)
if hasattr(torch._C, 'fused_scale_mask_softmax'):
attention_mask = ~attention_mask
if self.attention_softmax_in_fp32:
attention_probs = torch._C.fused_scale_mask_softmax(attention_scores.float(), attention_mask, fill_value=-10000.0, scale=1.0).half()
else:
attention_mask = attention_mask[
...,
:attention_scores.size(3),
:attention_scores.size(3)]

if context_length is not None:
attention_mask = torch.clone(attention_mask)
attention_mask[:, :, context_length:, :] = True

# attention scores and attention mask [b, np, sq, sk]
# attention_scores = attention_mask_func(attention_scores, attention_mask)
if hasattr(torch._C, 'fused_scale_mask_softmax'):
attention_mask = ~attention_mask
if self.attention_softmax_in_fp32:
attention_probs = torch._C.fused_scale_mask_softmax(attention_scores.float(), attention_mask, fill_value=-10000.0, scale=1.0).half()
else:
attention_probs = torch._C.fused_scale_mask_softmax(attention_scores, attention_mask, fill_value=-10000.0, scale=1.0)
else:
attention_scores = attention_scores - attention_mask * 10000.0
if self.attention_softmax_in_fp32:
attention_probs = self.softmax(attention_scores.float()).half()
attention_probs = torch._C.fused_scale_mask_softmax(attention_scores, attention_mask, fill_value=-10000.0, scale=1.0)
else:
attention_probs = self.softmax(attention_scores)

# =========================
# Context layer. [sq, b, hp]
# =========================
attention_scores = attention_scores - attention_mask * 10000.0
if self.attention_softmax_in_fp32:
attention_probs = self.softmax(attention_scores.float()).half()
else:
attention_probs = self.softmax(attention_scores)

# =========================
# Context layer. [sq, b, hp]
# =========================

# value_layer -> context layer.
# [sq, b, np, hn] --> [b, np, sq, hn]
# value_layer -> context layer.
# [sq, b, np, hn] --> [b, np, sq, hn]

# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3))
# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3))

# change view [sq, b * np, hn]
value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
# change view [sq, b * np, hn]
value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)

# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1],
output_size[2], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1],
output_size[2], -1)

# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0))
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0))

# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)

# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size,)
context_layer = context_layer.view(*new_context_layer_shape)
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size,)
context_layer = context_layer.view(*new_context_layer_shape)

# =================
# Output. [sq, b, h]
# =================
# =================
# Output. [sq, b, h]
# =================

output = self.dense(context_layer)

Expand Down
204 changes: 204 additions & 0 deletions tests/test_inference_oneflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@

import os
import copy
import time
import oneflow as torch
import random
import argparse
import numpy as np

from codegeex.oneflow.inference import get_token_stream
from codegeex.oneflow import CodeGeeXModel
from codegeex.tokenizer import CodeGeeXTokenizer
from codegeex.quantization import quantize
os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR"] = "1"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

优化4: matmul支持和bias_add进行融合。


def model_provider(args):
"""Build the model."""

model = CodeGeeXModel(
args.hidden_size,
args.num_layers,
args.num_attention_heads,
args.padded_vocab_size,
args.max_position_embeddings
)

return model


def add_code_generation_args(parser):
group = parser.add_argument_group(title="code generation")
group.add_argument(
"--num-layers",
type=int,
default=39,
)
group.add_argument(
"--hidden-size",
type=int,
default=5120,
)
group.add_argument(
"--num-attention-heads",
type=int,
default=40,
)
group.add_argument(
"--padded-vocab-size",
type=int,
default=52224,
)
group.add_argument(
"--max-position-embeddings",
type=int,
default=2048,
)
group.add_argument(
"--temperature",
type=float,
default=1.0,
help="Sampling temperature.",
)
group.add_argument(
"--greedy",
action="store_true",
default=False,
help="Use greedy sampling.",
)
group.add_argument(
"--top-p",
type=float,
default=0.0,
help="Top p sampling.",
)
group.add_argument(
"--top-k",
type=int,
default=0,
help="Top k sampling.",
)
group.add_argument(
"--out-seq-length",
type=int,
default=2048,
help="Size of the output generated text.",
)
group.add_argument(
"--prompt-file",
type=str,
default="./test_prompt.txt",
)
group.add_argument(
"--tokenizer-path",
type=str,
default="./tokenizer",
)
group.add_argument(
"--load",
type=str,
)
group.add_argument(
"--state-dict-path",
type=str,
)
group.add_argument(
"--micro-batch-size",
type=int,
default=1,
)
group.add_argument(
"--quantize",
action="store_true",
)

return parser


def main():
parser = argparse.ArgumentParser()
parser = add_code_generation_args(parser)
args, _ = parser.parse_known_args()

print("Loading tokenizer ...")
tokenizer = CodeGeeXTokenizer(
tokenizer_path=args.tokenizer_path,
mode="codegeex-13b")

print("Loading state dict ...")
state_dict = torch.load(args.load, map_location="cpu")
state_dict = state_dict["module"]

print("Building CodeGeeX model ...")
model = model_provider(args)
model.load_state_dict(state_dict)
model.eval()
model.half()
if args.quantize:
model = quantize(model, weight_bit_width=8, backend="torch")
model.cuda()
torch.cuda.synchronize()
with open(args.prompt_file, "r") as f:
prompt = f.readlines()
prompt = "".join(prompt)

times = {}
out_seq_lengths = [args.out_seq_length]
micro_batch_size = args.micro_batch_size
seq_length = args.max_position_embeddings
for out_seq_length in out_seq_lengths:
print(f"Generating with out_seq_len {out_seq_length}...")

times[out_seq_length] = []
for prompt in [prompt]:
t0 = time.perf_counter()
tokens = tokenizer.encode_code(prompt)
print(tokens)
print("Current prompt:")
print(prompt)
n_token_prompt = len(tokens)
print("N_token_prompt:", n_token_prompt)
token_stream = get_token_stream(
model,
tokenizer,
seq_length,
out_seq_length,
[copy.deepcopy(tokens) for _ in range(micro_batch_size)],
micro_batch_size=micro_batch_size,
topk=args.top_k,
topp=args.top_p,
temperature=args.temperature,
greedy=args.greedy,
)
is_finished = [False for _ in range(micro_batch_size)]
for i, generated in enumerate(token_stream):
generated_tokens = generated[0]
for j in range(micro_batch_size):
if is_finished[j]:
continue
generated_token_numpy = generated_tokens[j].numpy()
if generated_token_numpy[-1] == tokenizer.eos_token_id or len(
generated_tokens[j]) >= out_seq_length:
is_finished[j] = True
generated_tokens_ = generated_token_numpy.tolist()
generated_code = tokenizer.decode_code(generated_tokens_[n_token_prompt:])
generated_code = "".join(generated_code)
t1 = time.perf_counter()
print("Total generation time:", t1 - t0, "# Tokens:", len(generated_tokens_) - n_token_prompt)
print(f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token")
times[out_seq_length].append(t1 - t0)
print("================================= Generated code:")
print(generated_code)

if all(is_finished):
break

print(times)
for out_seq_length in times.keys():
print(out_seq_length, np.mean(times[out_seq_length]))

print("Generation finished.")


if __name__ == "__main__":
main()