Skip to content
Prev Previous commit
Next Next commit
add nvidia steerlm & gemini streaming fix
  • Loading branch information
infwinston committed Jan 4, 2024
commit e9ef1c6befe31bdad64a76a5809fd8ccf34f95de
11 changes: 11 additions & 0 deletions fastchat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,6 +1403,17 @@ def get_conv_template(name: str) -> Conversation:
)
)

# nvidia/Llama2-70B-SteerLM-Chat
register_conv_template(
Conversation(
name="steerlm",
system_message="",
roles=("user", "assistant"),
sep_style=None,
sep=None,
)
)

if __name__ == "__main__":
from fastchat.conversation import get_conv_template

Expand Down
11 changes: 11 additions & 0 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2066,6 +2066,16 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("solar")


class SteerLMAdapter(BaseModelAdapter):
"""The model adapter for nvidia/Llama2-70B-SteerLM-Chat"""

def match(self, model_path: str):
return "steerlm-chat" in model_path.lower()

def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("steerlm")


# Note: the registration order matters.
# The one registered earlier has a higher matching priority.
register_model_adapter(PeftModelAdapter)
Expand Down Expand Up @@ -2147,6 +2157,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
register_model_adapter(MetaMathAdapter)
register_model_adapter(BagelAdapter)
register_model_adapter(SolarAdapter)
register_model_adapter(SteerLMAdapter)

# After all adapters, try the default base adapter.
register_model_adapter(BaseModelAdapter)
7 changes: 7 additions & 0 deletions fastchat/model/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,13 @@ def get_model_info(name: str) -> ModelInfo:
"Claude Instant by Anthropic",
)

register_model_info(
["llama2-70b-steerlm-chat"],
"Llama2-70B-SteerLM-Chat",
"https://huggingface.co/nvidia/Llama2-70B-SteerLM-Chat",
"A Llama fine-tuned with SteerLM method by NVIDIA",
)

register_model_info(
["pplx-70b-online", "pplx-7b-online"],
"pplx-online-llms",
Expand Down
68 changes: 61 additions & 7 deletions fastchat/serve/api_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

from json import loads
import os

import json
import random
import requests
import time

from fastchat.utils import build_logger
Expand Down Expand Up @@ -177,6 +180,12 @@ def gemini_api_stream_iter(model_name, conv, temperature, top_p, max_new_tokens)
"max_output_tokens": max_new_tokens,
"top_p": top_p,
}
params = {
"model": model_name,
"prompt": conv,
}
params.update(generation_config)
logger.info(f"==== request ====\n{params}")

safety_settings = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
Expand All @@ -195,14 +204,22 @@ def gemini_api_stream_iter(model_name, conv, temperature, top_p, max_new_tokens)
convo = model.start_chat(history=history)
response = convo.send_message(conv.messages[-2][1], stream=True)

