Skip to content
Closed
2 changes: 1 addition & 1 deletion fastchat/model/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def load_compress_model(model_path, device, torch_dtype, use_fast, revision="mai
# such as chatglm, chatglm2
try:
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
except NameError:
except (NameError, ValueError):
model = AutoModel.from_config(config, trust_remote_code=True)
linear_weights = get_compressed_list(model)
if os.path.exists(model_path):
Expand Down
8 changes: 8 additions & 0 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,13 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
return model, tokenizer


class FlanAdapter(T5Adapter):
"""The model adapter for flan-t5-*, flan-ul2"""

def match(self, model_path: str):
return "flan" in model_path.lower()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the instruction template for flan? Do we need to set it here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @merrymercy, do you mean prompt template?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wangzhen263 Any updates? IIRC, flan-t5 has a default prompt template.


class KoalaAdapter(BaseModelAdapter):
"""The model adapter for koala"""

Expand Down Expand Up @@ -1570,6 +1577,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
register_model_adapter(LongChatAdapter)
register_model_adapter(CodeT5pAdapter)
register_model_adapter(T5Adapter)
register_model_adapter(FlanAdapter)
register_model_adapter(KoalaAdapter)
register_model_adapter(AlpacaAdapter)
register_model_adapter(ChatGLMAdapter)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ dependencies = [
]

[project.optional-dependencies]
model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0"]
model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0", "protobuf"]
webui = ["gradio"]
train = ["einops", "flash-attn>=2.0", "wandb"]
llm_judge = ["openai", "anthropic>=0.3", "ray"]
Expand Down