diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index 296b53c8f..0a660d9e9 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -23,7 +23,6 @@ AutoTokenizer, LlamaTokenizer, LlamaForCausalLM, - T5Tokenizer, ) from fastchat.constants import CPU_ISA @@ -31,9 +30,7 @@ from fastchat.modules.awq import AWQConfig, load_awq_quantized from fastchat.conversation import Conversation, get_conv_template from fastchat.model.compression import load_compress_model -from fastchat.model.llama_condense_monkey_patch import ( - replace_llama_with_condense, -) +from fastchat.model.llama_condense_monkey_patch import replace_llama_with_condense from fastchat.model.model_chatglm import generate_stream_chatglm from fastchat.model.model_codet5p import generate_stream_codet5p from fastchat.model.model_falcon import generate_stream_falcon @@ -616,11 +613,14 @@ def get_default_conv_template(self, model_path: str) -> Conversation: return get_conv_template("vicuna_v1.1") -class CodeT5pAdapter(BaseModelAdapter): - """The model adapter for Salesforce/codet5p-6b""" +class GoogleFlanAdapter(BaseModelAdapter): + """The model adapter for google/Flan based models, such as Salesforce/codet5p-6b, lmsys/fastchat-t5-3b-v1.0, flan-t5-*, flan-ul2""" def match(self, model_path: str): - return "codet5p" in model_path.lower() + return any( + model_str in model_path.lower() + for model_str in ["flan-", "fastchat-t5", "codet5p"] + ) def load_model(self, model_path: str, from_pretrained_kwargs: dict): revision = from_pretrained_kwargs.get("revision", "main") @@ -634,28 +634,6 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict): return model, tokenizer -class T5Adapter(BaseModelAdapter): - """The model adapter for lmsys/fastchat-t5-3b-v1.0""" - - def match(self, model_path: str): - return "t5" in model_path.lower() - - def load_model(self, model_path: str, from_pretrained_kwargs: dict): - revision = from_pretrained_kwargs.get("revision", "main") - tokenizer = T5Tokenizer.from_pretrained(model_path, revision=revision) - model = AutoModelForSeq2SeqLM.from_pretrained( - model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs - ) - return model, tokenizer - - -class FlanAdapter(T5Adapter): - """The model adapter for flan-t5-*, flan-ul2""" - - def match(self, model_path: str): - return "flan" in model_path.lower() - - class KoalaAdapter(BaseModelAdapter): """The model adapter for koala""" @@ -1599,9 +1577,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation: register_model_adapter(VicunaAdapter) register_model_adapter(AiroborosAdapter) register_model_adapter(LongChatAdapter) -register_model_adapter(CodeT5pAdapter) -register_model_adapter(T5Adapter) -register_model_adapter(FlanAdapter) +register_model_adapter(GoogleFlanAdapter) register_model_adapter(KoalaAdapter) register_model_adapter(AlpacaAdapter) register_model_adapter(ChatGLMAdapter)