From 5cab371d29431e476a9229a7404f2fb2a9c79bbd Mon Sep 17 00:00:00 2001 From: yunruis <205571022+yunruis@users.noreply.github.com> Date: Fri, 5 Sep 2025 09:35:19 +0800 Subject: [PATCH] [None][opt] Add batch waiting when scheduling (#7287) Signed-off-by: yunruis <205571022+yunruis@users.noreply.github.com> Co-authored-by: Tao Li @ NVIDIA --- .../_torch/auto_deploy/shim/ad_executor.py | 3 ++ tensorrt_llm/_torch/pyexecutor/config.py | 7 +++ tensorrt_llm/_torch/pyexecutor/py_executor.py | 35 +++++++++++++++ tensorrt_llm/llmapi/llm_args.py | 42 +++++++++++++++++- .../defs/accuracy/test_llm_api_pytorch.py | 43 +++++++++++++++++++ .../test_lists/qa/llm_function_core.txt | 1 + .../api_stability/references/llm.yaml | 8 ++++ 7 files changed, 138 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index 998c8a178f2..bf68dd28c41 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -142,6 +142,9 @@ def __init__( self.pytorch_backend_config.attention_dp_time_out_iters = 50 self.pytorch_backend_config.attention_dp_batching_wait_iters = 10 self.pytorch_backend_config.batch_wait_timeout_ms = 0 + self.pytorch_backend_config.batch_wait_timeout_iters = 0 + self.pytorch_backend_config.batch_wait_max_tokens_ratio = 0.0 + self.pytorch_backend_config.max_num_tokens = seq_info.max_num_tokens self.iter_counter = 0 # NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor... diff --git a/tensorrt_llm/_torch/pyexecutor/config.py b/tensorrt_llm/_torch/pyexecutor/config.py index 7f46c521b6f..952dace406e 100644 --- a/tensorrt_llm/_torch/pyexecutor/config.py +++ b/tensorrt_llm/_torch/pyexecutor/config.py @@ -49,7 +49,14 @@ class PyTorchConfig: attention_dp_time_out_iters: int = 50 attention_dp_batching_wait_iters: int = 10 + max_num_tokens: int = 8192 + batch_wait_timeout_ms: float = 0 + # Iterations to wait before scheduling context even if token budget not reached (0 disables). + batch_wait_timeout_iters: int = 0 + # Threshold ratio of max_num_tokens for token accumulation before scheduling context. + # Value range: [0, 1] (0 disables). + batch_wait_max_tokens_ratio: float = 0.0 attn_backend: str = 'TRTLLM' moe_backend: str = 'CUTLASS' diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 40dd9f9b071..3c44ae0e541 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -190,6 +190,7 @@ def __init__(self, self.active = True self.max_beam_width = max_beam_width self.max_draft_len = max_draft_len + self.max_num_tokens = model_engine.pytorch_backend_config.max_num_tokens self.print_log = model_engine.pytorch_backend_config.print_iter_log self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats self.enable_iter_req_stats = model_engine.pytorch_backend_config.enable_iter_req_stats @@ -198,6 +199,10 @@ def __init__(self, self.attention_dp_time_out_iters = model_engine.pytorch_backend_config.attention_dp_time_out_iters self.attention_dp_batching_wait_iters = model_engine.pytorch_backend_config.attention_dp_batching_wait_iters self.batch_wait_timeout_ms = model_engine.pytorch_backend_config.batch_wait_timeout_ms + self.batch_wait_timeout_iters = model_engine.pytorch_backend_config.batch_wait_timeout_iters + self.batch_wait_max_tokens_ratio = model_engine.pytorch_backend_config.batch_wait_max_tokens_ratio + self.enable_batch_waiting = self.batch_wait_timeout_iters > 0 or self.batch_wait_max_tokens_ratio > 0 + self.num_fetch_requests_cur_rank = 0 self.num_fetch_requests = 0 self.shutdown_event = threading.Event() @@ -244,6 +249,7 @@ def __init__(self, self.max_batch_size = max_batch_size self.adp_ctx_waiting_iters_count = 0 self.adp_ctx_batching_wait_iters_count = 0 + self.batch_wait_iters_count = 0 # request fetcher initialization self.executor_request_queue = ExecutorRequestQueue( @@ -1397,6 +1403,27 @@ def _balance_adp_requests(self, context_requests: list[LlmRequest], balanced_context_requests = context_requests return balanced_context_requests + def _waiting_requests(self, context_requests: list[LlmRequest], + generation_requests: list[LlmRequest]): + if not self.enable_batch_waiting: + return context_requests + + waited_context_requests = [] + stop_waiting = False + num_scheduled_ctx_tokens = sum( + len(ctx_req.get_tokens(0)) for ctx_req in context_requests) + num_scheduled_gen_tokens = sum(1 + gen_req.num_draft_tokens + for gen_req in generation_requests) + num_scheduled_tokens = num_scheduled_ctx_tokens + num_scheduled_gen_tokens + + stop_waiting = self.batch_wait_iters_count >= self.batch_wait_timeout_iters or num_scheduled_tokens >= self.batch_wait_max_tokens_ratio * self.max_num_tokens + if stop_waiting: + waited_context_requests = context_requests + self.batch_wait_iters_count = 0 + else: + self.batch_wait_iters_count += 1 + return waited_context_requests + @nvtx_range("_schedule") def _schedule(self): scheduler_output = self.scheduler.schedule_request( @@ -1407,6 +1434,14 @@ def _schedule(self): scheduler_output.context_requests, scheduler_output.generation_requests) + # if no generation requests, no need to wait, to avoid dead waiting + if not self.enable_attention_dp and self.enable_batch_waiting and len( + scheduler_output.context_requests) > 0 and len( + scheduler_output.generation_requests) > 0: + scheduled_context_requests = self._waiting_requests( + scheduler_output.context_requests, + scheduler_output.generation_requests) + scheduled_requests = ScheduledRequests() scheduled_requests.context_requests = scheduled_context_requests scheduled_requests.generation_requests = scheduler_output.generation_requests diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index ecf0ffdf362..b527b8a450d 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -2208,6 +2208,18 @@ class TorchLlmArgs(BaseLlmArgs): "If greater than 0, the request queue might wait up to batch_wait_timeout_ms to receive max_batch_size requests, if fewer than max_batch_size requests are currently available. If 0, no waiting occurs.", status="prototype") + batch_wait_timeout_iters: int = Field( + default=0, + description= + "Maximum number of iterations the scheduler will wait to accumulate new coming requests for improved GPU utilization efficiency. If greater than 0, the scheduler will delay batch processing to gather more requests up to the specified iteration limit. If 0, disables timeout-iters-based batching delays.", + status="prototype") + + batch_wait_max_tokens_ratio: float = Field( + default=0, + description= + "Token accumulation threshold ratio for batch scheduling optimization. If greater than 0, the scheduler will accumulate requests locally until the total token count reaches batch_wait_max_tokens_ratio * max_num_tokens. This mechanism enhances GPU utilization efficiency by ensuring adequate batch sizes.If 0 disables token-based batching delays.", + status="prototype") + torch_compile_config: Optional[TorchCompileConfig] = Field( default=None, description="Torch compile config.", status="prototype") @@ -2481,6 +2493,31 @@ def validate_batch_wait_timeout_ms(self) -> 'TorchLlmArgs': raise ValueError("batch_wait_timeout_ms must be greater than 0") return self + @model_validator(mode='after') + def validate_batch_wait_timeout_iters(self) -> 'TorchLlmArgs': + if self.batch_wait_timeout_iters < 0: + raise ValueError( + f"batch_wait_timeout_iters must be >= 0, got {self.batch_wait_timeout_iters}" + ) + return self + + @model_validator(mode='after') + def validate_batch_wait_max_tokens_ratio(self) -> 'TorchLlmArgs': + if self.batch_wait_max_tokens_ratio < 0 or self.batch_wait_max_tokens_ratio > 1: + raise ValueError( + f"batch_wait_max_tokens_ratio must be in range [0, 1], got {self.batch_wait_max_tokens_ratio}" + ) + return self + + def get_executor_config( + self, + _hf_model_dir: Optional[Path] = None, + tokenizer: Optional[TokenizerBase] = None, + ) -> _ExecutorConfig: + executor_config = super().get_executor_config(_hf_model_dir, tokenizer) + executor_config.mm_encoder_only = self.mm_encoder_only + return executor_config + # TODO: Remove this after the PyTorch backend is fully migrated to TorchLlmArgs from ExecutorConfig def get_pytorch_backend_config(self) -> "PyTorchConfig": from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig @@ -2547,7 +2584,10 @@ def get_pytorch_backend_config(self) -> "PyTorchConfig": attention_dp_batching_wait_iters=self.attention_dp_config. batching_wait_iters if self.attention_dp_config is not None else AttentionDpConfig.model_fields['batching_wait_iters'].default, - batch_wait_timeout_ms=self.batch_wait_timeout_ms) + batch_wait_timeout_ms=self.batch_wait_timeout_ms, + batch_wait_timeout_iters=self.batch_wait_timeout_iters, + batch_wait_max_tokens_ratio=self.batch_wait_max_tokens_ratio, + ) def update_llm_args_with_extra_dict( diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index d8639e19149..a20136d1ea4 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -1606,6 +1606,49 @@ def test_nvfp4(self, fp8kv, attention_dp, cuda_graph, overlap_scheduler, task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + @skip_pre_blackwell + @parametrize_with_ids("torch_compile", [False, True]) + @parametrize_with_ids("fp8kv,cuda_graph,overlap_scheduler", + [(False, False, False), (True, True, True)]) + @parametrize_with_ids("mtp_nextn", [0, 2]) + @parametrize_with_ids( + "batch_wait_timeout_iters,batch_wait_max_tokens_ratio", [(0, 0), + (10, 0.75), + (10, 0), + (0, 0.75)]) + def test_nvfp4_batch_waiting(self, torch_compile, fp8kv, cuda_graph, + overlap_scheduler, mtp_nextn, + batch_wait_timeout_iters, + batch_wait_max_tokens_ratio): + moe_backend = "CUTLASS" + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75) + torch_compile_config = TorchCompileConfig( + enable_fullgraph=True, + enable_piecewise_cuda_graph=cuda_graph, + capture_num_tokens=[2048, 8192], + max_num_streams=3) if torch_compile else None + pytorch_config = dict( + disable_overlap_scheduler=not overlap_scheduler, + cuda_graph_config=CudaGraphConfig() if cuda_graph else None, + torch_compile_config=torch_compile_config, + batch_wait_timeout_iters=batch_wait_timeout_iters, + batch_wait_max_tokens_ratio=batch_wait_max_tokens_ratio, + moe_config=MoeConfig(backend=moe_backend)) + mtp_config = None + if mtp_nextn > 0: + mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn) + if fp8kv: + kv_cache_config.dtype = "fp8" + with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/nvfp4_moe_only_mtp", + kv_cache_config=kv_cache_config, + **pytorch_config, + enable_attention_dp=False, + speculative_config=mtp_config) as llm: + assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4 + + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + @pytest.mark.skip_less_device(4) @skip_pre_blackwell @parametrize_with_ids("torch_compile", [False, True]) diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index c1ab5b3a1ce..b556560e682 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -465,6 +465,7 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_batch_waiting[batch_wait_timeout_iters=10-batch_wait_max_tokens_ratio=0.75-mtp_nextn=0-fp8kv=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus_static_eplb accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=0] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2] diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index 51518bbccdb..eaf34c56d37 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -131,6 +131,14 @@ methods: annotation: float default: 0 status: prototype + batch_wait_timeout_iters: + annotation: int + default: 0 + status: prototype + batch_wait_max_tokens_ratio: + annotation: float + default: 0 + status: prototype print_iter_log: annotation: bool default: False