Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Fix Mistral template
  • Loading branch information
Nithin Holla committed Nov 7, 2023
commit d0daff057d99e7c692a88335494b860ba8670c04
22 changes: 17 additions & 5 deletions fastchat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class SeparatorStyle(IntEnum):
PHOENIX = auto()
ROBIN = auto()
FALCON_CHAT = auto()
MISTRAL = auto()


@dataclasses.dataclass
Expand Down Expand Up @@ -213,6 +214,17 @@ def get_prompt(self) -> str:
ret += role + ":"

return ret
elif self.sep_style == SeparatorStyle.MISTRAL:
ret = self.sep
for i, (role, message) in enumerate(self.messages):
if role == "user":
if self.system_message and i == 0:
ret += " [INST] " + system_prompt + " " + message + " [/INST]"
else:
ret += " [INST] " + message + " [/INST]"
elif role == "assistant" and message:
ret += message + self.sep2 + " "
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")

Expand Down Expand Up @@ -883,10 +895,10 @@ def get_conv_template(name: str) -> Conversation:
register_conv_template(
Conversation(
name="mistral",
system_template="[INST]{system_message}\n",
roles=("[INST]", "[/INST]"),
sep_style=SeparatorStyle.LLAMA2,
sep=" ",
system_template="{system_message}",
roles=("user", "assistant"),
sep_style=SeparatorStyle.MISTRAL,
sep="<s>",
sep2="</s>",
)
)
Expand Down Expand Up @@ -1129,7 +1141,7 @@ def get_conv_template(name: str) -> Conversation:
Conversation(
name="metharme",
system_template="<|system|>{system_message}",
system_message="""Enter RP mode. You shall reply to the user while staying
system_message="""Enter RP mode. You shall reply to the user while staying
in character. Your responses must be detailed, creative, immersive, and drive the scenario
forward.""",
roles=("<|user|>", "<|model|>"),
Expand Down
6 changes: 6 additions & 0 deletions tests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,9 @@ PEFT_SHARE_BASE_WEIGHTS=true python3 -m fastchat.serve.multi_model_worker \
--model-path SurfaceData/dummy_pythia160m_lora16_peft_chat \
--model-path SurfaceData/dummy_pythia160m_lora8_peft_chat
```

### Test chat templates

```
python3 test_chat_templates.py
```
34 changes: 34 additions & 0 deletions tests/test_chat_templates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from fastchat.conversation import get_conv_template


def get_sample_conversation() -> list[str]:
return [
"What is your favourite condiment?",
"Well, I'm quite partial to a good squeeze of fresh lemon juice. "
"It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!",
"Do you have mayonnaise recipes?",
"Here is a recipe for mayonnaise.",
]


def test_chat_template_mistral():
conversation = get_sample_conversation()

conv_template = get_conv_template("mistral")
conv_template.append_message(conv_template.roles[0], conversation[0])
conv_template.append_message(conv_template.roles[1], conversation[1])
conv_template.append_message(conv_template.roles[0], conversation[2])
conv_template.append_message(conv_template.roles[1], conversation[3])
prompt = conv_template.get_prompt()

expected_prompt = (
f"<s> [INST] {conversation[0]} [/INST]{conversation[1]}</s> "
f"[INST] {conversation[2]} [/INST]"
f"{conversation[3]}</s> "
)

assert prompt == expected_prompt


if __name__ == "__main__":
test_chat_template_mistral()