Skip to content
Prev Previous commit
Next Next commit
add ai2 api
  • Loading branch information
infwinston committed Jan 6, 2024
commit 5de0697a22115a53f34d11c6fd4955e484cd422b
7 changes: 3 additions & 4 deletions fastchat/serve/api_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,19 +224,17 @@ def gemini_api_stream_iter(model_name, conv, temperature, top_p, max_new_tokens)

def ai2_api_stream_iter(
model_name,
model_id,
messages,
temperature,
top_p,
max_new_tokens,
api_key=None,
api_base=None,
):
from requests import post

# get keys and needed values
ai2_key = api_key or os.environ.get("AI2_API_KEY")
api_base = api_base or "https://inferd.allen.ai/api/v1/infer"
model_id = "mod_01hhgcga70c91402r9ssyxekan"

# Make requests
gen_params = {
Expand All @@ -253,7 +251,7 @@ def ai2_api_stream_iter(
if temperature == 0.0 and top_p < 1.0:
raise ValueError("top_p must be 1 when temperature is 0.0")

res = post(
res = requests.post(
api_base,
stream=True,
headers={"Authorization": f"Bearer {ai2_key}"},
Expand All @@ -272,6 +270,7 @@ def ai2_api_stream_iter(
},
},
},
timeout=5,
)

if res.status_code != 200:
Expand Down
13 changes: 13 additions & 0 deletions fastchat/serve/gradio_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
gemini_api_stream_iter,
mistral_api_stream_iter,
nvidia_api_stream_iter,
ai2_api_stream_iter,
init_palm_chat,
)
from fastchat.utils import (
Expand Down Expand Up @@ -457,6 +458,18 @@ def bot_response(
max_new_tokens,
model_api_dict["api_base"],
)
elif model_api_dict["api_type"] == "ai2":
prompt = conv.to_openai_api_messages()
stream_iter = ai2_api_stream_iter(
model_name,
model_api_dict["model_name"],
prompt,
temperature,
top_p,
max_new_tokens,
api_base=model_api_dict["api_base"],
api_key=model_api_dict["api_key"],
)
else:
raise NotImplementedError

Expand Down