diff --git a/fastchat/serve/api_provider.py b/fastchat/serve/api_provider.py index d1491583c..891ffba88 100644 --- a/fastchat/serve/api_provider.py +++ b/fastchat/serve/api_provider.py @@ -147,3 +147,78 @@ def palm_api_stream_iter(chat, message, temperature, top_p, max_new_tokens): "error_code": 0, } yield data + + +def ai2_api_stream_iter( + model_name, + messages, + temperature, + top_p, + max_new_tokens, + api_key=None, + api_base=None, +): + from requests import post + from json import loads + + # get keys and needed values + ai2_key = api_key or os.environ.get("AI2_API_KEY") + api_base = api_base or "https://inferd.allen.ai/api/v1/infer" + model_id = "mod_01hhgcga70c91402r9ssyxekan" + + # Make requests + gen_params = { + "model": model_name, + "prompt": messages, + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + } + logger.info(f"==== request ====\n{gen_params}") + + # AI2 uses vLLM, which requires that `top_p` be 1.0 for greedy sampling: + # https://github.com/vllm-project/vllm/blob/v0.1.7/vllm/sampling_params.py#L156-L157 + if temperature == 0.0 and top_p < 1.0: + raise ValueError("top_p must be 1 when temperature is 0.0") + + res = post( + api_base, + stream=True, + headers={"Authorization": f"Bearer {ai2_key}"}, + json={ + "model_id": model_id, + # This input format is specific to the Tulu2 model. Other models + # may require different input formats. See the model's schema + # documentation on InferD for more information. + "input": { + "messages": messages, + "opts": { + "max_tokens": max_new_tokens, + "temperature": temperature, + "top_p": top_p, + "logprobs": 1, # increase for more choices + }, + }, + }, + ) + + if res.status_code != 200: + logger.error(f"unexpected response ({res.status_code}): {res.text}") + raise ValueError("unexpected response from InferD", res) + + text = "" + for line in res.iter_lines(): + if line: + part = loads(line) + if "result" in part and "output" in part["result"]: + for t in part["result"]["output"]["text"]: + text += t + else: + logger.error(f"unexpected part: {part}") + raise ValueError("empty result in InferD response") + + data = { + "text": text, + "error_code": 0, + } + yield data