Skip to content
33 changes: 4 additions & 29 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
AutoTokenizer,
LlamaTokenizer,
LlamaForCausalLM,
T5Tokenizer,
)

from fastchat.constants import CPU_ISA
Expand Down Expand Up @@ -616,11 +615,11 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("vicuna_v1.1")


class CodeT5pAdapter(BaseModelAdapter):
"""The model adapter for Salesforce/codet5p-6b"""
class GoogleFlanAdapter(BaseModelAdapter):
"""The model adapter for google/Flan based models, such as Salesforce/codet5p-6b, lmsys/fastchat-t5-3b-v1.0, flan-t5-*, flan-ul2"""

def match(self, model_path: str):
return "codet5p" in model_path.lower()
return any(model_path in model_str for model_str in ["flan", "t5", "codet5p"])

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
Expand All @@ -634,28 +633,6 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
return model, tokenizer


class T5Adapter(BaseModelAdapter):
"""The model adapter for lmsys/fastchat-t5-3b-v1.0"""

def match(self, model_path: str):
return "t5" in model_path.lower()

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
tokenizer = T5Tokenizer.from_pretrained(model_path, revision=revision)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
)
return model, tokenizer


class FlanAdapter(T5Adapter):
"""The model adapter for flan-t5-*, flan-ul2"""

def match(self, model_path: str):
return "flan" in model_path.lower()


class KoalaAdapter(BaseModelAdapter):
"""The model adapter for koala"""

Expand Down Expand Up @@ -1597,9 +1574,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
register_model_adapter(VicunaAdapter)
register_model_adapter(AiroborosAdapter)
register_model_adapter(LongChatAdapter)
register_model_adapter(CodeT5pAdapter)
register_model_adapter(T5Adapter)
register_model_adapter(FlanAdapter)
register_model_adapter(GoogleFlanAdapter)
register_model_adapter(KoalaAdapter)
register_model_adapter(AlpacaAdapter)
register_model_adapter(ChatGLMAdapter)
Expand Down