Skip to content

Commit 09e4357

Browse files
authored
Update qwen and add pygmalion (#2607)
1 parent cbf2853 commit 09e4357

File tree

1 file changed

+31
-1
lines changed

1 file changed

+31
-1
lines changed

fastchat/model/model_adapter.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1420,6 +1420,20 @@ class QwenChatAdapter(BaseModelAdapter):
14201420
def match(self, model_path: str):
14211421
return "qwen" in model_path.lower()
14221422

1423+
def float_set(self, config, option):
1424+
config.bf16 = False
1425+
config.fp16 = False
1426+
config.fp32 = False
1427+
1428+
if option == "bf16":
1429+
config.bf16 = True
1430+
elif option == "fp16":
1431+
config.fp16 = True
1432+
elif option == "fp32":
1433+
config.fp32 = True
1434+
else:
1435+
print("Invalid option. Please choose one from 'bf16', 'fp16' and 'fp32'.")
1436+
14231437
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
14241438
from transformers.generation import GenerationConfig
14251439

@@ -1430,7 +1444,7 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
14301444
)
14311445
# NOTE: if you use the old version of model file, please remove the comments below
14321446
# config.use_flash_attn = False
1433-
config.fp16 = True
1447+
self.float_set(config, "fp16")
14341448
generation_config = GenerationConfig.from_pretrained(
14351449
model_path, trust_remote_code=True
14361450
)
@@ -1698,6 +1712,20 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
16981712
return get_conv_template("lemur-70b-chat")
16991713

17001714

1715+
class PygmalionAdapter(BaseModelAdapter):
1716+
"""The model adapter for Pygmalion/Metharme series of models(e.g., PygmalionAI/mythalion-13b)"""
1717+
1718+
# use_fast_tokenizer = False
1719+
1720+
def match(self, model_path: str):
1721+
return bool(
1722+
re.search(r"pygmalion|mythalion|metharme", model_path.lower(), re.I)
1723+
)
1724+
1725+
def get_default_conv_template(self, model_path: str) -> Conversation:
1726+
return get_conv_template("metharme")
1727+
1728+
17011729
# Note: the registration order matters.
17021730
# The one registered earlier has a higher matching priority.
17031731
register_model_adapter(PeftModelAdapter)
@@ -1760,6 +1788,8 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
17601788
register_model_adapter(ZephyrAdapter)
17611789
register_model_adapter(XwinLMAdapter)
17621790
register_model_adapter(LemurAdapter)
1791+
register_model_adapter(PygmalionAdapter)
1792+
17631793

17641794
# After all adapters, try the default base adapter.
17651795
register_model_adapter(BaseModelAdapter)

0 commit comments

Comments
 (0)