Skip to content
Merged
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
e6f8a22
Minor fix on code style
merrymercy Aug 24, 2023
e1f92ef
merge
Trangle Aug 24, 2023
adeaea4
Merge remote-tracking branch 'upstream/main'
Trangle Aug 25, 2023
41f0a74
Merge remote-tracking branch 'upstream/main'
Trangle Aug 25, 2023
b856594
Merge remote-tracking branch 'upstream/main'
Trangle Aug 27, 2023
2714582
Merge remote-tracking branch 'upstream/main'
Trangle Aug 28, 2023
43b8e82
Merge remote-tracking branch 'upstream/main'
Trangle Aug 29, 2023
f0b64ab
Merge remote-tracking branch 'upstream/main'
Trangle Aug 30, 2023
a9ce6c8
Merge remote-tracking branch 'upstream/main'
Trangle Sep 1, 2023
1068a40
Merge remote-tracking branch 'upstream/main'
Trangle Sep 5, 2023
f96e9b2
Merge remote-tracking branch 'upstream/main'
Trangle Sep 6, 2023
d7bdfda
Merge remote-tracking branch 'upstream/main'
Trangle Sep 6, 2023
b0ac061
Merge remote-tracking branch 'upstream/main'
Trangle Sep 7, 2023
3020e1c
Merge remote-tracking branch 'upstream/main'
Trangle Sep 7, 2023
d455b67
Merge remote-tracking branch 'upstream/main'
Trangle Sep 7, 2023
6304a6e
Merge remote-tracking branch 'upstream/main'
Trangle Sep 8, 2023
33e16e1
Merge remote-tracking branch 'upstream/main'
Trangle Sep 10, 2023
a9d0130
Merge remote-tracking branch 'upstream/main'
Trangle Sep 12, 2023
e544989
Merge remote-tracking branch 'upstream/main'
Trangle Sep 18, 2023
4b33e0f
Merge remote-tracking branch 'upstream/main'
Trangle Sep 19, 2023
2959057
Merge remote-tracking branch 'upstream/main'
Trangle Sep 20, 2023
25a4aed
Merge remote-tracking branch 'upstream/main'
Trangle Sep 22, 2023
689c6e3
Merge remote-tracking branch 'upstream/main'
Trangle Oct 7, 2023
0d8736b
Merge remote-tracking branch 'upstream/main'
Trangle Oct 10, 2023
fccd524
Merge remote-tracking branch 'upstream/main'
Trangle Oct 11, 2023
df1474f
Merge remote-tracking branch 'upstream/main'
Trangle Oct 17, 2023
ef61991
Merge remote-tracking branch 'upstream/main'
Trangle Oct 19, 2023
9b512cd
Merge remote-tracking branch 'upstream/main'
Trangle Oct 25, 2023
f339331
ci
Trangle Oct 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,20 @@ class QwenChatAdapter(BaseModelAdapter):
def match(self, model_path: str):
return "qwen" in model_path.lower()

def float_set(self, config, option):
config.bf16 = False
config.fp16 = False
config.fp32 = False

if option == "bf16":
config.bf16 = True
elif option == "fp16":
config.fp16 = True
elif option == "fp32":
config.fp32 = True
else:
print("Invalid option. Please choose one from 'bf16', 'fp16' and 'fp32'.")

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
from transformers.generation import GenerationConfig

Expand All @@ -1430,7 +1444,7 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
)
# NOTE: if you use the old version of model file, please remove the comments below
# config.use_flash_attn = False
config.fp16 = True
self.float_set(config, "fp16")
generation_config = GenerationConfig.from_pretrained(
model_path, trust_remote_code=True
)
Expand Down Expand Up @@ -1698,6 +1712,20 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("lemur-70b-chat")


class PygmalionAdapter(BaseModelAdapter):
"""The model adapter for Pygmalion/Metharme series of models(e.g., PygmalionAI/mythalion-13b)"""

# use_fast_tokenizer = False

def match(self, model_path: str):
return bool(
re.search(r"pygmalion|mythalion|metharme", model_path.lower(), re.I)
)

def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("metharme")


# Note: the registration order matters.
# The one registered earlier has a higher matching priority.
register_model_adapter(PeftModelAdapter)
Expand Down Expand Up @@ -1760,6 +1788,8 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
register_model_adapter(ZephyrAdapter)
register_model_adapter(XwinLMAdapter)
register_model_adapter(LemurAdapter)
register_model_adapter(PygmalionAdapter)


# After all adapters, try the default base adapter.
register_model_adapter(BaseModelAdapter)