Skip to content
Prev Previous commit
Next Next commit
Add mistral api & refactor (#2875)
  • Loading branch information
infwinston authored Jan 4, 2024
commit 8756b52f97ff4785c63bb013e9c32d94bc8f96a6
9 changes: 9 additions & 0 deletions fastchat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,15 @@ def get_conv_template(name: str) -> Conversation:
)
)

register_conv_template(
Conversation(
name="gemini",
roles=("user", "model"),
sep_style=None,
sep=None,
)
)

# BiLLa default template
register_conv_template(
Conversation(
Expand Down
4 changes: 2 additions & 2 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,13 +1126,13 @@ class GeminiAdapter(BaseModelAdapter):
"""The model adapter for Gemini"""

def match(self, model_path: str):
return model_path in ["gemini-pro"]
return model_path in ["gemini-pro", "gemini-pro-dev-api"]

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
raise NotImplementedError()

def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("bard")
return get_conv_template("gemini")


class BiLLaAdapter(BaseModelAdapter):
Expand Down
4 changes: 2 additions & 2 deletions fastchat/model/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ def get_model_info(name: str) -> ModelInfo:


register_model_info(
["mixtral-8x7b-instruct-v0.1", "mistral-7b-instruct"],
["mixtral-8x7b-instruct-v0.1", "mistral-medium", "mistral-7b-instruct"],
"Mixtral of experts",
"https://mistral.ai/news/mixtral-of-experts/",
"A Mixture-of-Experts model by Mistral AI",
)

register_model_info(
["gemini-pro"],
["gemini-pro", "gemini-pro-dev-api"],
"Gemini",
"https://blog.google/technology/ai/google-gemini-pro-imagen-duet-ai-update/",
"Gemini by Google",
Expand Down
80 changes: 80 additions & 0 deletions fastchat/serve/api_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,44 @@ def palm_api_stream_iter(model_name, chat, message, temperature, top_p, max_new_
yield data


def gemini_api_stream_iter(model_name, conv, temperature, top_p, max_new_tokens):
import google.generativeai as genai # pip install google-generativeai

genai.configure(api_key=os.environ["GEMINI_API_KEY"])

generation_config = {
"temperature": temperature,
"max_output_tokens": max_new_tokens,
"top_p": top_p,
}

safety_settings = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
]
model = genai.GenerativeModel(
model_name=model_name,
generation_config=generation_config,
safety_settings=safety_settings,
)
history = []
for role, message in conv.messages[:-2]:
history.append({"role": role, "parts": message})
convo = model.start_chat(history=history)
response = convo.send_message(conv.messages[-2][1], stream=True)

text = ""
for chunk in response:
text += chunk.text
data = {
"text": text,
"error_code": 0,
}
yield data


def ai2_api_stream_iter(
model_name,
messages,
Expand Down Expand Up @@ -239,3 +277,45 @@ def ai2_api_stream_iter(
"error_code": 0,
}
yield data


def mistral_api_stream_iter(model_name, messages, temperature, top_p, max_new_tokens):
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage

api_key = os.environ["MISTRAL_API_KEY"]

client = MistralClient(api_key=api_key)

# Make requests
gen_params = {
"model": model_name,
"prompt": messages,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
}
logger.info(f"==== request ====\n{gen_params}")

new_messages = [
ChatMessage(role=message["role"], content=message["content"])
for message in messages
]

res = client.chat_stream(
model=model_name,
temperature=temperature,
messages=new_messages,
max_tokens=max_new_tokens,
top_p=top_p,
)

text = ""
for chunk in res:
if chunk.choices[0].delta.content is not None:
text += chunk.choices[0].delta.content
data = {
"text": text,
"error_code": 0,
}
yield data
18 changes: 13 additions & 5 deletions fastchat/serve/gradio_block_arena_anony.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re
# tier 0
"gpt-4": 4,
"gpt-4-0314": 4,
"gpt-4-0613": 4,
"gpt-4-turbo": 4,
"gpt-3.5-turbo-0613": 2,
"gpt-3.5-turbo-1106": 2,
Expand All @@ -174,6 +175,7 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re
"pplx-70b-online": 4,
"solar-10.7b-instruct-v1.0": 2,
"mixtral-8x7b-instruct-v0.1": 4,
"mistral-medium": 8,
"openhermes-2.5-mistral-7b": 2,
"dolphin-2.2.1-mistral-7b": 2,
"wizardlm-70b": 2,
Expand Down Expand Up @@ -235,6 +237,12 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re
"gpt-3.5-turbo-0613",
"llama-2-70b-chat",
},
"mistral-medium": {
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-0613",
"gpt-4-turbo",
"mixtral-8x7b-instruct-v0.1",
},
"mixtral-8x7b-instruct-v0.1": {
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-0613",
Expand Down Expand Up @@ -292,15 +300,16 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re
# "tulu-2-dpo-70b",
# "yi-34b-chat",
"claude-2.1",
"claude-1",
# "claude-1",
"gpt-4-0613",
# "gpt-3.5-turbo-1106",
# "gpt-4-0314",
"gpt-4-turbo",
# "dolphin-2.2.1-mistral-7b",
"mixtral-8x7b-instruct-v0.1",
"gemini-pro",
"solar-10.7b-instruct-v1.0",
# "mixtral-8x7b-instruct-v0.1",
"mistral-medium",
# "gemini-pro",
# "solar-10.7b-instruct-v1.0",
]

# outage models won't be sampled.
Expand Down Expand Up @@ -544,7 +553,6 @@ def build_side_by_side_ui_anony(models):
textbox = gr.Textbox(
show_label=False,
placeholder="👉 Enter your prompt and press ENTER",
container=False,
elem_id="input_box",
)
send_btn = gr.Button(value="Send", variant="primary", scale=0)
Expand Down
1 change: 0 additions & 1 deletion fastchat/serve/gradio_block_arena_named.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,6 @@ def build_side_by_side_ui_named(models):
textbox = gr.Textbox(
show_label=False,
placeholder="👉 Enter your prompt and press ENTER",
container=False,
elem_id="input_box",
)
send_btn = gr.Button(value="Send", variant="primary", scale=0)
Expand Down
Loading