diff --git a/fastchat/serve/multi_model_worker.py b/fastchat/serve/multi_model_worker.py index 098c6d11e..13872bbdd 100644 --- a/fastchat/serve/multi_model_worker.py +++ b/fastchat/serve/multi_model_worker.py @@ -178,6 +178,13 @@ def create_multi_model_worker(): action="append", help="One or more model names. Values must be aligned with `--model-path` values.", ) + parser.add_argument( + "--conv-template", + type=str, + default=None, + action="append", + help="Conversation prompt template. Values must be aligned with `--model-path` values. If only one value is provided, it will be repeated for all models.", + ) parser.add_argument("--limit-worker-concurrency", type=int, default=5) parser.add_argument("--stream-interval", type=int, default=2) parser.add_argument("--no-register", action="store_true") @@ -201,9 +208,16 @@ def create_multi_model_worker(): if args.model_names is None: args.model_names = [[x.split("/")[-1]] for x in args.model_path] + if args.conv_template is None: + args.conv_template = [None] * len(args.model_path) + elif len(args.conv_template) == 1: # Repeat the same template + args.conv_template = args.conv_template * len(args.model_path) + # Launch all workers workers = [] - for model_path, model_names in zip(args.model_path, args.model_names): + for conv_template, model_path, model_names in zip( + args.conv_template, args.model_path, args.model_names + ): w = ModelWorker( args.controller_address, args.worker_address, @@ -219,6 +233,7 @@ def create_multi_model_worker(): cpu_offloading=args.cpu_offloading, gptq_config=gptq_config, stream_interval=args.stream_interval, + conv_template=conv_template, ) workers.append(w) for model_name in model_names: