Skip to content
Closed
12 changes: 10 additions & 2 deletions fastchat/model/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
from torch.nn import functional as F
import torch.nn as nn
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModel
from transformers import AutoConfig, \
AutoModelForCausalLM, \
AutoTokenizer, \
AutoModel, \
AutoModelForSeq2SeqLM


@dataclasses.dataclass
Expand Down Expand Up @@ -123,7 +127,11 @@ def load_compress_model(model_path, device, torch_dtype, use_fast, revision="mai
# some models are loaded by AutoModel but not AutoModelForCausalLM,
# such as chatglm, chatglm2
try:
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
# google/flan-* models are based on an AutoModelForSeq2SeqLM.
if 'T5Config' in str(type(config)):
model = AutoModelForSeq2SeqLM.from_config(config, trust_remote_code=True)
else:
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
except NameError:
model = AutoModel.from_config(config, trust_remote_code=True)
linear_weights = get_compressed_list(model)
Expand Down
8 changes: 8 additions & 0 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,13 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
return model, tokenizer


class FlanAdapter(T5Adapter):
"""The model adapter for flan-t5-*, flan-ul2"""

def match(self, model_path: str):
return "flan" in model_path.lower()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the instruction template for flan? Do we need to set it here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @merrymercy, do you mean prompt template?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wangzhen263 Any updates? IIRC, flan-t5 has a default prompt template.


class KoalaAdapter(BaseModelAdapter):
"""The model adapter for koala"""

Expand Down Expand Up @@ -1587,6 +1594,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
register_model_adapter(LongChatAdapter)
register_model_adapter(CodeT5pAdapter)
register_model_adapter(T5Adapter)
register_model_adapter(FlanAdapter)
register_model_adapter(KoalaAdapter)
register_model_adapter(AlpacaAdapter)
register_model_adapter(ChatGLMAdapter)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ dependencies = [
]

[project.optional-dependencies]
model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0"]
model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0", "protobuf"]
webui = ["gradio"]
train = ["einops", "flash-attn>=2.0", "wandb"]
llm_judge = ["openai", "anthropic>=0.3", "ray"]
Expand Down