Skip to content

Commit ac2a899

Browse files
committed
Add llava 34b template (#3034)
1 parent b21d0f7 commit ac2a899

File tree

3 files changed

+24
-0
lines changed

3 files changed

+24
-0
lines changed

fastchat/conversation.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@ def get_prompt(self) -> str:
172172
ret = "" if system_prompt == "" else system_prompt + self.sep + "\n"
173173
for role, message in self.messages:
174174
if message:
175+
if type(message) is tuple:
176+
message, images = message
177+
message = IMAGE_PLACEHOLDER_STR * len(images) + message
175178
ret += role + "\n" + message + self.sep + "\n"
176179
else:
177180
ret += role + "\n"
@@ -1562,6 +1565,21 @@ def get_conv_template(name: str) -> Conversation:
15621565
)
15631566
)
15641567

1568+
# Llava-chatml
1569+
# reference: https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/llava/conversation.py#L361
1570+
register_conv_template(
1571+
Conversation(
1572+
name="llava-chatml",
1573+
system_template="<|im_start|>system\n{system_message}",
1574+
system_message="Answer the questions.",
1575+
roles=("<|im_start|>user", "<|im_start|>assistant"),
1576+
sep_style=SeparatorStyle.CHATML,
1577+
sep="<|im_end|>",
1578+
stop_str="<|im_end|>",
1579+
)
1580+
)
1581+
1582+
15651583
if __name__ == "__main__":
15661584
from fastchat.conversation import get_conv_template
15671585

fastchat/model/model_adapter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2217,6 +2217,10 @@ def match(self, model_path: str):
22172217
return "llava" in model_path.lower()
22182218

22192219
def get_default_conv_template(self, model_path: str) -> Conversation:
2220+
model_path = model_path.lower()
2221+
if "34b" in model_path:
2222+
return get_conv_template("llava-chatml")
2223+
22202224
return get_conv_template("vicuna_v1.1")
22212225

22222226

fastchat/model/model_registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,8 +645,10 @@ def get_model_info(name: str) -> ModelInfo:
645645

646646
register_model_info(
647647
[
648+
"llava-v1.6-34b",
648649
"llava-v1.6-vicuna-13b",
649650
"llava-v1.6-vicuna-7b",
651+
"llava-v1.6-mistral-7b",
650652
"llava-v1.5-13b",
651653
"llava-v1.5-7b",
652654
],

0 commit comments

Comments
 (0)