From d0daff057d99e7c692a88335494b860ba8670c04 Mon Sep 17 00:00:00 2001 From: Nithin Holla Date: Tue, 7 Nov 2023 15:49:25 +0100 Subject: [PATCH] Fix Mistral template --- fastchat/conversation.py | 22 +++++++++++++++++----- tests/README.md | 6 ++++++ tests/test_chat_templates.py | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 5 deletions(-) create mode 100644 tests/test_chat_templates.py diff --git a/fastchat/conversation.py b/fastchat/conversation.py index 73d24a72d..47bc7c7f1 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -28,6 +28,7 @@ class SeparatorStyle(IntEnum): PHOENIX = auto() ROBIN = auto() FALCON_CHAT = auto() + MISTRAL = auto() @dataclasses.dataclass @@ -213,6 +214,17 @@ def get_prompt(self) -> str: ret += role + ":" return ret + elif self.sep_style == SeparatorStyle.MISTRAL: + ret = self.sep + for i, (role, message) in enumerate(self.messages): + if role == "user": + if self.system_message and i == 0: + ret += " [INST] " + system_prompt + " " + message + " [/INST]" + else: + ret += " [INST] " + message + " [/INST]" + elif role == "assistant" and message: + ret += message + self.sep2 + " " + return ret else: raise ValueError(f"Invalid style: {self.sep_style}") @@ -883,10 +895,10 @@ def get_conv_template(name: str) -> Conversation: register_conv_template( Conversation( name="mistral", - system_template="[INST]{system_message}\n", - roles=("[INST]", "[/INST]"), - sep_style=SeparatorStyle.LLAMA2, - sep=" ", + system_template="{system_message}", + roles=("user", "assistant"), + sep_style=SeparatorStyle.MISTRAL, + sep="", sep2="", ) ) @@ -1129,7 +1141,7 @@ def get_conv_template(name: str) -> Conversation: Conversation( name="metharme", system_template="<|system|>{system_message}", - system_message="""Enter RP mode. You shall reply to the user while staying + system_message="""Enter RP mode. You shall reply to the user while staying in character. Your responses must be detailed, creative, immersive, and drive the scenario forward.""", roles=("<|user|>", "<|model|>"), diff --git a/tests/README.md b/tests/README.md index 3d1c1e61c..16bfa2cfd 100644 --- a/tests/README.md +++ b/tests/README.md @@ -51,3 +51,9 @@ PEFT_SHARE_BASE_WEIGHTS=true python3 -m fastchat.serve.multi_model_worker \ --model-path SurfaceData/dummy_pythia160m_lora16_peft_chat \ --model-path SurfaceData/dummy_pythia160m_lora8_peft_chat ``` + +### Test chat templates + +``` +python3 test_chat_templates.py +``` \ No newline at end of file diff --git a/tests/test_chat_templates.py b/tests/test_chat_templates.py new file mode 100644 index 000000000..72c9aa41a --- /dev/null +++ b/tests/test_chat_templates.py @@ -0,0 +1,34 @@ +from fastchat.conversation import get_conv_template + + +def get_sample_conversation() -> list[str]: + return [ + "What is your favourite condiment?", + "Well, I'm quite partial to a good squeeze of fresh lemon juice. " + "It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!", + "Do you have mayonnaise recipes?", + "Here is a recipe for mayonnaise.", + ] + + +def test_chat_template_mistral(): + conversation = get_sample_conversation() + + conv_template = get_conv_template("mistral") + conv_template.append_message(conv_template.roles[0], conversation[0]) + conv_template.append_message(conv_template.roles[1], conversation[1]) + conv_template.append_message(conv_template.roles[0], conversation[2]) + conv_template.append_message(conv_template.roles[1], conversation[3]) + prompt = conv_template.get_prompt() + + expected_prompt = ( + f" [INST] {conversation[0]} [/INST]{conversation[1]} " + f"[INST] {conversation[2]} [/INST]" + f"{conversation[3]} " + ) + + assert prompt == expected_prompt + + +if __name__ == "__main__": + test_chat_template_mistral()