Skip to content

Commit b49d789

Browse files
authored
Added google/flan models and fixed AutoModelForSeq2SeqLM when loading T5 compression model (lm-sys#2402)
1 parent a8088ba commit b49d789

File tree

3 files changed

+23
-3
lines changed

3 files changed

+23
-3
lines changed

fastchat/model/compression.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,13 @@
1111
from torch.nn import functional as F
1212
import torch.nn as nn
1313
from tqdm import tqdm
14-
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModel
14+
from transformers import (
15+
AutoConfig,
16+
AutoModelForCausalLM,
17+
AutoTokenizer,
18+
AutoModel,
19+
AutoModelForSeq2SeqLM,
20+
)
1521

1622

1723
@dataclasses.dataclass
@@ -123,7 +129,13 @@ def load_compress_model(model_path, device, torch_dtype, use_fast, revision="mai
123129
# some models are loaded by AutoModel but not AutoModelForCausalLM,
124130
# such as chatglm, chatglm2
125131
try:
126-
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
132+
# google/flan-* models are based on an AutoModelForSeq2SeqLM.
133+
if "T5Config" in str(type(config)):
134+
model = AutoModelForSeq2SeqLM.from_config(
135+
config, trust_remote_code=True
136+
)
137+
else:
138+
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
127139
except NameError:
128140
model = AutoModel.from_config(config, trust_remote_code=True)
129141
linear_weights = get_compressed_list(model)

fastchat/model/model_adapter.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,13 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
649649
return model, tokenizer
650650

651651

652+
class FlanAdapter(T5Adapter):
653+
"""The model adapter for flan-t5-*, flan-ul2"""
654+
655+
def match(self, model_path: str):
656+
return "flan" in model_path.lower()
657+
658+
652659
class KoalaAdapter(BaseModelAdapter):
653660
"""The model adapter for koala"""
654661

@@ -1592,6 +1599,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
15921599
register_model_adapter(LongChatAdapter)
15931600
register_model_adapter(CodeT5pAdapter)
15941601
register_model_adapter(T5Adapter)
1602+
register_model_adapter(FlanAdapter)
15951603
register_model_adapter(KoalaAdapter)
15961604
register_model_adapter(AlpacaAdapter)
15971605
register_model_adapter(ChatGLMAdapter)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ dependencies = [
1919
]
2020

2121
[project.optional-dependencies]
22-
model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0"]
22+
model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0", "protobuf"]
2323
webui = ["gradio"]
2424
train = ["einops", "flash-attn>=2.0", "wandb"]
2525
llm_judge = ["openai", "anthropic>=0.3", "ray"]

0 commit comments

Comments
 (0)