diff --git a/examples/llm-api/quickstart_multimodal.py b/examples/llm-api/quickstart_multimodal.py index fc18671ee28..dda7ceeadcd 100644 --- a/examples/llm-api/quickstart_multimodal.py +++ b/examples/llm-api/quickstart_multimodal.py @@ -108,6 +108,15 @@ def add_multimodal_args(parser): type=str, default="cpu", help="The device to have the input on.") + # Add multiturn conversation related parameters + parser.add_argument("--multiturn", + action="store_true", + help="Enable multi-turn conversation mode.") + parser.add_argument( + "--conversation_turns", + type=int, + default=2, + help="Number of conversation turns for automated testing.") return parser @@ -162,6 +171,80 @@ def main(): open(os.path.join(llm._hf_model_dir, 'config.json')))['model_type'] assert model_type in ALL_SUPPORTED_MULTIMODAL_MODELS, f"Unsupported model_type: {model_type}" + # If multiturn mode is enabled + if args.multiturn: + # Run predefined multiturn conversation examples + assert args.prompt is not None, "Please provide a prompt for multiturn conversation." + assert args.media is not None, "Please provide media for multiturn conversation." + # Determine how many turns to run + max_turns = min(args.conversation_turns, len(args.prompt)) + generated_outputs = [] # Store generated outputs for return + + # Initialize conversation history with the first prompt + conversation_history = args.prompt[0] if args.prompt else "" + + for i in range(max_turns): + print(f"\n--- Turn {i+1} ---") + + try: + # Use multimodal input loader to process input with conversation context + # Use accumulated conversation history instead of just the current prompt + cur_prompt = conversation_history + inputs = default_multimodal_input_loader( + tokenizer=llm.tokenizer, + model_dir=llm._hf_model_dir, + model_type=model_type, + modality=args.modality, + prompts=[cur_prompt], + media=args.media, + image_data_format="pt", + num_frames=8, + device="cpu") + + lora_request = None + if args.load_lora: + if model_class is None: + raise ValueError( + "model_class must be provided when load_lora is True" + ) + lora_request = model_class.lora_request( + len(inputs), args.modality, llm._hf_model_dir) + + # Generate response + outputs = llm.generate(inputs, + sampling_params, + lora_request=lora_request) + assert outputs and len( + outputs) > 0 and outputs[0].outputs and len( + outputs[0].outputs) > 0 + response = outputs[0].outputs[0].text.strip() + + # Store generated output + generated_outputs.append({ + "turn": i + 1, + "user_input": cur_prompt, + "assistant_response": response, + "media": args.media + }) + + conversation_history = conversation_history + "\n" + response + if i + 1 < len(args.prompt): + conversation_history = conversation_history + "\n" + args.prompt[ + i + 1] + + except Exception as e: + print(f"Error in turn {i+1}: {e}") + import traceback + traceback.print_exc() + continue + + for i, output in enumerate(generated_outputs): + print( + f"[{i}] Prompt: {output['user_input']!r}, Generated text: {output['assistant_response']!r}" + ) + return + + # Original single-turn processing logic # set prompts and media to example prompts and images if they are not provided if args.prompt is None: args.prompt = example_medias_and_prompts[args.modality]["prompt"] diff --git a/tests/integration/defs/accuracy/references/gsm8k.yaml b/tests/integration/defs/accuracy/references/gsm8k.yaml index 5e72cc225e1..ba6ad70265b 100644 --- a/tests/integration/defs/accuracy/references/gsm8k.yaml +++ b/tests/integration/defs/accuracy/references/gsm8k.yaml @@ -19,6 +19,13 @@ meta-llama/Llama-3.3-70B-Instruct: accuracy: 84.08 meta-llama/Llama-4-Maverick-17B-128E-Instruct: - accuracy: 92.20 + - quant_algo: FP8 + kv_cache_quant_algo: FP8 + accuracy: 92.20 + - quant_algo: FP8 + kv_cache_quant_algo: FP8 + spec_dec_algo: Eagle + accuracy: 92.20 meta-llama/Llama-4-Scout-17B-16E-Instruct: - accuracy: 89.70 - quant_algo: NVFP4 diff --git a/tests/integration/defs/accuracy/references/mmlu.yaml b/tests/integration/defs/accuracy/references/mmlu.yaml index 70cfb64bfbe..01516a0b703 100644 --- a/tests/integration/defs/accuracy/references/mmlu.yaml +++ b/tests/integration/defs/accuracy/references/mmlu.yaml @@ -73,6 +73,9 @@ meta-llama/Llama-4-Maverick-17B-128E-Instruct: kv_cache_quant_algo: FP8 spec_dec_algo: Eagle accuracy: 86.40 + - quant_algo: FP8 + kv_cache_quant_algo: FP8 + accuracy: 86.40 meta-llama/Llama-4-Scout-17B-16E-Instruct: - accuracy: 80.00 - quant_algo: NVFP4 diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index 4ee258280c7..8fd7508b075 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -246,7 +246,7 @@ def run_parallel_test(model_name: str, model_path: str, ctx_pp: int, total_ctx_gpus = ctx_tp * ctx_pp * ctx_instances total_gen_gpus = gen_tp * gen_pp * gen_instances if total_ctx_gpus + total_gen_gpus > get_device_count(): - pytest.fail( + pytest.skip( f"Not enough devices for {ctx_instances} ctx instances (ctx_pp={ctx_pp}*ctx_tp={ctx_tp}) + {gen_instances} gen instances (gen_pp={gen_pp}*gen_tp={gen_tp}), total: {total_ctx_gpus + total_gen_gpus}" ) @@ -378,6 +378,7 @@ def test_ngram(self): task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + @skip_pre_hopper @parametrize_with_ids("overlap_scheduler", [True, False]) @parametrize_with_ids("eagle3_one_model", [True, False]) def test_eagle3(self, overlap_scheduler, eagle3_one_model): @@ -461,6 +462,7 @@ def test_multi_instance(self, testset): @pytest.mark.skip_less_device_memory(140000) @pytest.mark.timeout(3600) +@pytest.mark.skip_less_device(4) class TestLlama4ScoutInstruct(LlmapiAccuracyTestHarness): MODEL_NAME = "meta-llama/Llama-4-Scout-17B-16E-Instruct" MODEL_PATH = f"{llm_models_root()}/llama4-models/Llama-4-Scout-17B-16E-Instruct" @@ -540,6 +542,7 @@ def test_nixl_backend(self): @parametrize_with_ids("overlap_scheduler", [True, False]) @parametrize_with_ids("mtp_nextn", [0, pytest.param(2, marks=skip_pre_hopper)]) + @pytest.mark.skip_less_device(4) def test_auto_dtype(self, overlap_scheduler, mtp_nextn): ctx_server_config = {"disable_overlap_scheduler": True} gen_server_config = {"disable_overlap_scheduler": not overlap_scheduler} @@ -671,6 +674,7 @@ def test_nixl_backend(self): task.evaluate(llm) @pytest.mark.parametrize("overlap_scheduler", [False, True]) + @skip_pre_hopper def test_auto_dtype(self, overlap_scheduler): ctx_server_config = { "disable_overlap_scheduler": True, diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index ecb8a1980bf..c667bbaa10a 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -695,6 +695,7 @@ class TestMistralSmall24B(LlmapiAccuracyTestHarness): MODEL_NAME = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" MODEL_PATH = f"{llm_models_root()}/Mistral-Small-3.1-24B-Instruct-2503" + @pytest.mark.skip_less_device_memory(80000) def test_auto_dtype(self): with LLM(self.MODEL_PATH) as llm: task = CnnDailymail(self.MODEL_NAME) @@ -1033,7 +1034,7 @@ def test_cute_dsl_fp8_block_scales( max_num_streams=3) if torch_compile else None) pytorch_config = dict( disable_overlap_scheduler=not overlap_scheduler, - use_cuda_graph=cuda_graph, + cuda_graph_config=CudaGraphConfig() if cuda_graph else None, torch_compile_config=torch_compile_config, moe_config=MoeConfig(backend="CUTEDSL"), ) @@ -1191,7 +1192,7 @@ def test_cute_dsl_fp8_block_scales_4gpus( max_num_streams=3) if torch_compile else None) pytorch_config = dict( disable_overlap_scheduler=not overlap_scheduler, - use_cuda_graph=cuda_graph, + cuda_graph_config=CudaGraphConfig() if cuda_graph else None, torch_compile_config=torch_compile_config, moe_config=MoeConfig(backend="CUTEDSL"), ) diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index d247063a125..2921b4c2d04 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -2051,7 +2051,7 @@ def test_ptp_quickstart_advanced_8gpus(llm_root, llm_venv, model_name, def test_ptp_quickstart_advanced_8gpus_chunked_prefill_sq_22k( llm_root, llm_venv, model_name, model_path, cuda_graph): print(f"Testing {model_name} on 8 GPUs.") - example_root = Path(os.path.join(llm_root, "examples", "pytorch")) + example_root = Path(os.path.join(llm_root, "examples", "llm-api")) cmd = [ str(example_root / "quickstart_advanced.py"), "--enable_chunked_prefill", @@ -2076,10 +2076,12 @@ def test_ptp_quickstart_advanced_8gpus_chunked_prefill_sq_22k( @pytest.mark.skip_less_device_memory(80000) @pytest.mark.skip_less_device(2) @pytest.mark.parametrize("model_name,model_path", [ - ("Llama3.1-70B-BF16", "llama-3.1-model/Meta-Llama-3.1-70B"), ('Nemotron-Super-49B-v1-BF16', 'nemotron-nas/Llama-3_3-Nemotron-Super-49B-v1'), ("Mixtral-8x7B-BF16", "Mixtral-8x7B-Instruct-v0.1"), + pytest.param('Llama3.1-70B-BF16', + 'llama-3.1-model/Meta-Llama-3.1-70B', + marks=pytest.mark.skip_less_device_memory(95000)), ]) def test_ptp_quickstart_advanced_2gpus_sm120(llm_root, llm_venv, model_name, model_path): @@ -2521,6 +2523,106 @@ def test_ptp_quickstart_multimodal_2gpu(llm_root, llm_venv, model_name, print("All answers are correct!") +@pytest.mark.skip_less_device_memory(80000) +@pytest.mark.parametrize("model_name,model_path", [ + ("gemma-3-27b-it", "gemma/gemma-3-27b-it"), + ("mistral-small-3.1-24b-instruct", "Mistral-Small-3.1-24B-Instruct-2503"), + ("Phi-4-multimodal-instruct", "multimodals/Phi-4-multimodal-instruct"), +]) +def test_ptp_quickstart_multimodal_multiturn(llm_root, llm_venv, model_name, + model_path): + example_root = Path(os.path.join(llm_root, "examples", "llm-api")) + test_data_root = Path( + os.path.join(llm_models_root(), "multimodals", "test_data")) + + print(f"Accuracy test {model_name} image mode with example inputs.") + + # Define accuracy inputs for image modality + accuracy_inputs = { + "image": { + "prompt": [ + "Describe what you see in this image.", + "How would you describe the atmosphere of this scene?", + ], + "media": [ + str(test_data_root / "inpaint.png"), + ], + } + } + + # Define expected keywords for each model + expected_keywords = { + "gemma-3-27b-it": { + "image": [ + ["half", "dome", "yosemite", "landmark", "rounded"], + ["atmosphere", "peaceful", "majestic", "calm", "quiet"], + ], + }, + "mistral-small-3.1-24b-instruct": { + "image": [ + ["depicts", "landscape", "rock", "sky", "high", "altitude"], + ["atmosphere", "serene", "majestic", "sense", "tranquility"], + ], + }, + "Phi-4-multimodal-instruct": { + "image": [ + ["depicts", "landscape", "mountain", "half", "dome"], + ["atmosphere", "serene", "sense", "tranquility", "peace."], + ], + }, + } + # Build command for image modality + cmd = [ + str(example_root / "quickstart_multimodal.py"), + "--model_dir", + f"{llm_models_root()}/{model_path}", + "--modality", + "image", + "--multiturn", + "--prompt", + *accuracy_inputs["image"]["prompt"], + "--media", + *accuracy_inputs["image"]["media"], + ] + + # Add model-specific configurations + if model_name == "gemma-3-27b-it": + # Gemma3 VLM needs a custom mask which is only supported by flashinfer backend currently. + # Custom mask involves bidirectional masking of image tokens in context phase. To get this + # correct, chunked prefill and kv cache reuse need to be turned off. + cmd.append("--image_format=pil") + cmd.append("--attention_backend=FLASHINFER") + cmd.append("--disable_kv_cache_reuse") + elif model_name == "Phi-4-multimodal-instruct": + # Set max_seq_len to 4096 to use short rope factor. + cmd.append("--max_seq_len=4096") + cmd.append("--load_lora") + cmd.append("--auto_model_name") + cmd.append("Phi4MMForCausalLM") + + output = llm_venv.run_cmd(cmd, caller=check_output) + print("output:", output) + # Set match ratio based on model + match_ratio = 4.0 / 5 + if model_name == "Phi-4-multimodal-instruct": + match_ratio = 0.6 + + # Check output accuracy + for prompt_output, prompt_keywords in zip( + parse_output(output), expected_keywords[model_name]["image"]): + matches = [ + keyword in prompt_output.lower() for keyword in prompt_keywords + ] + obs_match_ratio = 1. * sum(matches) / len(matches) + print("prompt_output:", prompt_output) + print("prompt_keywords:", prompt_keywords) + print("matches:", matches) + print("obs_match_ratio:", obs_match_ratio) + assert obs_match_ratio >= match_ratio, f"Incorrect output!\nGenerated \"{prompt_output}\"\nExpected keywords \"{prompt_keywords}\"\n Matched keywords: {matches}\n Observed match ratio {obs_match_ratio} below threshold {match_ratio}" + + print("All answers are correct!") + + @pytest.mark.parametrize("model_name,model_path", [ ("BertForSequenceClassification", "bert/bert-base-uncased-yelp-polarity"), ]) diff --git a/tests/integration/test_lists/qa/llm_function_full.txt b/tests/integration/test_lists/qa/llm_function_full.txt index 5c0a9585ffc..42e04be7e78 100644 --- a/tests/integration/test_lists/qa/llm_function_full.txt +++ b/tests/integration/test_lists/qa/llm_function_full.txt @@ -602,6 +602,9 @@ test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image_audio] test_e2e.py::test_ptp_quickstart_multimodal_2gpu[gemma-3-27b-it-gemma/gemma-3-27b-it] test_e2e.py::test_ptp_quickstart_multimodal_2gpu[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503] test_e2e.py::test_ptp_quickstart_multimodal_2gpu[Phi-4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct] +test_e2e.py::test_ptp_quickstart_multimodal_multiturn[gemma-3-27b-it-gemma/gemma-3-27b-it] +test_e2e.py::test_ptp_quickstart_multimodal_multiturn[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503] +test_e2e.py::test_ptp_quickstart_multimodal_multiturn[Phi-4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct] test_e2e.py::test_ptp_quickstart_bert[VANILLA-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity] test_e2e.py::test_ptp_quickstart_bert[TRTLLM-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity] test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B]