diff --git a/fastchat/conversation.py b/fastchat/conversation.py
index e7863f03b..79079ba8a 100644
--- a/fastchat/conversation.py
+++ b/fastchat/conversation.py
@@ -7,7 +7,7 @@
import dataclasses
from enum import auto, IntEnum
-from typing import List, Any, Dict, Union
+from typing import List, Any, Dict, Union, Tuple
class SeparatorStyle(IntEnum):
@@ -41,7 +41,7 @@ class Conversation:
# The system message
system_message: str = ""
# The names of two roles
- roles: List[str] = (("USER", "ASSISTANT"),)
+ roles: Tuple[str] = ("USER", "ASSISTANT")
# All messages. Each item is (role, message).
messages: List[List[str]] = ()
# The number of few shot examples
@@ -54,6 +54,8 @@ class Conversation:
stop_str: Union[str, List[str]] = None
# Stops generation if meeting any token in this list
stop_token_ids: List[int] = None
+ # Tags to be used in the template
+ tags: Tuple[str] = None
def get_prompt(self) -> str:
"""Get the prompt for generation."""
@@ -128,13 +130,14 @@ def get_prompt(self) -> str:
else:
ret = "[INST] "
for i, (role, message) in enumerate(self.messages):
+ tag = self.tags[i % 2]
if message:
if i == 0:
ret += message + " "
else:
- ret += role + " " + message + seps[i % 2]
+ ret += tag + " " + message + seps[i % 2]
else:
- ret += role
+ ret += tag
return ret
elif self.sep_style == SeparatorStyle.CHATGLM:
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
@@ -266,6 +269,7 @@ def copy(self):
sep2=self.sep2,
stop_str=self.stop_str,
stop_token_ids=self.stop_token_ids,
+ tags=self.tags,
)
def dict(self):
@@ -846,7 +850,7 @@ def get_conv_template(name: str) -> Conversation:
Conversation(
name="mistral",
system_template="",
- roles=("[INST]", "[/INST]"),
+ tags=("[INST]", "[/INST]"),
sep_style=SeparatorStyle.LLAMA2,
sep=" ",
sep2="",
@@ -860,7 +864,7 @@ def get_conv_template(name: str) -> Conversation:
Conversation(
name="llama-2",
system_template="[INST] <>\n{system_message}\n<>\n\n",
- roles=("[INST]", "[/INST]"),
+ tags=("[INST]", "[/INST]"),
sep_style=SeparatorStyle.LLAMA2,
sep=" ",
sep2=" ",