diff --git a/fastchat/conversation.py b/fastchat/conversation.py index 9a485b815..76e4f151d 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -27,6 +27,7 @@ class SeparatorStyle(IntEnum): RWKV = auto() PHOENIX = auto() ROBIN = auto() + FALCON_CHAT = auto() @dataclasses.dataclass @@ -200,6 +201,17 @@ def get_prompt(self) -> str: else: ret += role + ":\n" return ret + elif self.sep_style == SeparatorStyle.FALCON_CHAT: + ret = "" + if self.system_message: + ret += "System: " + self.system_message + self.sep + for role, message in self.messages: + if message: + ret += role + ": " + message + self.sep + else: + ret += role + ": " + + return ret else: raise ValueError(f"Invalid style: {self.sep_style}") @@ -940,6 +952,19 @@ def get_conv_template(name: str) -> Conversation: ) ) +# Falcon 180B chat template +register_conv_template( + Conversation( + name="falcon-chat", + roles=("User", "Falcon"), + messages=[], + sep_style=SeparatorStyle.FALCON_CHAT, + sep="\n", + sep2="<|endoftext|>", + stop_str="\nUser:", # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text + ) +) + # Phind template register_conv_template( Conversation( diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index 028ac91f1..e6b7bd57e 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -1112,7 +1112,7 @@ class FalconAdapter(BaseModelAdapter): """The model adapter for tiiuae/falcon-40b""" def match(self, model_path: str): - return "falcon" in model_path.lower() + return "falcon" in model_path.lower() and "chat" not in model_path.lower() def load_model(self, model_path: str, from_pretrained_kwargs: dict): revision = from_pretrained_kwargs.get("revision", "main") @@ -1133,6 +1133,14 @@ def get_default_conv_template(self, model_path: str) -> Conversation: return get_conv_template("falcon") +class FalconChatAdapter(BaseModelAdapter): + def match(self, model_path: str): + return "falcon" in model_path.lower() and "chat" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("falcon-chat") + + class TigerBotAdapter(BaseModelAdapter): """The model adapter for TigerResearch/tigerbot-7b-sft""" @@ -1647,6 +1655,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation: register_model_adapter(CamelAdapter) register_model_adapter(ChangGPTAdapter) register_model_adapter(TuluAdapter) +register_model_adapter(FalconChatAdapter) register_model_adapter(FalconAdapter) register_model_adapter(TigerBotAdapter) register_model_adapter(BaichuanAdapter)