diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h index ccda8ce2042..6ac3f46bf46 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h @@ -1336,6 +1336,8 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske // Note max_attention_window_size is maximum of cyclic_attention_window_size among all layers. // By default, you can assume that they are the same. auto const cyclic_kv_cache_len = static_cast(params.cyclic_attention_window_size); + // The chunked attention size. + auto const chunked_attention_size = static_cast(params.chunked_attention_size); // The number of sink tokens in kv cache to support streamingllm auto const sink_token_len = static_cast(params.sink_token_length); // The current timestep (including paddings). @@ -1361,7 +1363,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske #ifndef MMHA_USE_FP32_ACCUM_FOR_LOGITS if (sizeof(Tk) != 4) { - auto const max_timesteps = min(timestep, cyclic_kv_cache_len); + auto const max_timesteps = min(timestep, min(cyclic_kv_cache_len, chunked_attention_size)); logits_smem_ += divUp(max_timesteps + 1, 4u) * 16; } Tk* logits_smem = reinterpret_cast(logits_smem_); diff --git a/docs/source/performance/perf-overview.md b/docs/source/performance/perf-overview.md index 9e316617186..c06f4039045 100644 --- a/docs/source/performance/perf-overview.md +++ b/docs/source/performance/perf-overview.md @@ -12,6 +12,8 @@ Tuning batch sizes, parallelism configurations, and other options may lead to im For DeepSeek R1 performance, please check out our [performance guide](../blogs/Best_perf_practice_on_DeepSeek-R1_in_TensorRT-LLM.md) +For more information on benchmarking with `trtllm-bench` see this NVIDIA [blog post](https://developer.nvidia.com/blog/llm-inference-benchmarking-performance-tuning-with-tensorrt-llm/). + ## Throughput Measurements The below table shows performance data where a local inference client is fed requests at an infinite rate (no delay between messages), @@ -21,50 +23,64 @@ The performance numbers below were collected using the steps described in this d Testing was performed on models with weights quantized using [ModelOpt](https://nvidia.github.io/TensorRT-Model-Optimizer/#) and published by NVIDIA on the [Model Optimizer HuggingFace Collection](https://huggingface.co/collections/nvidia/model-optimizer-66aa84f7966b3150262481a4). -### FP4 Models: -``` +### Hardware +The following GPU variants were used for testing: +- H100 SXM 80GB (DGX H100) +- H200 SXM 141GB (DGX H200) +- GH200 96GB HBM3 (480GB LPDDR5X) +- B200 180GB (DGX B200) +- GB200 192GB (GB200 NVL72) + +Other hardware variants may have different TDP, memory bandwidth, core count, or other features leading to performance differences on these workloads. + +### FP4 Models + +```text nvidia/Llama-3.3-70B-Instruct-FP4 nvidia/Llama-3.1-405B-Instruct-FP4 ``` #### Llama 3.3 70B FP4 -| | GPU | B200 | | | | -|:------------------------|:--------|:----------|:----------|:----------|:----------| -| | TP Size | 1 | 2 | 4 | 8 | -| ISL, OSL | | | | | | -| | | | | | | -| 128, 128 | | 10,994.48 | 17,542.11 | 24,667.31 | 27,272.27 | -| 128, 2048 | | 9,580.46 | 15,432.35 | 23,568.12 | 31,174.31 | -| 128, 4096 | | 6,418.39 | 9,841.53 | 17,808.76 | 25,229.25 | -| 500, 2000 | | 7,343.32 | 11,850.57 | 20,709.67 | 28,038.78 | -| 1000, 1000 | | 6,752.53 | 10,815.88 | 16,413.04 | 20,060.66 | -| 1000, 2000 | | 6,670.07 | 9,830.73 | 15,597.49 | 20,672.37 | -| 1024, 2048 | | 6,636.75 | 9,807.13 | 15,519.23 | 20,617.28 | -| 2048, 128 | | 1,342.17 | 1,989.41 | 3,033.14 | 4,035.64 | -| 5000, 500 | | 1,429.67 | 2,419.67 | 3,686.84 | 5,182.96 | -| 20000, 2000 | | 629.77 | 1,177.01 | 2,120.66 | 3,429.03 | +| | GPU: | B200 | GB200 | +|:-----------------------------|:---|:----------|:--------------| +| | TP Size | 1 | 1 | +| ISL, OSL | | | | +| | | | | +| 128, 128 | | 10,613.84 | 11,100.97 | +| 128, 2048 | | 9,445.51 | 10,276.05 | +| 128, 4096 | | 6,276.85 | 7,351.12 | +| 500, 2000 | | 6,983.27 | 8,194.30 | +| 1000, 1000 | | 6,434.29 | 7,401.80 | +| 1000, 2000 | | 6,725.03 | 6,478.72 | +| 1024, 2048 | | 6,546.61 | 7,922.88 | +| 2048, 128 | | 1,330.35 | 1,418.47 | +| 2048, 2048 | | 4,528.48 | 5,326.77 | +| 5000, 500 | | 1,427.44 | 1,502.44 | +| 20000, 2000 | | 636.36 | 732.43 | #### Llama 3.1 405B FP4 -| | GPU | B200 | | -|:------------------------|:------- |:---------|:----------| -| | TP Size | 4 | 8 | -| ISL, OSL | | | | -| | | | | -| 128, 128 | | 6,163.81 | 9,002.90 | -| 128, 2048 | | 7,081.21 | 10,288.28 | -| 128, 4096 | | 6,028.37 | 8,713.77 | -| 500, 2000 | | 5,858.75 | 9,125.86 | -| 1000, 1000 | | 4,848.00 | 7,582.97 | -| 1000, 2000 | | 5,375.25 | 7,626.28 | -| 1024, 2048 | | 5,345.70 | 7,464.03 | -| 2048, 128 | | 693.55 | 1,086.56 | -| 5000, 500 | | 947.49 | 1,532.45 | -| 20000, 2000 | | 641.11 | 1,097.84 | - -### FP8 Models: -``` +| | GPU: | B200 | GB200 | +|:-----------------------------|:---|:---------|:--------------| +| | TP Size | 4 | 4 | +| ISL, OSL | | | | +| | | | | +| 128, 128 | | 6,218.89 | 6,598.97 | +| 128, 2048 | | 7,178.10 | 7,497.40 | +| 128, 4096 | | 5,890.89 | 5,898.19 | +| 500, 2000 | | 5,844.37 | 6,198.33 | +| 1000, 1000 | | 4,958.53 | 5,243.35 | +| 1000, 2000 | | 4,874.16 | 4,905.51 | +| 1024, 2048 | | 4,833.19 | 4,686.38 | +| 2048, 128 | | 737.95 | 761.58 | +| 2048, 2048 | | 4,024.02 | 4,326.56 | +| 5000, 500 | | 1,032.40 | 1,078.87 | +| 20000, 2000 | | 667.39 | 649.95 | + +### FP8 Models + +```text nvidia/Llama-3.1-8B-Instruct-FP8 nvidia/Llama-3.3-70B-Instruct-FP8 nvidia/Llama-3.1-405B-Instruct-FP8 @@ -73,61 +89,65 @@ nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8 #### Llama 3.1 8B FP8 -| | GPU | H200 141GB HBM3 | H100 80GB HBM3 | -|:-----------------------------|:---|:------------------|:-----------------| -| | TP Size | 1 | 1 | -| ISL, OSL | | | | -| | | | | -| 128, 128 | | 27,970.14 | 27,688.36 | -| 128, 2048 | | 23,326.38 | 21,841.15 | -| 128, 4096 | | 17,508.51 | 13,730.89 | -| 500, 2000 | | 21,390.41 | 17,833.34 | -| 1000, 1000 | | 17,366.89 | 15,270.62 | -| 1000, 2000 | | 16,831.31 | 13,798.08 | -| 1024, 2048 | | 16,737.03 | 13,385.50 | -| 2048, 128 | | 3,488.03 | 3,414.67 | -| 5000, 500 | | 3,813.69 | 3,394.54 | -| 20000, 2000 | | 1,696.66 | 1,345.42 | +| | GPU: | GH200 | H100 | H200 | +|:-----------------------------|:---|:--------------|:-----------------|:------------------| +| | TP Size | 1 | 1 | 1 | +| ISL, OSL | | | | | +| | | | | | +| 128, 128 | | 27,304.25 | 26,401.48 | 27,027.80 | +| 128, 2048 | | 24,045.60 | 21,413.21 | 23,102.25 | +| 128, 4096 | | 15,409.85 | 13,541.54 | 17,396.83 | +| 500, 2000 | | 20,123.88 | 17,571.01 | 19,759.16 | +| 1000, 1000 | | 16,352.99 | 14,991.62 | 17,162.49 | +| 1000, 2000 | | 15,705.82 | 13,505.23 | 16,227.11 | +| 1024, 2048 | | 16,102.52 | 13,165.91 | 16,057.66 | +| 2048, 128 | | 3,573.85 | 3,275.55 | 3,390.69 | +| 2048, 2048 | | 10,767.05 | 9,462.43 | 11,822.14 | +| 5000, 500 | | 3,584.74 | 3,276.47 | 3,758.08 | +| 20000, 2000 | | 1,393.31 | 1,340.69 | 1,705.68 | #### Llama 3.3 70B FP8 -| | GPU | H200 141GB HBM3 | | | | H100 80GB HBM3 | | | | -|:-----------------------------|:---|:------------------|:---------|:----------|:----------|:-----------------|:---------|:----------|:----------| -| | TP Size | 1 | 2 | 4 | 8 | 1 | 2 | 4 | 8 | -| ISL, OSL | | | | | | | | | | -| | | | | | | | | | | -| 128, 128 | | 3,605.47 | 6,427.69 | 10,407.42 | 15,434.37 | 3,128.33 | 6,216.91 | | | -| 128, 2048 | | 4,315.80 | 8,464.03 | 13,508.59 | 20,759.72 | 756.42 | 5,782.57 | 11,464.94 | 17,424.32 | -| 128, 4096 | | 2,701.17 | 5,573.55 | 11,458.56 | 16,668.75 | | 3,868.37 | 8,206.39 | 12,624.61 | -| 500, 2000 | | 3,478.76 | 6,740.06 | 12,200.18 | | | 4,684.06 | 9,903.53 | 14,553.93 | -| 1000, 1000 | | 2,744.32 | 5,119.72 | 8,685.44 | 12,744.51 | 742.14 | 4,247.19 | 7,435.65 | 11,018.81 | -| 1000, 2000 | | 2,896.44 | 5,847.26 | 9,031.21 | 13,141.17 | 533.74 | 3,866.53 | 7,611.12 | 11,139.22 | -| 1024, 2048 | | 2,874.18 | 5,568.61 | 8,946.71 | 13,082.62 | 530.16 | 3,796.68 | 7,575.24 | 11,004.31 | -| 2048, 128 | | 435.90 | 772.67 | 1,264.76 | | | 736.89 | 1,213.33 | 1,839.22 | -| 2048, 2048 | | | | | 10,412.85 | | | | | -| 5000, 500 | | 545.96 | 997.15 | 1,698.22 | 2,655.28 | 204.94 | 862.91 | 1,552.68 | 2,369.84 | -| 20000, 2000 | | 276.66 | 620.33 | 1,161.29 | 1,985.85 | | 416.13 | 903.66 | 1,554.10 | +| | GPU: | H100 | H200 | +|:-----------------------------|:---|:-----------------|:------------------| +| | TP Size | 2 | 2 | +| ISL, OSL | | | | +| | | | | +| 128, 128 | | 6,092.28 | 6,327.98 | +| 128, 2048 | | 5,892.94 | 7,467.36 | +| 128, 4096 | | 3,828.46 | 5,526.42 | +| 500, 2000 | | 4,654.74 | 6,639.15 | +| 1000, 1000 | | 4,181.06 | 4,773.33 | +| 1000, 2000 | | 3,708.93 | 5,790.36 | +| 1024, 2048 | | 3,785.04 | 5,480.44 | +| 2048, 128 | | 723.40 | 747.55 | +| 2048, 2048 | | 2,785.53 | 3,775.80 | +| 5000, 500 | | 865.55 | 978.28 | +| 20000, 2000 | | 411.85 | 609.42 | #### Llama 3.1 405B FP8 - -| | GPU | H200 141GB HBM3 | H100 80GB HBM3 | -|:-----------------------------|:---|:------------------|:-----------------| -| | TP Size | 8 | 8 | -| ISL, OSL | | | | -| | | | | -| 128, 2048 | | 5,567.87 | | -| 128, 4096 | | 5,136.85 | | -| 500, 2000 | | 4,787.61 | 3,673.91 | -| 1000, 1000 | | 3,286.30 | 3,012.22 | -| 1000, 2000 | | 3,636.76 | 3,262.20 | -| 1024, 2048 | | 3,618.66 | 3,109.70 | -| 2048, 128 | | 443.10 | 449.02 | -| 5000, 500 | | 645.46 | | -| 20000, 2000 | | | 372.12 | +| | GPU: | H100 | H200 | +|:-----------------------------|:---|:-----------------|:------------------| +| | TP Size | 8 | 8 | +| Runtime Input/Output Lengths | | | | +| | | | | +| 128, 128 | | | 3,705.18 | +| 128, 2048 | | 4,517.39 | 4,715.13 | +| 128, 4096 | | 2,910.31 | 4,475.91 | +| 500, 2000 | | 3,664.62 | 4,804.10 | +| 1000, 1000 | | 2,955.50 | 3,208.25 | +| 1000, 2000 | | 2,884.69 | 3,630.29 | +| 1024, 2048 | | 3,237.41 | 3,609.50 | +| 2048, 128 | | 433.47 | 441.35 | +| 2048, 2048 | | 2,216.55 | 2,840.86 | +| 5000, 500 | | 579.05 | 645.26 | +| 20000, 2000 | | 363.27 | 509.87 | #### Llama 4 Maverick FP8 -| | GPU | H200 141GB HBM3 | H100 80GB HBM3 | +Note: Performance for Llama 4 on sequence lengths less than 8,192 tokens is affected by an issue introduced in v0.21. To reproduce the Llama 4 performance noted here, please use v0.20 + +| | GPU | H200 | H100 | |:-----------------------------|:---|:------------------|:-----------------| | | TP Size | 8 | 8 | | ISL, OSL | | | | @@ -140,7 +160,6 @@ nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8 | 2048, 128 | | 4,364.06 | 3,832.38 | | 2048, 2048 | | 12,800.89 | | | 5000, 500 | | 5,128.60 | | -| 20000, 2000 | | 1,764.27 | 1,400.79 | ## Reproducing Benchmarked Results @@ -216,7 +235,7 @@ a model name (HuggingFace reference or path to a local model), a [generated data trtllm-bench --model $model_name throughput --dataset $dataset_file --backend pytorch --extra_llm_api_options $llm_options ``` -The data collected for the v0.20 benchmarks was run with the following file: +The data collected for the v0.21 benchmarks was run with the following file: `llm_options.yml` ```yaml @@ -240,7 +259,7 @@ cuda_graph_config: - 8192 ``` -In a majority of cases, we also use a higher KV cache percentage by setting `--kv_cache_free_gpu_mem_fraction 0.95` in the benchmark command. This allows us to obtain better performance than the default setting of `0.90`. We fall back to `0.90` if we hit an out of memory issue. +In many cases, we also use a higher KV cache percentage by setting `--kv_cache_free_gpu_mem_fraction 0.95` in the benchmark command. This allows us to obtain better performance than the default setting of `0.90`. We fall back to `0.90` or lower if out-of-memory errors are encountered. The results will be printed to the terminal upon benchmark completion. For example, diff --git a/docs/source/release-notes.md b/docs/source/release-notes.md index dee84ecfde5..d0cf99c69eb 100644 --- a/docs/source/release-notes.md +++ b/docs/source/release-notes.md @@ -73,6 +73,7 @@ All published functionality in the Release Notes has been fully tested and verif ### Known Issues - accuracy/test_cli_flow::TestGpt2::test_beam_search_large is broken. - Enabling disaggregated serving, MTP, and the overlap scheduler at the same time can lead to accuracy problems. +- In 0.21, full chunked attention support has been added to make sure LLaMA4 model can functionally run with > 8K seq length, while there is a known performance regression(only affect LLaMA4 model) on Hopper due to this functional enhancement. The root cause of the regression has been identified already and the fix will be part of the future release. ## TensorRT-LLM Release 0.20.0 diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 8054acea82e..b5c88dec51d 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -489,6 +489,51 @@ def test_chunked_prefill(self, attn_backend): task = MMLU(self.MODEL_NAME) task.evaluate(llm) + @skip_pre_hopper + @pytest.mark.skip_less_mpi_world_size(8) + @parametrize_with_ids("cuda_graph", [False, True]) + @pytest.mark.parametrize("tp_size,pp_size,ep_size", [(8, 1, 1), (8, 1, 4), + (8, 1, 8)], + ids=["tp8", "tp8ep4", "tp8ep8"]) + def test_fp8(self, cuda_graph, tp_size, pp_size, ep_size): + with LLM( + f"{llm_models_root()}/llama4-models/nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8", + tensor_parallel_size=tp_size, + # Keep this low to avoid warmup OOM in CI + max_seq_len=8192, + pipeline_parallel_size=pp_size, + moe_expert_parallel_size=ep_size, + use_cuda_graph=cuda_graph) as llm: + assert llm.args.quant_config.quant_algo == QuantAlgo.FP8 + assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8 + task = MMLU(self.MODEL_NAME) + task.evaluate(llm) + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + + @skip_pre_hopper + @pytest.mark.skip_less_mpi_world_size(8) + @parametrize_with_ids("cuda_graph", [False, True]) + @pytest.mark.parametrize("tp_size,pp_size,ep_size", [(8, 1, 8)], + ids=["tp8ep8"]) + def test_fp8_chunked_prefill(self, cuda_graph, tp_size, pp_size, ep_size): + with LLM( + f"{llm_models_root()}/llama4-models/nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8", + tensor_parallel_size=tp_size, + # Keep this low to avoid warmup OOM in CI + max_seq_len=8192, + pipeline_parallel_size=pp_size, + moe_expert_parallel_size=ep_size, + enable_chunked_prefill=True, + max_num_tokens=256, + use_cuda_graph=cuda_graph) as llm: + assert llm.args.quant_config.quant_algo == QuantAlgo.FP8 + assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8 + task = MMLU(self.MODEL_NAME) + task.evaluate(llm) + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + @skip_pre_hopper @pytest.mark.skip_less_mpi_world_size(8) @parametrize_with_ids("torch_compile", [True, False]) @@ -587,6 +632,50 @@ def test_fp4(self, cuda_graph, tp_size, pp_size, ep_size): task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + @skip_pre_hopper + @pytest.mark.skip_less_mpi_world_size(4) + @parametrize_with_ids("cuda_graph", [True]) + @pytest.mark.parametrize("tp_size,pp_size,ep_size", [(4, 1, 4)], + ids=["tp4ep4"]) + def test_fp8_chunked_prefill(self, cuda_graph, tp_size, pp_size, ep_size): + with LLM( + f"{llm_models_root()}/llama4-models/Llama-4-Scout-17B-16E-Instruct-FP8", + tensor_parallel_size=tp_size, + max_seq_len=22000, + pipeline_parallel_size=pp_size, + moe_expert_parallel_size=ep_size, + enable_chunked_prefill=True, + max_num_tokens=256, + use_cuda_graph=cuda_graph) as llm: + assert llm.args.quant_config.quant_algo == QuantAlgo.FP8 + assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8 + task = MMLU(self.MODEL_NAME) + task.evaluate(llm) + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + + @skip_pre_blackwell + @pytest.mark.skip_less_mpi_world_size(8) + @parametrize_with_ids("cuda_graph", [True]) + @pytest.mark.parametrize("tp_size,pp_size,ep_size", [(4, 1, 4)], + ids=["tp4ep4"]) + def test_fp4_chunked_prefill(self, cuda_graph, tp_size, pp_size, ep_size): + with LLM( + f"{llm_models_root()}/llama4-models/Llama-4-Scout-17B-16E-Instruct-FP4", + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + moe_expert_parallel_size=ep_size, + max_seq_len=22000, + enable_chunked_prefill=True, + max_num_tokens=256, + use_cuda_graph=cuda_graph) as llm: + assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4 + assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8 + task = MMLU(self.MODEL_NAME) + task.evaluate(llm) + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + class TestMistral7B(LlmapiAccuracyTestHarness): MODEL_NAME = "mistralai/Mistral-7B-v0.1" diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index 74083eeb1a7..a6f44d431fd 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -1923,6 +1923,40 @@ def test_ptp_quickstart_advanced_8gpus(llm_root, llm_venv, model_name, _check_mem_usage(running_log, [mapping[model_name], 0, 0, 0], 8) +@skip_pre_hopper +@pytest.mark.skip_less_device(8) +@pytest.mark.parametrize("cuda_graph", [False, True]) +@pytest.mark.parametrize("model_name,model_path", [ + ("Llama-4-Maverick-17B-128E-Instruct-FP8", + "llama4-models/nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8"), + ("Llama-4-Scout-17B-16E-Instruct-FP8", + "llama4-models/Llama-4-Scout-17B-16E-Instruct-FP8"), + pytest.param('Llama-4-Scout-17B-16E-Instruct-FP4', + 'llama4-models/Llama-4-Scout-17B-16E-Instruct-FP4', + marks=skip_pre_blackwell), +]) +def test_ptp_quickstart_advanced_8gpus_chunked_prefill_sq_22k( + llm_root, llm_venv, model_name, model_path, cuda_graph): + print(f"Testing {model_name} on 8 GPUs.") + example_root = Path(os.path.join(llm_root, "examples", "pytorch")) + cmd = [ + str(example_root / "quickstart_advanced.py"), + "--enable_chunked_prefill", + "--model_dir", + f"{llm_models_root()}/{model_path}", + "--tp_size=8", + "--moe_ep_size=8", + "--max_seq_len=22000", + "--kv_cache_fraction=0.1", + ] + if cuda_graph: + cmd.extend([ + "--use_cuda_graph", + "--cuda_graph_padding_enabled", + ]) + llm_venv.run_cmd(cmd) + + # This test is specifically to be run on 2 GPUs on Blackwell RTX 6000 Pro (SM120) architecture # TODO: remove once we have a node with 8 GPUs and reuse test_ptp_quickstart_advanced_8gpus @skip_no_sm120 diff --git a/tests/integration/test_lists/qa/examples_test_list.txt b/tests/integration/test_lists/qa/examples_test_list.txt index 0b838f112f6..c6d642d710e 100644 --- a/tests/integration/test_lists/qa/examples_test_list.txt +++ b/tests/integration/test_lists/qa/examples_test_list.txt @@ -459,6 +459,11 @@ accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_auto_dtype[tp accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_auto_dtype[tp8ep8-cuda_graph=True] accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_chunked_prefill[attn_backend=FLASHINFER] accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_chunked_prefill[attn_backend=TRTLLM] +accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_fp8[tp8ep8-cuda_graph=True] +accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_fp8[tp8ep4-cuda_graph=True] +accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_fp8[tp8-cuda_graph=True] +accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_fp8_chunked_prefill[tp8ep8-cuda_graph=False] +accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_fp8_chunked_prefill[tp8ep8-cuda_graph=True] accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_fp8_eagle3[tp8-torch_compile=True] accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_fp8_eagle3[tp8-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_auto_dtype[tp8-cuda_graph=False] @@ -468,6 +473,8 @@ accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp8[tp8ep8-cuda_ accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp8[tp4-cuda_graph=True] accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp4[tp8ep8-cuda_graph=True] accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp4[tp4-cuda_graph=True] +accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp8_chunked_prefill[tp4ep4-cuda_graph=True] +accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp4_chunked_prefill[tp4ep4-cuda_graph=True] accuracy/test_llm_api_pytorch.py::TestMixtral8x7B::test_fp8_tp2 accuracy/test_llm_api_pytorch.py::TestMixtral8x7B::test_nvfp4_tp2 accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] @@ -541,6 +548,10 @@ test_e2e.py::test_ptp_quickstart_advanced_8gpus[Llama3.1-405B-FP8-llama-3.1-mode test_e2e.py::test_ptp_quickstart_advanced_8gpus[Mixtral-8x7B-BF16-Mixtral-8x7B-v0.1] test_e2e.py::test_ptp_quickstart_advanced_8gpus[Mixtral-8x7B-NVFP4-nvfp4-quantized/Mixtral-8x7B-Instruct-v0.1] test_e2e.py::test_ptp_quickstart_advanced_deepseek_r1_8gpus[DeepSeek-R1-DeepSeek-R1/DeepSeek-R1] +test_e2e.py::test_ptp_quickstart_advanced_8gpus_chunked_prefill_sq_22k[Llama-4-Maverick-17B-128E-Instruct-FP8-llama4-models/nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8-False] +test_e2e.py::test_ptp_quickstart_advanced_8gpus_chunked_prefill_sq_22k[Llama-4-Maverick-17B-128E-Instruct-FP8-llama4-models/nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8-True] +test_e2e.py::test_ptp_quickstart_advanced_8gpus_chunked_prefill_sq_22k[Llama-4-Scout-17B-16E-Instruct-FP8-llama4-models/Llama-4-Scout-17B-16E-Instruct-FP8-True] +test_e2e.py::test_ptp_quickstart_advanced_8gpus_chunked_prefill_sq_22k[Llama-4-Scout-17B-16E-Instruct-FP4-llama4-models/Llama-4-Scout-17B-16E-Instruct-FP4-True] test_e2e.py::test_relaxed_acceptance_quickstart_advanced_deepseek_r1_8gpus[DeepSeek-R1-DeepSeek-R1/DeepSeek-R1] test_e2e.py::test_ptp_quickstart_multimodal[NVILA-8B-FP16-vila/NVILA-8B-image-False] test_e2e.py::test_ptp_quickstart_multimodal[NVILA-8B-FP16-vila/NVILA-8B-video-False]