diff --git a/fastchat/llm_judge/common.py b/fastchat/llm_judge/common.py index 18255d711..4b598cefb 100644 --- a/fastchat/llm_judge/common.py +++ b/fastchat/llm_judge/common.py @@ -400,7 +400,10 @@ def play_a_match_pair(match: MatchPair, output_file: str): return result -def chat_compeletion_openai(model, conv, temperature, max_tokens): +def chat_compeletion_openai(model, conv, temperature, max_tokens, api_dict=None): + if api_dict is not None: + openai.api_base = api_dict["api_base"] + openai.api_key = api_dict["api_key"] output = API_ERROR_OUTPUT for _ in range(API_MAX_RETRY): try: @@ -421,11 +424,15 @@ def chat_compeletion_openai(model, conv, temperature, max_tokens): return output -def chat_compeletion_openai_azure(model, conv, temperature, max_tokens): +def chat_compeletion_openai_azure(model, conv, temperature, max_tokens, api_dict=None): openai.api_type = "azure" - openai.api_base = os.environ["AZURE_OPENAI_ENDPOINT"] - openai.api_key = os.environ["AZURE_OPENAI_KEY"] - openai.api_version = "2023-05-15" + openai.api_version = "2023-07-01-preview" + if api_dict is not None: + openai.api_base = api_dict["api_base"] + openai.api_key = api_dict["api_key"] + else: + openai.api_base = os.environ["AZURE_OPENAI_ENDPOINT"] + openai.api_key = os.environ["AZURE_OPENAI_KEY"] if "azure-" in model: model = model[6:] @@ -446,6 +453,12 @@ def chat_compeletion_openai_azure(model, conv, temperature, max_tokens): except openai.error.OpenAIError as e: print(type(e), e) time.sleep(API_RETRY_SLEEP) + except openai.error.InvalidRequestError as e: + print(type(e), e) + break + except KeyError: + print(response) + break return output diff --git a/fastchat/llm_judge/gen_api_answer.py b/fastchat/llm_judge/gen_api_answer.py index 53d6d18ba..be099c44d 100644 --- a/fastchat/llm_judge/gen_api_answer.py +++ b/fastchat/llm_judge/gen_api_answer.py @@ -27,7 +27,7 @@ def get_answer( question: dict, model: str, num_choices: int, max_tokens: int, answer_file: str ): - if args.force_temperature: + if args.force_temperature is not None: temperature = args.force_temperature elif question["category"] in temperature_config: temperature = temperature_config[question["category"]] diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index dfdb18e45..cf96c7933 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -944,6 +944,19 @@ def get_default_conv_template(self, model_path: str) -> Conversation: return get_conv_template("chatgpt") +class AzureOpenAIAdapter(BaseModelAdapter): + """The model adapter for Azure OpenAI""" + + def match(self, model_path: str): + return model_path in ("azure-gpt-35-turbo", "azure-gpt-4") + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("chatgpt") + + class ClaudeAdapter(BaseModelAdapter): """The model adapter for Claude""" @@ -1719,6 +1732,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation: register_model_adapter(BardAdapter) register_model_adapter(PaLM2Adapter) register_model_adapter(ChatGPTAdapter) +register_model_adapter(AzureOpenAIAdapter) register_model_adapter(ClaudeAdapter) register_model_adapter(MPTAdapter) register_model_adapter(BiLLaAdapter)