Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add nvidia steerlm & gemini streaming fix
  • Loading branch information
infwinston committed Jan 4, 2024
commit e9ef1c6befe31bdad64a76a5809fd8ccf34f95de
11 changes: 11 additions & 0 deletions fastchat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,6 +1403,17 @@ def get_conv_template(name: str) -> Conversation:
)
)

# nvidia/Llama2-70B-SteerLM-Chat
register_conv_template(
Conversation(
name="steerlm",
system_message="",
roles=("user", "assistant"),
sep_style=None,
sep=None,
)
)

if __name__ == "__main__":
from fastchat.conversation import get_conv_template

Expand Down
11 changes: 11 additions & 0 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2066,6 +2066,16 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("solar")


class SteerLMAdapter(BaseModelAdapter):
"""The model adapter for nvidia/Llama2-70B-SteerLM-Chat"""

def match(self, model_path: str):
return "steerlm-chat" in model_path.lower()

def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("steerlm")


# Note: the registration order matters.
# The one registered earlier has a higher matching priority.
register_model_adapter(PeftModelAdapter)
Expand Down Expand Up @@ -2147,6 +2157,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
register_model_adapter(MetaMathAdapter)
register_model_adapter(BagelAdapter)
register_model_adapter(SolarAdapter)
register_model_adapter(SteerLMAdapter)

# After all adapters, try the default base adapter.
register_model_adapter(BaseModelAdapter)
7 changes: 7 additions & 0 deletions fastchat/model/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,13 @@ def get_model_info(name: str) -> ModelInfo:
"Claude Instant by Anthropic",
)

register_model_info(
["llama2-70b-steerlm-chat"],
"Llama2-70B-SteerLM-Chat",
"https://huggingface.co/nvidia/Llama2-70B-SteerLM-Chat",
"A Llama fine-tuned with SteerLM method by NVIDIA",
)

register_model_info(
["pplx-70b-online", "pplx-7b-online"],
"pplx-online-llms",
Expand Down
68 changes: 61 additions & 7 deletions fastchat/serve/api_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

from json import loads
import os

import json
import random
import requests
import time

from fastchat.utils import build_logger
Expand Down Expand Up @@ -177,6 +180,12 @@ def gemini_api_stream_iter(model_name, conv, temperature, top_p, max_new_tokens)
"max_output_tokens": max_new_tokens,
"top_p": top_p,
}
params = {
"model": model_name,
"prompt": conv,
}
params.update(generation_config)
logger.info(f"==== request ====\n{params}")

safety_settings = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
Expand All @@ -195,14 +204,22 @@ def gemini_api_stream_iter(model_name, conv, temperature, top_p, max_new_tokens)
convo = model.start_chat(history=history)
response = convo.send_message(conv.messages[-2][1], stream=True)

