Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
move all model-related arguments into --model-info-file
  • Loading branch information
hnyls2002 committed Sep 29, 2023
commit cc6b7ac53ecabc2fb350043319b9e07bce700a1c
63 changes: 27 additions & 36 deletions fastchat/serve/huggingface_api_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@
"api_base": "https://api-inference.huggingface.co/models",
"token": "hf_xxx",
"context_length": 2048
"model_names": "falcon-180b-chat",
"conv_template": null,
}
}

Only "model_path", "api_base", and "token" are necessary, others are optional.
"""
import argparse
import asyncio
Expand Down Expand Up @@ -270,34 +274,14 @@ def create_huggingface_api_worker():
parser.add_argument(
"--controller-address", type=str, default="http://localhost:21001"
)
# all model-related parameters are listed in --model-info-file
parser.add_argument(
"--model-info-file",
type=str,
required=True,
help="Huggingface API model's info file path",
)

# support multi huggingface api models here
parser.add_argument(
"--model",
type=str,
default=[],
action="append",
help="The models' names to be called.",
)
parser.add_argument(
"--model-names",
type=lambda s: s.split(","),
action="append",
help="One or more model names. Values must be aligned with `--model` values.",
)
parser.add_argument(
"--conv-template",
type=str,
default=None,
action="append",
help="Conversation prompt template. Values must be aligned with `--model` values. If only one value is provided, it will be repeated for all models.",
)
parser.add_argument(
"--limit-worker-concurrency",
type=int,
Expand All @@ -313,13 +297,6 @@ def create_huggingface_api_worker():
)
args = parser.parse_args()

if args.model_names is None:
args.model_names = [[x.split("/")[-1]] for x in args.model]
if args.conv_template is None:
args.conv_template = [None] * len(args.model)
elif len(args.conv_template) == 1: # Repeat the same template
args.conv_template = args.conv_template * len(args.model)

with open(args.model_info_file, "r", encoding="UTF-8") as f:
model_info = json.load(f)

Expand All @@ -329,16 +306,30 @@ def create_huggingface_api_worker():
api_base_list = []
token_list = []
context_length_list = []
model_names_list = []
conv_template_list = []

for m in args.model:
if m not in model_info:
raise ValueError(
f"Model {args.model} not supported. Please add it to {args.model_info_file}."
)
for m in model_info:
model_path_list.append(model_info[m]["model_path"])
api_base_list.append(model_info[m]["api_base"])
token_list.append(model_info[m]["token"])
context_length_list.append(model_info[m]["context_length"])

context_length = model_info[m].get("context_length", 1024)
model_names = model_info[m].get("model_names", [m.split("/")[-1]])
if isinstance(model_names, str):
model_names = [model_names]
conv_template = model_info[m].get("conv_template", None)

context_length_list.append(context_length)
model_names_list.append(model_names)
conv_template_list.append(conv_template)

logger.info(f"Model paths: {model_path_list}")
logger.info(f"API bases: {api_base_list}")
logger.info(f"Tokens: {token_list}")
logger.info(f"Context lengths: {context_length_list}")
logger.info(f"Model names: {model_names_list}")
logger.info(f"Conv templates: {conv_template_list}")

for (
model_names,
Expand All @@ -348,8 +339,8 @@ def create_huggingface_api_worker():
token,
context_length,
) in zip(
args.model_names,
args.conv_template,
model_names_list,
conv_template_list,
model_path_list,
api_base_list,
token_list,
Expand Down