From 8cc8861635c337a661a493bcf4207e8b4e92a4c0 Mon Sep 17 00:00:00 2001 From: lambda Date: Fri, 8 Sep 2023 07:51:03 +0000 Subject: [PATCH 1/4] add falcon 180B chat conversation template --- fastchat/conversation.py | 53 ++++++++++++++++++++-------------------- 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/fastchat/conversation.py b/fastchat/conversation.py index f733be68a..c11841503 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -27,6 +27,7 @@ class SeparatorStyle(IntEnum): RWKV = auto() PHOENIX = auto() ROBIN = auto() + FALCON_CHAT = auto() @dataclasses.dataclass @@ -200,6 +201,18 @@ def get_prompt(self) -> str: else: ret += role + ":\n" return ret + elif self.sep_style == SeparatorStyle.FALCON_CHAT: + ret = "" + for idx, (role, message) in enumerate(self.messages): + if role == "System": + assert idx == 0, f"System message must be the first message, but got {idx}-th message." + if message: + ret += role + ": " + message + self.sep + else: + assert idx == len(self.messages) - 1, f"Only the last message can be empty, but got {idx}-th message." + ret += role + ": " + + return ret else: raise ValueError(f"Invalid style: {self.sep_style}") @@ -688,33 +701,6 @@ def get_conv_template(name: str) -> Conversation: ) ) -# Falcon default template -register_conv_template( - Conversation( - name="falcon", - roles=("User", "Assistant"), - messages=[], - sep_style=SeparatorStyle.RWKV, - sep="\n", - sep2="<|endoftext|>", - stop_str="\nUser", # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text - stop_token_ids=[ - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - 10, - 11, - ], # it better only put special tokens here, because tokenizer only remove special tokens - ) -) - # ChagGPT default template register_conv_template( Conversation( @@ -905,6 +891,19 @@ def get_conv_template(name: str) -> Conversation: ) ) +# Falcon 180B chat template +register_conv_template( + Conversation( + name="falcon", + roles=("User", "Falcon", "System"), + messages=[], + sep_style=SeparatorStyle.FALCON_CHAT, + sep="\n", + sep2="<|endoftext|>", + stop_str="\nUser:", # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text + ) +) + if __name__ == "__main__": print("Vicuna template:") From dc25c9cc4f28904a04c2d1e0005f7f54f6ebf021 Mon Sep 17 00:00:00 2001 From: lambda Date: Fri, 8 Sep 2023 08:04:06 +0000 Subject: [PATCH 2/4] fix format --- fastchat/conversation.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/fastchat/conversation.py b/fastchat/conversation.py index c11841503..1d38a712e 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -205,13 +205,17 @@ def get_prompt(self) -> str: ret = "" for idx, (role, message) in enumerate(self.messages): if role == "System": - assert idx == 0, f"System message must be the first message, but got {idx}-th message." + assert ( + idx == 0 + ), f"System message must be the first message, but got {idx}-th message." if message: ret += role + ": " + message + self.sep else: - assert idx == len(self.messages) - 1, f"Only the last message can be empty, but got {idx}-th message." + assert ( + idx == len(self.messages) - 1 + ), f"Only the last message can be empty, but got {idx}-th message." ret += role + ": " - + return ret else: raise ValueError(f"Invalid style: {self.sep_style}") From e7c9a5fbe5a17334685aa458baacb1aa4398dcd2 Mon Sep 17 00:00:00 2001 From: lambda Date: Sun, 10 Sep 2023 14:16:12 +0000 Subject: [PATCH 3/4] restore falcon --- fastchat/conversation.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/fastchat/conversation.py b/fastchat/conversation.py index 1d38a712e..24b064ab1 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -705,6 +705,33 @@ def get_conv_template(name: str) -> Conversation: ) ) +# Falcon default template +register_conv_template( + Conversation( + name="falcon", + roles=("User", "Assistant"), + messages=[], + sep_style=SeparatorStyle.RWKV, + sep="\n", + sep2="<|endoftext|>", + stop_str="\nUser", # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text + stop_token_ids=[ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + ], # it better only put special tokens here, because tokenizer only remove special tokens + ) +) + # ChagGPT default template register_conv_template( Conversation( @@ -898,7 +925,7 @@ def get_conv_template(name: str) -> Conversation: # Falcon 180B chat template register_conv_template( Conversation( - name="falcon", + name="falcon-chat", roles=("User", "Falcon", "System"), messages=[], sep_style=SeparatorStyle.FALCON_CHAT, From 6aaf38736334521e01e10d1727b882b5c2b04100 Mon Sep 17 00:00:00 2001 From: lambda Date: Wed, 13 Sep 2023 09:42:59 +0000 Subject: [PATCH 4/4] add falcon model adapter to dispatch conv_template ; remove system role --- fastchat/conversation.py | 13 ++++--------- fastchat/model/model_adapter.py | 11 ++++++++++- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/fastchat/conversation.py b/fastchat/conversation.py index 24b064ab1..f49ae9040 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -203,17 +203,12 @@ def get_prompt(self) -> str: return ret elif self.sep_style == SeparatorStyle.FALCON_CHAT: ret = "" - for idx, (role, message) in enumerate(self.messages): - if role == "System": - assert ( - idx == 0 - ), f"System message must be the first message, but got {idx}-th message." + if self.system_message: + ret += "System: " + self.system_message + self.sep + for role, message in self.messages: if message: ret += role + ": " + message + self.sep else: - assert ( - idx == len(self.messages) - 1 - ), f"Only the last message can be empty, but got {idx}-th message." ret += role + ": " return ret @@ -926,7 +921,7 @@ def get_conv_template(name: str) -> Conversation: register_conv_template( Conversation( name="falcon-chat", - roles=("User", "Falcon", "System"), + roles=("User", "Falcon"), messages=[], sep_style=SeparatorStyle.FALCON_CHAT, sep="\n", diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index c1e2b2163..2405f4916 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -1092,7 +1092,7 @@ class FalconAdapter(BaseModelAdapter): """The model adapter for tiiuae/falcon-40b""" def match(self, model_path: str): - return "falcon" in model_path.lower() + return "falcon" in model_path.lower() and "chat" not in model_path.lower() def load_model(self, model_path: str, from_pretrained_kwargs: dict): revision = from_pretrained_kwargs.get("revision", "main") @@ -1113,6 +1113,14 @@ def get_default_conv_template(self, model_path: str) -> Conversation: return get_conv_template("falcon") +class FalconChatAdapter(BaseModelAdapter): + def match(self, model_path: str): + return "falcon" in model_path.lower() and "chat" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("falcon-chat") + + class TigerBotAdapter(BaseModelAdapter): """The model adapter for TigerResearch/tigerbot-7b-sft""" @@ -1614,6 +1622,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation: register_model_adapter(CamelAdapter) register_model_adapter(ChangGPTAdapter) register_model_adapter(TuluAdapter) +register_model_adapter(FalconChatAdapter) register_model_adapter(FalconAdapter) register_model_adapter(TigerBotAdapter) register_model_adapter(BaichuanAdapter)