diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index 29829f7f644..9a98d14e7e8 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -137,6 +137,9 @@ def __init__( self.pytorch_backend_config.attention_dp_time_out_iters = 50 self.pytorch_backend_config.attention_dp_batching_wait_iters = 10 self.pytorch_backend_config.batch_wait_timeout_ms = 0 + self.pytorch_backend_config.batch_wait_timeout_iters = 0 + self.pytorch_backend_config.batch_wait_max_tokens_ratio = 0.0 + self.pytorch_backend_config.max_num_tokens = 8192 self.iter_counter = 0 # NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor... diff --git a/tensorrt_llm/_torch/pyexecutor/config.py b/tensorrt_llm/_torch/pyexecutor/config.py index 3be1e5558fc..9dc8f51ac2c 100644 --- a/tensorrt_llm/_torch/pyexecutor/config.py +++ b/tensorrt_llm/_torch/pyexecutor/config.py @@ -50,7 +50,11 @@ class PyTorchConfig: attention_dp_time_out_iters: int = 50 attention_dp_batching_wait_iters: int = 10 + max_num_tokens: int = 8192 + batch_wait_timeout_ms: float = 0 + batch_wait_timeout_iters: int = 0 + batch_wait_max_tokens_ratio: float = 0 attn_backend: str = 'TRTLLM' moe_backend: str = 'CUTLASS' diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 4b3315560f8..dbe1126b7f1 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -184,6 +184,7 @@ def __init__( self.active = True self.max_beam_width = max_beam_width self.max_draft_len = max_draft_len + self.max_num_tokens = model_engine.pytorch_backend_config.max_num_tokens self.print_log = model_engine.pytorch_backend_config.print_iter_log self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats self.enable_iter_req_stats = model_engine.pytorch_backend_config.enable_iter_req_stats @@ -192,6 +193,10 @@ def __init__( self.attention_dp_time_out_iters = model_engine.pytorch_backend_config.attention_dp_time_out_iters self.attention_dp_batching_wait_iters = model_engine.pytorch_backend_config.attention_dp_batching_wait_iters self.batch_wait_timeout_ms = model_engine.pytorch_backend_config.batch_wait_timeout_ms + self.batch_wait_timeout_iters = model_engine.pytorch_backend_config.batch_wait_timeout_iters + self.batch_wait_max_tokens_ratio = model_engine.pytorch_backend_config.batch_wait_max_tokens_ratio + self.enable_batch_waiting = self.batch_wait_timeout_iters > 0 or self.batch_wait_max_tokens_ratio > 0 + self.num_fetch_requests_cur_rank = 0 self.num_fetch_requests = 0 self.shutdown_event = threading.Event() @@ -236,6 +241,7 @@ def __init__( self.max_batch_size = max_batch_size self.adp_ctx_waiting_iters_count = 0 self.adp_ctx_batching_wait_iters_count = 0 + self.batch_wait_iters_count = 0 # request fetcher initialization self.executor_request_queue = ExecutorRequestQueue( @@ -1334,6 +1340,27 @@ def _balance_adp_requests(self, context_requests: list[LlmRequest], balanced_context_requests = context_requests return balanced_context_requests + def _waiting_requests(self, context_requests: list[LlmRequest], + generation_requests: list[LlmRequest]): + if not self.enable_batch_waiting: + return context_requests + + waited_context_requests = [] + stop_waiting = False + num_scheduled_ctx_tokens = sum( + len(ctx_req.get_tokens(0)) for ctx_req in context_requests) + num_scheduled_gen_tokens = sum(1 + gen_req.num_draft_tokens + for gen_req in generation_requests) + num_scheduled_tokens = num_scheduled_ctx_tokens + num_scheduled_gen_tokens + + stop_waiting = self.batch_wait_iters_count >= self.batch_wait_timeout_iters or num_scheduled_tokens >= self.batch_wait_max_tokens_ratio * self.max_num_tokens + if stop_waiting: + waited_context_requests = context_requests + self.batch_wait_iters_count = 0 + else: + self.batch_wait_iters_count += 1 + return waited_context_requests + @nvtx_range("_schedule") def _schedule(self): scheduler_output = self.scheduler.schedule_request( @@ -1344,6 +1371,14 @@ def _schedule(self): scheduler_output.context_requests, scheduler_output.generation_requests) + # if no generation requests, no need to wait, to avoid dead waiting + if not self.enable_attention_dp and self.enable_batch_waiting and len( + scheduler_output.context_requests) > 0 and len( + scheduler_output.generation_requests) > 0: + scheduled_context_requests = self._waiting_requests( + scheduler_output.context_requests, + scheduler_output.generation_requests) + scheduled_requests = ScheduledRequests() scheduled_requests.context_requests = scheduled_context_requests scheduled_requests.generation_requests = scheduler_output.generation_requests diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 62db888a447..81f621ffe52 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -2255,6 +2255,18 @@ class TorchLlmArgs(BaseLlmArgs): "If greater than 0, the request queue might wait up to batch_wait_timeout_ms to receive max_batch_size requests, if fewer than max_batch_size requests are currently available. If 0, no waiting occurs.", status="prototype") + batch_wait_timeout_iters: int = Field( + default=0, + description= + "Maximum number of iterations the scheduler will wait to accumulate new coming requests for improved GPU utilization efficiency. If greater than 0, the scheduler will delay batch processing to gather more requests up to the specified iteration limit. If 0, disables timeout-iters-based batching delays.", + status="prototype") + + batch_wait_max_tokens_ratio: float = Field( + default=0, + description= + "Token accumulation threshold ratio for batch scheduling optimization. If greater than 0, the scheduler will accumulate requests locally until the total token count reaches batch_wait_max_tokens_ratio * max_num_tokens. This mechanism enhances GPU utilization efficiency by ensuring adequate batch sizes.If 0 disables token-based batching delays.", + status="prototype") + torch_compile_config: Optional[TorchCompileConfig] = Field( default=None, description="Torch compile config.", status="prototype") @@ -2528,6 +2540,22 @@ def validate_batch_wait_timeout_ms(self) -> 'TorchLlmArgs': raise ValueError("batch_wait_timeout_ms must be greater than 0") return self + @model_validator(mode='after') + def validate_batch_wait_timeout_iters(self) -> 'TorchLlmArgs': + if self.batch_wait_timeout_iters < 0: + raise ValueError( + f"batch_wait_timeout_iters must be >= 0, got {self.batch_wait_timeout_iters}" + ) + return self + + @model_validator(mode='after') + def validate_batch_wait_max_tokens_ratio(self) -> 'TorchLlmArgs': + if self.batch_wait_max_tokens_ratio < 0 or self.batch_wait_max_tokens_ratio > 1: + raise ValueError( + f"batch_wait_max_tokens_ratio must be in range [0, 1], got {self.batch_wait_max_tokens_ratio}" + ) + return self + def get_executor_config( self, _hf_model_dir: Optional[Path] = None, @@ -2603,7 +2631,10 @@ def get_pytorch_backend_config(self) -> "PyTorchConfig": attention_dp_batching_wait_iters=self.attention_dp_config. batching_wait_iters if self.attention_dp_config is not None else AttentionDpConfig.model_fields['batching_wait_iters'].default, - batch_wait_timeout_ms=self.batch_wait_timeout_ms) + batch_wait_timeout_ms=self.batch_wait_timeout_ms, + batch_wait_timeout_iters=self.batch_wait_timeout_iters, + batch_wait_max_tokens_ratio=self.batch_wait_max_tokens_ratio, + ) def update_llm_args_with_extra_dict( diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 2667657a168..68b760f52ba 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -1546,6 +1546,49 @@ def test_nvfp4(self, fp8kv, attention_dp, cuda_graph, overlap_scheduler, task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + @skip_pre_blackwell + @parametrize_with_ids("torch_compile", [False, True]) + @parametrize_with_ids("fp8kv,cuda_graph,overlap_scheduler", + [(False, False, False), (True, True, True)]) + @parametrize_with_ids("mtp_nextn", [0, 2]) + @parametrize_with_ids( + "batch_wait_timeout_iters,batch_wait_max_tokens_ratio", [(0, 0), + (10, 0.75), + (10, 0), + (0, 0.75)]) + def test_nvfp4_batch_waiting(self, torch_compile, fp8kv, cuda_graph, + overlap_scheduler, mtp_nextn, + batch_wait_timeout_iters, + batch_wait_max_tokens_ratio): + moe_backend = "CUTLASS" + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75) + torch_compile_config = TorchCompileConfig( + enable_fullgraph=True, + enable_piecewise_cuda_graph=cuda_graph, + capture_num_tokens=[2048, 8192], + max_num_streams=3) if torch_compile else None + pytorch_config = dict( + disable_overlap_scheduler=not overlap_scheduler, + cuda_graph_config=CudaGraphConfig() if cuda_graph else None, + torch_compile_config=torch_compile_config, + batch_wait_timeout_iters=batch_wait_timeout_iters, + batch_wait_max_tokens_ratio=batch_wait_max_tokens_ratio, + moe_config=MoeConfig(backend=moe_backend)) + mtp_config = None + if mtp_nextn > 0: + mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn) + if fp8kv: + kv_cache_config.dtype = "fp8" + with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/nvfp4_moe_only_mtp", + kv_cache_config=kv_cache_config, + **pytorch_config, + enable_attention_dp=False, + speculative_config=mtp_config) as llm: + assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4 + + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + @pytest.mark.skip_less_device(4) @skip_pre_blackwell @parametrize_with_ids("torch_compile", [False, True]) diff --git a/tests/integration/test_lists/qa/llm_function_full.txt b/tests/integration/test_lists/qa/llm_function_full.txt index 18d4faa2e89..d1421f90d98 100644 --- a/tests/integration/test_lists/qa/llm_function_full.txt +++ b/tests/integration/test_lists/qa/llm_function_full.txt @@ -512,6 +512,7 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_batch_waiting[batch_wait_timeout_iters=10-batch_wait_max_tokens_ratio=0.75-mtp_nextn=0-fp8kv=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus_static_eplb accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=0] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2] diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index 7154238f6a8..9618c8972d4 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -131,6 +131,14 @@ methods: annotation: float default: 0 status: prototype + batch_wait_timeout_iters: + annotation: int + default: 0 + status: prototype + batch_wait_max_tokens_ratio: + annotation: float + default: 0 + status: prototype print_iter_log: annotation: bool default: False