text = ""
for chunk in response:
text += chunk.text
data = {
"text": text,
"error_code": 0,
try:
text = ""
for chunk in response:
text += chunk.text
data = {
"text": text,
"error_code": 0,
}
yield data
except Exception as e:
logger.error(f"==== error ====\n{e}")
reason = chunk.candidates
yield {
"text": f"**API REQUEST ERROR** Reason: {reason}.",
"error_code": 1,
}
yield data


def ai2_api_stream_iter(
Expand Down Expand Up @@ -319,3 +336,40 @@ def mistral_api_stream_iter(model_name, messages, temperature, top_p, max_new_to
"error_code": 0,
}
yield data


def nvidia_api_stream_iter(model_name, messages, temp, top_p, max_tokens, api_base):
assert model_name in ["llama2-70b-steerlm-chat"]

api_key = os.environ["NVIDIA_API_KEY"]
headers = {
"Authorization": f"Bearer {api_key}",
"accept": "text/event-stream",
"content-type": "application/json",
}
# nvidia api does not accept 0 temperature
if temp == 0.0:
temp = 0.0001

payload = {
"messages": messages,
"temperature": temp,
"top_p": top_p,
"max_tokens": max_tokens,
"seed": 42,
"stream": True,
}
logger.info(f"==== request ====\n{payload}")

response = requests.post(
api_base, headers=headers, json=payload, stream=True, timeout=1
)
text = ""
for line in response.iter_lines():
if line:
data = line.decode("utf-8")
if data.endswith("[DONE]"):
break
data = json.loads(data[6:])["choices"][0]["delta"]["content"]
text += data
yield {"text": text, "error_code": 0}
31 changes: 19 additions & 12 deletions fastchat/serve/gradio_block_arena_anony.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,13 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re
"claude-1": 2,
"claude-instant-1": 4,
"gemini-pro": 4,
"gemini-pro-dev-api": 4,
"pplx-7b-online": 4,
"pplx-70b-online": 4,
"solar-10.7b-instruct-v1.0": 2,
"llama2-70b-steerlm-chat": 2,
"mixtral-8x7b-instruct-v0.1": 4,
"mistral-medium": 8,
"mistral-medium": 4,
"openhermes-2.5-mistral-7b": 2,
"dolphin-2.2.1-mistral-7b": 2,
"wizardlm-70b": 2,
Expand All @@ -186,7 +188,7 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re
"openchat-3.5": 2,
"chatglm3-6b": 2,
# tier 1
"deluxe-chat-v1.2": 2,
"deluxe-chat-v1.2": 4,
"llama-2-70b-chat": 1.5,
"llama-2-13b-chat": 1.5,
"codellama-34b-instruct": 1.5,
Expand Down Expand Up @@ -254,6 +256,7 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re
"claude-1": {"claude-2.1", "gpt-4-0613", "gpt-3.5-turbo-0613"},
"claude-instant-1": {"gpt-3.5-turbo-1106", "claude-2.1"},
"gemini-pro": {"gpt-4-turbo", "gpt-4-0613", "gpt-3.5-turbo-0613"},
"gemini-pro-dev-api": {"gpt-4-turbo", "gpt-4-0613", "gpt-3.5-turbo-0613"},
"deluxe-chat-v1.1": {"gpt-4-0613", "gpt-4-turbo"},
"deluxe-chat-v1.2": {"gpt-4-0613", "gpt-4-turbo"},
"pplx-7b-online": {"gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "llama-2-70b-chat"},
Expand Down Expand Up @@ -297,19 +300,14 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re
}

SAMPLING_BOOST_MODELS = [
# "tulu-2-dpo-70b",
# "yi-34b-chat",
"claude-2.1",
# "claude-1",
# "claude-2.1",
"gpt-4-0613",
# "gpt-3.5-turbo-1106",
# "gpt-4-0314",
"gpt-4-turbo",
# "dolphin-2.2.1-mistral-7b",
# "mixtral-8x7b-instruct-v0.1",
"mistral-medium",
"llama2-70b-steerlm-chat",
"gemini-pro-dev-api",
# "gemini-pro",
# "solar-10.7b-instruct-v1.0",
]

# outage models won't be sampled.
Expand Down Expand Up @@ -479,13 +477,22 @@ def bot_response_multi(
)
)

is_gemini = []
for i in range(num_sides):
is_gemini.append("gemini" in states[i].model_name)

chatbots = [None] * num_sides
iters = 0
while True:
stop = True
iters += 1
for i in range(num_sides):
try:
ret = next(gen[i])
states[i], chatbots[i] = ret[0], ret[1]
# yield gemini fewer times as its chunk size is larger
# otherwise, gemini will stream too fast
if not is_gemini[i] or (iters % 30 == 1 or iters < 3):
ret = next(gen[i])
states[i], chatbots[i] = ret[0], ret[1]
stop = False
except StopIteration:
pass
Expand Down
13 changes: 11 additions & 2 deletions fastchat/serve/gradio_block_arena_named.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,13 +242,22 @@ def bot_response_multi(
)
)

is_gemini = []
for i in range(num_sides):
is_gemini.append("gemini" in states[i].model_name)

chatbots = [None] * num_sides
iters = 0
while True:
stop = True
iters += 1
for i in range(num_sides):
try:
ret = next(gen[i])
states[i], chatbots[i] = ret[0], ret[1]
# yield gemini fewer times as its chunk size is larger
# otherwise, gemini will stream too fast
if not is_gemini[i] or (iters % 30 == 1 or iters < 3):
ret = next(gen[i])
states[i], chatbots[i] = ret[0], ret[1]
stop = False
except StopIteration:
pass
Expand Down
11 changes: 11 additions & 0 deletions fastchat/serve/gradio_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
palm_api_stream_iter,
gemini_api_stream_iter,
mistral_api_stream_iter,
nvidia_api_stream_iter,
init_palm_chat,
)
from fastchat.utils import (
Expand Down Expand Up @@ -446,6 +447,16 @@ def bot_response(
stream_iter = mistral_api_stream_iter(
model_name, prompt, temperature, top_p, max_new_tokens
)
elif model_api_dict["api_type"] == "nvidia":
prompt = conv.to_openai_api_messages()
stream_iter = nvidia_api_stream_iter(
model_name,
prompt,
temperature,
top_p,
max_new_tokens,
model_api_dict["api_base"],
)
else:
raise NotImplementedError

Expand Down