Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from fastchat.model.model_chatglm import generate_stream_chatglm
from fastchat.model.model_codet5p import generate_stream_codet5p
from fastchat.model.model_falcon import generate_stream_falcon
from fastchat.model.model_yuan2 import generate_stream_yuan2
from fastchat.model.model_exllama import generate_stream_exllama
from fastchat.model.model_xfastertransformer import generate_stream_xft
from fastchat.model.monkey_patch_non_inplace import (
Expand Down Expand Up @@ -388,6 +389,7 @@ def get_generate_stream_function(model: torch.nn.Module, model_path: str):
is_codet5p = "codet5p" in model_type
is_exllama = "exllama" in model_type
is_xft = "xft" in model_type
is_yuan = "yuan" in model_type

if is_chatglm:
return generate_stream_chatglm
Expand All @@ -399,6 +401,8 @@ def get_generate_stream_function(model: torch.nn.Module, model_path: str):
return generate_stream_exllama
elif is_xft:
return generate_stream_xft
elif is_yuan:
return generate_stream_yuan2

elif peft_share_base_weights and is_peft:
# Return a curried stream function that loads the right adapter
Expand All @@ -421,6 +425,8 @@ def generate_stream_peft(
is_codet5p = "codet5p" in base_model_type
is_exllama = "exllama" in base_model_type
is_xft = "xft" in base_model_type
is_yuan = "yuan" in base_model_type

generate_stream_function = generate_stream
if is_chatglm:
generate_stream_function = generate_stream_chatglm
Expand All @@ -432,6 +438,8 @@ def generate_stream_peft(
generate_stream_function = generate_stream_exllama
elif is_xft:
generate_stream_function = generate_stream_xft
elif is_yuan:
generate_stream_function = generate_stream_yuan2
for x in generate_stream_function(
model,
tokenizer,
Expand Down
139 changes: 139 additions & 0 deletions fastchat/model/model_yuan2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import gc
from threading import Thread
from typing import Iterable

import torch
import transformers
from transformers import TextIteratorStreamer, GenerationConfig

from fastchat.utils import is_partial_stop


@torch.inference_mode()
def generate_stream_yuan2(
model,
tokenizer,
params,
device,
context_len=2048,
stream_interval=2,
judge_sent_end=False,
):
prompt = params["prompt"]
len_prompt = len(prompt)
temperature = float(params.get("temperature", 1))
repetition_penalty = float(params.get("repetition_penalty", 1.0))
top_p = float(params.get("top_p", 0))
top_k = int(params.get("top_k", 1)) # -1 means disable
max_new_tokens = int(params.get("max_new_tokens", 512))
stop_str = params.get("stop", "<eod>")
echo = bool(params.get("echo", True))
stop_token_ids = params.get("stop_token_ids", None) or []
stop_token_ids.append(tokenizer("<eod>")["input_ids"][0])

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]

max_src_len = context_len - max_new_tokens - 8

input_ids = input_ids[-max_src_len:] # truncate from the left
attention_mask = attention_mask[-max_src_len:] # truncate from the left
input_echo_len = len(input_ids)

decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config)

generation_config = GenerationConfig(
max_new_tokens=max_new_tokens,
do_sample=temperature >= 1.2,
temperature=temperature,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=10,
top_p=top_p,
top_k=top_k,
)

generation_kwargs = dict(
inputs=input_ids,
attention_mask=attention_mask,
streamer=streamer,
generation_config=generation_config,
)

thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()

if echo:
# means keep the prompt
output = prompt
else:
output = ""

for i, new_text in enumerate(streamer):
output += new_text
if i % stream_interval == 0:
if echo:
rfind_start = len_prompt
else:
rfind_start = 0

partially_stopped = False
if stop_str:
if isinstance(stop_str, str):
pos = output.rfind(stop_str, rfind_start)
if pos != -1:
output = output[:pos]
else:
partially_stopped = is_partial_stop(output, stop_str)
elif isinstance(stop_str, Iterable):
for each_stop in stop_str:
pos = output.rfind(each_stop, rfind_start)
if pos != -1:
output = output[:pos]
break
else:
partially_stopped = is_partial_stop(output, each_stop)
if partially_stopped:
break
else:
raise ValueError("Invalid stop field type.")

# prevent yielding partial stop sequence
if not partially_stopped:
yield {
"text": output,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
"finish_reason": None,
}
output = output.strip()

# finish stream event, which contains finish reason
if i == max_new_tokens - 1:
finish_reason = "length"
elif partially_stopped:
finish_reason = None
else:
finish_reason = "stop"

yield {
"text": output,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
"finish_reason": finish_reason,
}

# clean
gc.collect()
torch.cuda.empty_cache()
if device == "xpu":
torch.xpu.empty_cache()
if device == "npu":
torch.npu.empty_cache()
Loading