Skip to content
Prev Previous commit
Next Next commit
Revert "fix code style"
This reverts commit 9ce4681.
  • Loading branch information
wangzhen263 committed Sep 13, 2023
commit 5881a9543758bf3bc10839d9cec78e1b86532496
53 changes: 38 additions & 15 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
from fastchat.modules.awq import AWQConfig, load_awq_quantized
from fastchat.conversation import Conversation, get_conv_template
from fastchat.model.compression import load_compress_model
from fastchat.model.llama_condense_monkey_patch import replace_llama_with_condense
from fastchat.model.llama_condense_monkey_patch import (
replace_llama_with_condense,
)
from fastchat.model.model_chatglm import generate_stream_chatglm
from fastchat.model.model_codet5p import generate_stream_codet5p
from fastchat.model.model_falcon import generate_stream_falcon
Expand Down Expand Up @@ -212,7 +214,7 @@ def load_model(

if "max_memory" in kwargs:
kwargs["max_memory"]["cpu"] = (
str(math.floor(psutil.virtual_memory().available / 2 ** 20)) + "Mib"
str(math.floor(psutil.virtual_memory().available / 2**20)) + "Mib"
)
kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_8bit_fp32_cpu_offload=cpu_offloading
Expand Down Expand Up @@ -531,7 +533,9 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
model_path, use_fast=self.use_fast_tokenizer, revision=revision
)
model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs,
model_path,
low_cpu_mem_usage=True,
**from_pretrained_kwargs,
)
self.raise_warning_for_old_weights(model)
return model, tokenizer
Expand Down Expand Up @@ -601,7 +605,9 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
model_path, use_fast=self.use_fast_tokenizer, revision=revision
)
model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs,
model_path,
low_cpu_mem_usage=True,
**from_pretrained_kwargs,
)
return model, tokenizer

Expand All @@ -613,9 +619,7 @@ class GoogleFlanAdapter(BaseModelAdapter):
"""The model adapter for google/Flan based models, such as Salesforce/codet5p-6b, lmsys/fastchat-t5-3b-v1.0, flan-t5-*, flan-ul2"""

def match(self, model_path: str):
return any(
model_path in model_str for model_str in ["flan-", "fastchat-t5", "codet5p"]
)
return any(model_path in model_str for model_str in ["flan-", "fastchat-t5", "codet5p"])

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
Expand Down Expand Up @@ -686,7 +690,9 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision)
model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs,
model_path,
low_cpu_mem_usage=True,
**from_pretrained_kwargs,
)
# 50277 means "### End"
tokenizer.eos_token_id = 50277
Expand Down Expand Up @@ -935,7 +941,9 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision)
model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs,
model_path,
low_cpu_mem_usage=True,
**from_pretrained_kwargs,
)
return model, tokenizer

Expand Down Expand Up @@ -1103,7 +1111,9 @@ def match(self, model_path: str):
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True, revision=revision,
model_path,
trust_remote_code=True,
revision=revision,
)
model = AutoModelForCausalLM.from_pretrained(
model_path,
Expand Down Expand Up @@ -1267,7 +1277,9 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
model_path, use_fast=self.use_fast_tokenizer, revision=revision
)
model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs,
model_path,
low_cpu_mem_usage=True,
**from_pretrained_kwargs,
).eval()
return model, tokenizer

Expand Down Expand Up @@ -1314,7 +1326,10 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
from transformers.generation import GenerationConfig

revision = from_pretrained_kwargs.get("revision", "main")
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True,)
config = AutoConfig.from_pretrained(
model_path,
trust_remote_code=True,
)
# NOTE: if you use the old version of model file, please remove the comments below
# config.use_flash_attn = False
config.fp16 = True
Expand Down Expand Up @@ -1356,7 +1371,10 @@ def match(self, model_path: str):

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
model = AutoModel.from_pretrained(model_path, **from_pretrained_kwargs,)
model = AutoModel.from_pretrained(
model_path,
**from_pretrained_kwargs,
)
tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True, revision=revision
)
Expand All @@ -1382,7 +1400,10 @@ def match(self, model_path: str):

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
model = AutoModel.from_pretrained(model_path, **from_pretrained_kwargs,)
model = AutoModel.from_pretrained(
model_path,
**from_pretrained_kwargs,
)
tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True, revision=revision
)
Expand Down Expand Up @@ -1431,7 +1452,9 @@ def match(self, model_path: str):
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True, revision=revision,
model_path,
trust_remote_code=True,
revision=revision,
)
model = AutoModelForCausalLM.from_pretrained(
model_path,
Expand Down