Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add models & moderation
  • Loading branch information
infwinston committed Jan 19, 2024
commit ac321a2858a4ff628fc2d2cd172074b456bf928e
2 changes: 1 addition & 1 deletion fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,7 +1143,7 @@ class GeminiAdapter(BaseModelAdapter):
"""The model adapter for Gemini"""

def match(self, model_path: str):
return model_path in ["gemini-pro", "gemini-pro-dev-api"]
return "gemini" in model_path.lower()

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
raise NotImplementedError()
Expand Down
14 changes: 14 additions & 0 deletions fastchat/model/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ def get_model_info(name: str) -> ModelInfo:
"Gemini by Google",
)

register_model_info(
["bard-jan-24-gemini-pro"],
"Bard",
"https://bard.google.com/",
"Bard by Google",
)

register_model_info(
["solar-10.7b-instruct-v1.0"],
"SOLAR-10.7B-Instruct",
Expand Down Expand Up @@ -543,3 +550,10 @@ def get_model_info(name: str) -> ModelInfo:
"https://huggingface.co/meta-math",
"MetaMath is a finetune of Llama2 on [MetaMathQA](https://huggingface.co/datasets/meta-math/MetaMathQA) that specializes in mathematical reasoning.",
)

register_model_info(
["stripedhyena-nous-7b"],
"StripedHyena-Nous",
"https://huggingface.co/togethercomputer/StripedHyena-Nous-7B",
"A chat model developed by Together Research and Nous Research."
)
80 changes: 78 additions & 2 deletions fastchat/serve/api_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,12 @@ def palm_api_stream_iter(model_name, chat, message, temperature, top_p, max_new_
yield data


def gemini_api_stream_iter(model_name, conv, temperature, top_p, max_new_tokens):
def gemini_api_stream_iter(model_name, conv, temperature, top_p, max_new_tokens, api_key=None):
import google.generativeai as genai # pip install google-generativeai

genai.configure(api_key=os.environ["GEMINI_API_KEY"])
if api_key is None:
api_key = os.environ["GEMINI_API_KEY"]
genai.configure(api_key=api_key)

generation_config = {
"temperature": temperature,
Expand Down Expand Up @@ -222,6 +224,80 @@ def gemini_api_stream_iter(model_name, conv, temperature, top_p, max_new_tokens)
}


def bard_api_stream_iter(model_name, conv, temperature, top_p, api_key=None):
if api_key is None:
api_key = os.environ["BARD_API_KEY"]

# convert conv to conv_bard
conv_bard = []
for turn in conv:
if turn["role"] == "user":
conv_bard.append({"author": "0", "content": turn["content"]})
elif turn["role"] == "assistant":
conv_bard.append({"author": "1", "content": turn["content"]})
else:
raise ValueError(f"Unsupported role: {turn['role']}")

generation_config = {
"temperature": temperature,
"top_p": top_p,
}
params = {
"model": model_name,
"prompt": conv_bard,
}
params.update(generation_config)
logger.info(f"==== request ====\n{params}")

try:
res = requests.post(
f"https://generativelanguage.googleapis.com/v1beta2/models/{model_name}:generateMessage?key={api_key}",
json={
"prompt": {
"messages": conv_bard,
},
"temperature": temperature,
"topP": top_p,
},
timeout=30,
)
except Exception as e:
logger.error(f"==== error ====\n{e}")
yield {
"text": f"**API REQUEST ERROR** Reason: {e}.",
"error_code": 1,
}

if res.status_code != 200:
logger.error(f"==== error ==== ({res.status_code}): {res.text}")
yield {
"text": f"**API REQUEST ERROR** Reason: status code {res.status_code}.",
"error_code": 1,
}

response_json = res.json()
if "candidates" not in response_json:
logger.error(f"==== error ==== response blocked: {response_json}")
reason = response_json["filters"][0]["reason"]
yield {
"text": f"**API REQUEST ERROR** Reason: {reason}.",
"error_code": 1,
}

response = response_json["candidates"][0]["content"]
pos = 0
while pos < len(response):
# This is a fancy way to simulate token generation latency combined
# with a Poisson process.
pos += random.randint(1, 5)
time.sleep(random.expovariate(200))
data = {
"text": response[:pos],
"error_code": 0,
}
yield data


def ai2_api_stream_iter(
model_name,
model_id,
Expand Down
2 changes: 1 addition & 1 deletion fastchat/serve/call_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, log_dir: str):
self.log_dir = log_dir
self.model_call = {}
self.user_call = {}
self.model_call_limit_global = {"gpt-4-turbo": 1000}
self.model_call_limit_global = {"gpt-4-turbo": 200}
self.model_call_day_limit_per_user = {"gpt-4-turbo": 10}

async def update_stats(self, num_file=1) -> None:
Expand Down
23 changes: 19 additions & 4 deletions fastchat/serve/gradio_block_arena_anony.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,12 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re
"claude-instant-1": 4,
"gemini-pro": 4,
"gemini-pro-dev-api": 4,
"bard-jan-24-gemini-pro": 4,
"pplx-7b-online": 4,
"pplx-70b-online": 4,
"solar-10.7b-instruct-v1.0": 2,
"llama2-70b-steerlm-chat": 2,
"stripedhyena-nous-7b": 4,
"mixtral-8x7b-instruct-v0.1": 4,
"mistral-medium": 4,
"openhermes-2.5-mistral-7b": 2,
Expand Down Expand Up @@ -257,6 +259,18 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re
"claude-instant-1": {"gpt-3.5-turbo-1106", "claude-2.1"},
"gemini-pro": {"gpt-4-turbo", "gpt-4-0613", "gpt-3.5-turbo-0613"},
"gemini-pro-dev-api": {"gpt-4-turbo", "gpt-4-0613", "gpt-3.5-turbo-0613"},
"bard-jan-24-gemini-pro": {"gpt-4-turbo", "gpt-4-0613", "gpt-3.5-turbo-0613"},
"llama2-70b-steerlm-chat": {
"llama-2-70b-chat",
"tulu-2-dpo-70b",
"yi-34b-chat",
},
"stripedhyena-nous-7b": {
"starling-lm-7b-alpha",
"openhermes-2.5-mistral-7b",
"mistral-7b-instruct",
"llama-2-7b-chat",
},
"deluxe-chat-v1.1": {"gpt-4-0613", "gpt-4-turbo"},
"deluxe-chat-v1.2": {"gpt-4-0613", "gpt-4-turbo"},
"pplx-7b-online": {"gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "llama-2-70b-chat"},
Expand Down Expand Up @@ -301,12 +315,14 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re

SAMPLING_BOOST_MODELS = [
# "claude-2.1",
"gpt-4-0613",
# "gpt-4-0613",
# "gpt-4-0314",
"gpt-4-turbo",
"mistral-medium",
"llama2-70b-steerlm-chat",
"gemini-pro-dev-api",
# "gemini-pro-dev-api",
"stripedhyena-nous-7b",
"bard-jan-24-gemini-pro",
# "gemini-pro",
]

Expand Down Expand Up @@ -479,8 +495,7 @@ def bot_response_multi(

is_gemini = []
for i in range(num_sides):
is_gemini.append("gemini" in states[i].model_name)

is_gemini.append(states[i].model_name in ["gemini-pro", "gemini-pro-dev-api"])
chatbots = [None] * num_sides
iters = 0
while True:
Expand Down
2 changes: 1 addition & 1 deletion fastchat/serve/gradio_block_arena_named.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def bot_response_multi(

is_gemini = []
for i in range(num_sides):
is_gemini.append("gemini" in states[i].model_name)
is_gemini.append(states[i].model_name in ["gemini-pro", "gemini-pro-dev-api"])

chatbots = [None] * num_sides
iters = 0
Expand Down
8 changes: 7 additions & 1 deletion fastchat/serve/gradio_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
openai_api_stream_iter,
palm_api_stream_iter,
gemini_api_stream_iter,
bard_api_stream_iter,
mistral_api_stream_iter,
nvidia_api_stream_iter,
ai2_api_stream_iter,
Expand Down Expand Up @@ -441,7 +442,12 @@ def bot_response(
)
elif model_api_dict["api_type"] == "gemini":
stream_iter = gemini_api_stream_iter(
model_api_dict["model_name"], conv, temperature, top_p, max_new_tokens
model_api_dict["model_name"], conv, temperature, top_p, max_new_tokens, api_key=model_api_dict["api_key"],
)
elif model_api_dict["api_type"] == "bard":
prompt = conv.to_openai_api_messages()
stream_iter = bard_api_stream_iter(
model_api_dict["model_name"], prompt, temperature, top_p, api_key=model_api_dict["api_key"],
)
elif model_api_dict["api_type"] == "mistral":
prompt = conv.to_openai_api_messages()
Expand Down
10 changes: 8 additions & 2 deletions fastchat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,17 @@ def oai_moderation(text):
openai.api_type = "open_ai"
openai.api_version = None

threshold_dict = {
"sexual": 0.2,
}
MAX_RETRY = 3
for i in range(MAX_RETRY):
for _ in range(MAX_RETRY):
try:
res = openai.Moderation.create(input=text)
flagged = res["results"][0]["flagged"]
for category, threshold in threshold_dict.items():
if res["results"][0]["category_scores"][category] > threshold:
flagged = True
break
except (openai.error.OpenAIError, KeyError, IndexError) as e:
# flag true to be conservative
Expand All @@ -171,7 +177,7 @@ def oai_moderation(text):


def moderation_filter(text, model_list):
MODEL_KEYWORDS = ["claude"]
MODEL_KEYWORDS = ["claude", "gpt-4", "gpt-3.5", "bard"]

for keyword in MODEL_KEYWORDS:
for model in model_list:
Expand Down