text = ""
for chunk in response:
text += chunk.text
data = {
"text": text,
"error_code": 0,
try:
text = ""
for chunk in response:
text += chunk.text
data = {
"text": text,
"error_code": 0,
}
yield data
except Exception as e:
logger.error(f"==== error ====\n{e}")
reason = chunk.candidates
yield {
"text": f"**API REQUEST ERROR** Reason: {reason}.",
"error_code": 1,
}
yield data


def ai2_api_stream_iter(
Expand Down Expand Up @@ -319,3 +336,40 @@ def mistral_api_stream_iter(model_name, messages, temperature, top_p, max_new_to
"error_code": 0,
}
yield data


def nvidia_api_stream_iter(model_name, messages, temp, top_p, max_tokens, api_base):
assert model_name in ["llama2-70b-steerlm-chat"]

api_key = os.environ["NVIDIA_API_KEY"]
headers = {
"Authorization": f"Bearer {api_key}",
"accept": "text/event-stream",
"content-type": "application/json",
}
# nvidia api does not accept 0 temperature
if temp == 0.0:
temp = 0.0001

payload = {
"messages": messages,
"temperature": temp,
"top_p": top_p,
"max_tokens": max_tokens,
"seed": 42,
"stream": True,
}
logger.info(f"==== request ====\n{payload}")

response = requests.post(
api_base, headers=headers, json=payload, stream=True, timeout=1
)
text = ""
for line in response.iter_lines():
if line:
data = line.decode("utf-8")
if data.endswith("[DONE]"):
break
data = json.loads(data[6:])["choices"][0]["delta"]["content"]
text += data
yield {"text": text, "error_code": 0}
31 changes: 19 additions & 12 deletions fastchat/serve/gradio_block_arena_anony.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,13 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re
"claude-1": 2,
"claude-instant-1": 4,
"gemini-pro": 4,
"gemini-pro-dev-api": 4,
"pplx-7b-online": 4,
"pplx-70b-online": 4,
"solar-10.7b-instruct-v1.0": 2,
"llama2-70b-steerlm-chat": 2,
"mixtral-8x7b-instruct-v0.1": 4,
"mistral-medium": 8,
"mistral-medium": 4,
"openhermes-2.5-mistral-7b": 2,
"dolphin-2.2.1-mistral-7b": 2,
"wizardlm-70b": 2,
Expand All @@ -186,7 +188,7 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re
"openchat-3.5": 2,
"chatglm3-6b": 2,
# tier 1
"deluxe-chat-v1.2": 2,
"deluxe-chat-v1.2": 4,
"llama-2-70b-chat": 1.5,
"llama-2-13b-chat": 1.5,
"codellama-34b-instruct": 1.5,
Expand Down Expand Up @@ -254,6 +256,7 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re
"claude-1": {"claude-2.1", "gpt-4-0613", "gpt-3.5-turbo-0613"},
"claude-instant-1": {"gpt-3.5-turbo-1106", "claude-2.1"},
"gemini-pro": {"gpt-4-turbo", "gpt-4-0613", "gpt-3.5-turbo-0613"},
"gemini-pro-dev-api": {"gpt-4-turbo", "gpt-4-0613", "gpt-3.5-turbo-0613"},
"deluxe-chat-v1.1": {"gpt-4-0613", "gpt-4-turbo"},
"deluxe-chat-v1.2": {"gpt-4-0613", "gpt-4-turbo"},
"pplx-7b-online": {"gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "llama-2-70b-chat"},
Expand Down Expand Up @@ -297,19 +300,14 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re
}

SAMPLING_BOOST_MODELS = [
# "tulu-2-dpo-70b",
# "yi-34b-chat",
"claude-2.1",
# "claude-1",
# "claude-2.1",
"gpt-4-0613",
# "gpt-3.5-turbo-1106",
# "gpt-4-0314",
"gpt-4-turbo",
# "dolphin-2.2.1-mistral-7b",
# "mixtral-8x7b-instruct-v0.1",
"mistral-medium",
"llama2-70b-steerlm-chat",
"gemini-pro-dev-api",
# "gemini-pro",
# "solar-10.7b-instruct-v1.0",
]

# outage models won't be sampled.
Expand Down Expand Up @@ -479,13 +477,22 @@ def bot_response_multi(
)
)

is_gemini = []
for i in range(num_sides):
is_gemini.append("gemini" in states[i].model_name)

chatbots = [None] * num_sides
iters = 0
while True:
stop = True
iters += 1
for i in range(num_sides):
try:
ret = next(gen[i])
states[i], chatbots[i] = ret[0], ret[1]
# yield gemini fewer times as its chunk size is larger
# otherwise, gemini will stream too fast
if not is_gemini[i] or (iters % 30 == 1 or iters < 3):
ret = next(gen[i])
states[i], chatbots[i] = ret[0], ret[1]
stop = False
except StopIteration:
pass
Expand Down
13 changes: 11 additions & 2 deletions fastchat/serve/gradio_block_arena_named.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,13 +242,22 @@ def bot_response_multi(
)
)

is_gemini = []
for i in range(num_sides):
is_gemini.append("gemini" in states[i].model_name)

chatbots = [None] * num_sides
iters = 0
while True:
stop = True
iters += 1
for i in range(num_sides):
try:
ret = next(gen[i])
states[i], chatbots[i] = ret[0], ret[1]
# yield gemini fewer times as its chunk size is larger
# otherwise, gemini will stream too fast
if not is_gemini[i] or (iters % 30 == 1 or iters < 3):
ret = next(gen[i])
states[i], chatbots[i] = ret[0], ret[1]
stop = False
except StopIteration:
pass
Expand Down
11 changes: 11 additions & 0 deletions fastchat/serve/gradio_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
palm_api_stream_iter,
gemini_api_stream_iter,
mistral_api_stream_iter,
nvidia_api_stream_iter,
init_palm_chat,
)
from fastchat.utils import (
Expand Down Expand Up @@ -446,6 +447,16 @@ def bot_response(
stream_iter = mistral_api_stream_iter(
model_name, prompt, temperature, top_p, max_new_tokens
)
elif model_api_dict["api_type"] == "nvidia":
prompt = conv.to_openai_api_messages()
stream_iter = nvidia_api_stream_iter(
model_name,
prompt,
temperature,
top_p,
max_new_tokens,
model_api_dict["api_base"],
)
else:
raise NotImplementedError

Expand Down