diff --git a/docs/model_support.md b/docs/model_support.md index 9d1aedddc..042e78963 100644 --- a/docs/model_support.md +++ b/docs/model_support.md @@ -7,6 +7,8 @@ - Vicuna, Alpaca, LLaMA, Koala - example: `python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5` - [BAAI/AquilaChat-7B](https://huggingface.co/BAAI/AquilaChat-7B) +- [BAAI/AquilaChat2-7B](https://huggingface.co/BAAI/AquilaChat2-7B) +- [BAAI/AquilaChat2-34B](https://huggingface.co/BAAI/AquilaChat2-34B) - [BAAI/bge-large-en](https://huggingface.co/BAAI/bge-large-en#using-huggingface-transformers) - [baichuan-inc/baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B) - [BlinkDL/RWKV-4-Raven](https://huggingface.co/BlinkDL/rwkv-4-raven) diff --git a/fastchat/conversation.py b/fastchat/conversation.py index a8bdb1cb6..77aad9844 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -969,13 +969,57 @@ def get_conv_template(name: str) -> Conversation: name="aquila-chat", system_message="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", - roles=("Human", "Assistant", "System"), + roles=("Human", "Assistant"), sep_style=SeparatorStyle.ADD_COLON_SINGLE, sep="###", sep2="", stop_str=["###", "", "[UNK]"], ) ) +# AquilaChat2-34B default template +# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L212 +register_conv_template( + Conversation( + name="aquila-legacy", + system_message="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + roles=("### Human: ", "### Assistant: "), + offset=0, + sep_style=SeparatorStyle.NO_COLON_TWO, + sep="\n", + sep2="", + stop_str=["", "[UNK]"], + ) +) +# AquilaChat2-7B-16K and AquilaChat2-34B-16K default template +# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L227 +register_conv_template( + Conversation( + name="aquila", + system_message="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + offset=0, + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="###", + sep2="", + stop_str=["", "[UNK]"], + ) +) + +# AquilaChat2-7B default template +# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L242 +register_conv_template( + Conversation( + name="aquila-v1", + roles=("<|startofpiece|>", "<|endofpiece|>"), + offset=0, + sep_style=SeparatorStyle.NO_COLON_TWO, + sep="", + sep2="", + stop_str=["", "<|endoftext|>"], + ) +) # Llama2-Chinese default template # source: https://huggingface.co/FlagAlpha diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index b26b92491..cc50214d3 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -1532,7 +1532,13 @@ def get_default_conv_template(self, model_path: str) -> Conversation: class AquilaChatAdapter(BaseModelAdapter): - """The model adapter for BAAI/AquilaChat-7B""" + """The model adapter for BAAI/Aquila + + Now supports: + - BAAI/AquilaChat-7B + - BAAI/AquilaChat2-7B + - BAAI/AquilaChat2-34B + """ def match(self, model_path: str): return "aquila" in model_path.lower() @@ -1552,7 +1558,17 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict): return model, tokenizer def get_default_conv_template(self, model_path: str) -> Conversation: - return get_conv_template("aquila-chat") + model_path = model_path.lower() + # See: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L347 + if "aquilachat2" in model_path: + if "16k" in model_path: + return get_conv_template("aquila") + elif "34b" in model_path: + return get_conv_template("aquila-legacy") + else: + return get_conv_template("aquila-v1") + else: + return get_conv_template("aquila-chat") class Lamma2ChineseAdapter(BaseModelAdapter): diff --git a/fastchat/model/model_registry.py b/fastchat/model/model_registry.py index 3ade406b5..10af25a67 100644 --- a/fastchat/model/model_registry.py +++ b/fastchat/model/model_registry.py @@ -352,3 +352,14 @@ def get_model_info(name: str) -> ModelInfo: "https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca", "A fine-tune of [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1) using [OpenOrca dataset](https://huggingface.co/datasets/Open-Orca/OpenOrca)", ) + +register_model_info( + [ + "AquilaChat-7B", + "AquilaChat2-7B", + "AquilaChat2-34B", + ], + "Aquila-Chat", + "https://huggingface.co/BAAI/AquilaChat2-34B", + "Chat models developed by BAAI team", +)