diff --git a/fastchat/conversation.py b/fastchat/conversation.py
index 73d24a72d..47bc7c7f1 100644
--- a/fastchat/conversation.py
+++ b/fastchat/conversation.py
@@ -28,6 +28,7 @@ class SeparatorStyle(IntEnum):
PHOENIX = auto()
ROBIN = auto()
FALCON_CHAT = auto()
+ MISTRAL = auto()
@dataclasses.dataclass
@@ -213,6 +214,17 @@ def get_prompt(self) -> str:
ret += role + ":"
return ret
+ elif self.sep_style == SeparatorStyle.MISTRAL:
+ ret = self.sep
+ for i, (role, message) in enumerate(self.messages):
+ if role == "user":
+ if self.system_message and i == 0:
+ ret += " [INST] " + system_prompt + " " + message + " [/INST]"
+ else:
+ ret += " [INST] " + message + " [/INST]"
+ elif role == "assistant" and message:
+ ret += message + self.sep2 + " "
+ return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
@@ -883,10 +895,10 @@ def get_conv_template(name: str) -> Conversation:
register_conv_template(
Conversation(
name="mistral",
- system_template="[INST]{system_message}\n",
- roles=("[INST]", "[/INST]"),
- sep_style=SeparatorStyle.LLAMA2,
- sep=" ",
+ system_template="{system_message}",
+ roles=("user", "assistant"),
+ sep_style=SeparatorStyle.MISTRAL,
+ sep="",
sep2="",
)
)
@@ -1129,7 +1141,7 @@ def get_conv_template(name: str) -> Conversation:
Conversation(
name="metharme",
system_template="<|system|>{system_message}",
- system_message="""Enter RP mode. You shall reply to the user while staying
+ system_message="""Enter RP mode. You shall reply to the user while staying
in character. Your responses must be detailed, creative, immersive, and drive the scenario
forward.""",
roles=("<|user|>", "<|model|>"),
diff --git a/tests/README.md b/tests/README.md
index 3d1c1e61c..16bfa2cfd 100644
--- a/tests/README.md
+++ b/tests/README.md
@@ -51,3 +51,9 @@ PEFT_SHARE_BASE_WEIGHTS=true python3 -m fastchat.serve.multi_model_worker \
--model-path SurfaceData/dummy_pythia160m_lora16_peft_chat \
--model-path SurfaceData/dummy_pythia160m_lora8_peft_chat
```
+
+### Test chat templates
+
+```
+python3 test_chat_templates.py
+```
\ No newline at end of file
diff --git a/tests/test_chat_templates.py b/tests/test_chat_templates.py
new file mode 100644
index 000000000..72c9aa41a
--- /dev/null
+++ b/tests/test_chat_templates.py
@@ -0,0 +1,34 @@
+from fastchat.conversation import get_conv_template
+
+
+def get_sample_conversation() -> list[str]:
+ return [
+ "What is your favourite condiment?",
+ "Well, I'm quite partial to a good squeeze of fresh lemon juice. "
+ "It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!",
+ "Do you have mayonnaise recipes?",
+ "Here is a recipe for mayonnaise.",
+ ]
+
+
+def test_chat_template_mistral():
+ conversation = get_sample_conversation()
+
+ conv_template = get_conv_template("mistral")
+ conv_template.append_message(conv_template.roles[0], conversation[0])
+ conv_template.append_message(conv_template.roles[1], conversation[1])
+ conv_template.append_message(conv_template.roles[0], conversation[2])
+ conv_template.append_message(conv_template.roles[1], conversation[3])
+ prompt = conv_template.get_prompt()
+
+ expected_prompt = (
+ f" [INST] {conversation[0]} [/INST]{conversation[1]} "
+ f"[INST] {conversation[2]} [/INST]"
+ f"{conversation[3]} "
+ )
+
+ assert prompt == expected_prompt
+
+
+if __name__ == "__main__":
+ test_chat_template_mistral()