diff --git a/fastchat/model/compression.py b/fastchat/model/compression.py index 4a1d2adb7..f85ae3da4 100644 --- a/fastchat/model/compression.py +++ b/fastchat/model/compression.py @@ -11,7 +11,11 @@ from torch.nn import functional as F import torch.nn as nn from tqdm import tqdm -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModel +from transformers import AutoConfig, \ + AutoModelForCausalLM, \ + AutoTokenizer, \ + AutoModel, \ + AutoModelForSeq2SeqLM @dataclasses.dataclass @@ -123,7 +127,11 @@ def load_compress_model(model_path, device, torch_dtype, use_fast, revision="mai # some models are loaded by AutoModel but not AutoModelForCausalLM, # such as chatglm, chatglm2 try: - model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + # google/flan-* models are based on an AutoModelForSeq2SeqLM. + if 'T5Config' in str(type(config)): + model = AutoModelForSeq2SeqLM.from_config(config, trust_remote_code=True) + else: + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) except NameError: model = AutoModel.from_config(config, trust_remote_code=True) linear_weights = get_compressed_list(model) diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index 8c2fbde32..57df05dd0 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -644,6 +644,13 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict): return model, tokenizer +class FlanAdapter(T5Adapter): + """The model adapter for flan-t5-*, flan-ul2""" + + def match(self, model_path: str): + return "flan" in model_path.lower() + + class KoalaAdapter(BaseModelAdapter): """The model adapter for koala""" @@ -1587,6 +1594,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation: register_model_adapter(LongChatAdapter) register_model_adapter(CodeT5pAdapter) register_model_adapter(T5Adapter) +register_model_adapter(FlanAdapter) register_model_adapter(KoalaAdapter) register_model_adapter(AlpacaAdapter) register_model_adapter(ChatGLMAdapter) diff --git a/pyproject.toml b/pyproject.toml index 6c1d12f5e..278b17b6c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ dependencies = [ ] [project.optional-dependencies] -model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0"] +model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0", "protobuf"] webui = ["gradio"] train = ["einops", "flash-attn>=2.0", "wandb"] llm_judge = ["openai", "anthropic>=0.3", "ray"]