diff --git a/fastchat/conversation.py b/fastchat/conversation.py index e7863f03b..79079ba8a 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -7,7 +7,7 @@ import dataclasses from enum import auto, IntEnum -from typing import List, Any, Dict, Union +from typing import List, Any, Dict, Union, Tuple class SeparatorStyle(IntEnum): @@ -41,7 +41,7 @@ class Conversation: # The system message system_message: str = "" # The names of two roles - roles: List[str] = (("USER", "ASSISTANT"),) + roles: Tuple[str] = ("USER", "ASSISTANT") # All messages. Each item is (role, message). messages: List[List[str]] = () # The number of few shot examples @@ -54,6 +54,8 @@ class Conversation: stop_str: Union[str, List[str]] = None # Stops generation if meeting any token in this list stop_token_ids: List[int] = None + # Tags to be used in the template + tags: Tuple[str] = None def get_prompt(self) -> str: """Get the prompt for generation.""" @@ -128,13 +130,14 @@ def get_prompt(self) -> str: else: ret = "[INST] " for i, (role, message) in enumerate(self.messages): + tag = self.tags[i % 2] if message: if i == 0: ret += message + " " else: - ret += role + " " + message + seps[i % 2] + ret += tag + " " + message + seps[i % 2] else: - ret += role + ret += tag return ret elif self.sep_style == SeparatorStyle.CHATGLM: # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 @@ -266,6 +269,7 @@ def copy(self): sep2=self.sep2, stop_str=self.stop_str, stop_token_ids=self.stop_token_ids, + tags=self.tags, ) def dict(self): @@ -846,7 +850,7 @@ def get_conv_template(name: str) -> Conversation: Conversation( name="mistral", system_template="", - roles=("[INST]", "[/INST]"), + tags=("[INST]", "[/INST]"), sep_style=SeparatorStyle.LLAMA2, sep=" ", sep2="", @@ -860,7 +864,7 @@ def get_conv_template(name: str) -> Conversation: Conversation( name="llama-2", system_template="[INST] <>\n{system_message}\n<>\n\n", - roles=("[INST]", "[/INST]"), + tags=("[INST]", "[/INST]"), sep_style=SeparatorStyle.LLAMA2, sep=" ", sep2=" ",