diff --git a/fastchat/llm_judge/common.py b/fastchat/llm_judge/common.py index e36906630..57fa70048 100644 --- a/fastchat/llm_judge/common.py +++ b/fastchat/llm_judge/common.py @@ -14,7 +14,11 @@ import openai import anthropic -from fastchat.model.model_adapter import get_conversation_template, ANTHROPIC_MODEL_LIST +from fastchat.model.model_adapter import ( + get_conversation_template, + ANTHROPIC_MODEL_LIST, + OPENAI_MODEL_LIST, +) # API setting constants API_MAX_RETRY = 16 @@ -159,7 +163,7 @@ def run_judge_single(question, answer, judge, ref_answer, multi_turn=False): conv.append_message(conv.roles[0], user_prompt) conv.append_message(conv.roles[1], None) - if model in ["gpt-3.5-turbo", "gpt-4"]: + if model in OPENAI_MODEL_LIST: judgment = chat_completion_openai(model, conv, temperature=0, max_tokens=2048) elif model in ANTHROPIC_MODEL_LIST: judgment = chat_completion_anthropic( @@ -262,7 +266,7 @@ def run_judge_pair(question, answer_a, answer_b, judge, ref_answer, multi_turn=F conv.append_message(conv.roles[0], user_prompt) conv.append_message(conv.roles[1], None) - if model in ["gpt-3.5-turbo", "gpt-4"]: + if model in OPENAI_MODEL_LIST: conv.set_system_message(system_prompt) judgment = chat_completion_openai(model, conv, temperature=0, max_tokens=2048) elif model in ANTHROPIC_MODEL_LIST: diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index bb268e093..bd9eeb9fe 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -58,6 +58,17 @@ "claude-instant-1.2", ) +OPENAI_MODEL_LIST = ( + "gpt-3.5-turbo", + "gpt-3.5-turbo-0301", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-turbo", +) + class BaseModelAdapter: """The base and the default model adapter.""" @@ -1053,16 +1064,7 @@ class ChatGPTAdapter(BaseModelAdapter): """The model adapter for ChatGPT""" def match(self, model_path: str): - return model_path in ( - "gpt-3.5-turbo", - "gpt-3.5-turbo-0301", - "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-1106", - "gpt-4", - "gpt-4-0314", - "gpt-4-0613", - "gpt-4-turbo", - ) + return model_path in OPENAI_MODEL_LIST def load_model(self, model_path: str, from_pretrained_kwargs: dict): raise NotImplementedError()