diff --git a/fastchat/llm_judge/gen_model_answer.py b/fastchat/llm_judge/gen_model_answer.py index be399750f..44166fb08 100644 --- a/fastchat/llm_judge/gen_model_answer.py +++ b/fastchat/llm_judge/gen_model_answer.py @@ -31,6 +31,7 @@ def run_eval( num_gpus_total, max_gpu_memory, dtype, + revision, ): questions = load_questions(question_file, question_begin, question_end) # random shuffle the questions to balance the loading @@ -61,6 +62,7 @@ def run_eval( num_gpus_per_model, max_gpu_memory, dtype=dtype, + revision=revision, ) ) @@ -79,9 +81,11 @@ def get_model_answers( num_gpus_per_model, max_gpu_memory, dtype, + revision, ): model, tokenizer = load_model( model_path, + revision=revision, device="cuda", num_gpus=num_gpus_per_model, max_gpu_memory=max_gpu_memory, @@ -259,6 +263,12 @@ def reorg_answer_file(answer_file): help="Override the default dtype. If not set, it will use float16 on GPU and float32 on CPU.", default=None, ) + parser.add_argument( + "--revision", + type=str, + default="main", + help="The model revision to load.", + ) args = parser.parse_args() @@ -288,6 +298,7 @@ def reorg_answer_file(answer_file): num_gpus_total=args.num_gpus_total, max_gpu_memory=args.max_gpu_memory, dtype=str_to_torch_dtype(args.dtype), + revision=args.revision, ) reorg_answer_file(answer_file)