Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ def __init__(
self.pytorch_backend_config.attention_dp_time_out_iters = 50
self.pytorch_backend_config.attention_dp_batching_wait_iters = 10
self.pytorch_backend_config.batch_wait_timeout_ms = 0
self.pytorch_backend_config.batch_wait_timeout_iters = 0
self.pytorch_backend_config.batch_wait_max_tokens_ratio = 0.0
self.pytorch_backend_config.max_num_tokens = 8192
self.iter_counter = 0

# NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor...
Expand Down
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ class PyTorchConfig:
attention_dp_time_out_iters: int = 50
attention_dp_batching_wait_iters: int = 10

max_num_tokens: int = 8192

batch_wait_timeout_ms: float = 0
batch_wait_timeout_iters: int = 0
batch_wait_max_tokens_ratio: float = 0

attn_backend: str = 'TRTLLM'
moe_backend: str = 'CUTLASS'
Expand Down
35 changes: 35 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def __init__(
self.active = True
self.max_beam_width = max_beam_width
self.max_draft_len = max_draft_len
self.max_num_tokens = model_engine.pytorch_backend_config.max_num_tokens
self.print_log = model_engine.pytorch_backend_config.print_iter_log
self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats
self.enable_iter_req_stats = model_engine.pytorch_backend_config.enable_iter_req_stats
Expand All @@ -192,6 +193,10 @@ def __init__(
self.attention_dp_time_out_iters = model_engine.pytorch_backend_config.attention_dp_time_out_iters
self.attention_dp_batching_wait_iters = model_engine.pytorch_backend_config.attention_dp_batching_wait_iters
self.batch_wait_timeout_ms = model_engine.pytorch_backend_config.batch_wait_timeout_ms
self.batch_wait_timeout_iters = model_engine.pytorch_backend_config.batch_wait_timeout_iters
self.batch_wait_max_tokens_ratio = model_engine.pytorch_backend_config.batch_wait_max_tokens_ratio
self.enable_batch_waiting = self.batch_wait_timeout_iters > 0 or self.batch_wait_max_tokens_ratio > 0

self.num_fetch_requests_cur_rank = 0
self.num_fetch_requests = 0
self.shutdown_event = threading.Event()
Expand Down Expand Up @@ -236,6 +241,7 @@ def __init__(
self.max_batch_size = max_batch_size
self.adp_ctx_waiting_iters_count = 0
self.adp_ctx_batching_wait_iters_count = 0
self.batch_wait_iters_count = 0

# request fetcher initialization
self.executor_request_queue = ExecutorRequestQueue(
Expand Down Expand Up @@ -1334,6 +1340,27 @@ def _balance_adp_requests(self, context_requests: list[LlmRequest],
balanced_context_requests = context_requests
return balanced_context_requests

def _waiting_requests(self, context_requests: list[LlmRequest],
generation_requests: list[LlmRequest]):
if not self.enable_batch_waiting:
return context_requests

waited_context_requests = []
stop_waiting = False
num_scheduled_ctx_tokens = sum(
len(ctx_req.get_tokens(0)) for ctx_req in context_requests)
num_scheduled_gen_tokens = sum(1 + gen_req.num_draft_tokens
for gen_req in generation_requests)
num_scheduled_tokens = num_scheduled_ctx_tokens + num_scheduled_gen_tokens

stop_waiting = self.batch_wait_iters_count >= self.batch_wait_timeout_iters or num_scheduled_tokens >= self.batch_wait_max_tokens_ratio * self.max_num_tokens
if stop_waiting:
waited_context_requests = context_requests
self.batch_wait_iters_count = 0
else:
self.batch_wait_iters_count += 1
return waited_context_requests

@nvtx_range("_schedule")
def _schedule(self):
scheduler_output = self.scheduler.schedule_request(
Expand All @@ -1344,6 +1371,14 @@ def _schedule(self):
scheduler_output.context_requests,
scheduler_output.generation_requests)

# if no generation requests, no need to wait, to avoid dead waiting
if not self.enable_attention_dp and self.enable_batch_waiting and len(
scheduler_output.context_requests) > 0 and len(
scheduler_output.generation_requests) > 0:
scheduled_context_requests = self._waiting_requests(
scheduler_output.context_requests,
scheduler_output.generation_requests)

scheduled_requests = ScheduledRequests()
scheduled_requests.context_requests = scheduled_context_requests
scheduled_requests.generation_requests = scheduler_output.generation_requests
Expand Down
33 changes: 32 additions & 1 deletion tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2255,6 +2255,18 @@ class TorchLlmArgs(BaseLlmArgs):
"If greater than 0, the request queue might wait up to batch_wait_timeout_ms to receive max_batch_size requests, if fewer than max_batch_size requests are currently available. If 0, no waiting occurs.",
status="prototype")

batch_wait_timeout_iters: int = Field(
default=0,
description=
"Maximum number of iterations the scheduler will wait to accumulate new coming requests for improved GPU utilization efficiency. If greater than 0, the scheduler will delay batch processing to gather more requests up to the specified iteration limit. If 0, disables timeout-iters-based batching delays.",
status="prototype")

batch_wait_max_tokens_ratio: float = Field(
default=0,
description=
"Token accumulation threshold ratio for batch scheduling optimization. If greater than 0, the scheduler will accumulate requests locally until the total token count reaches batch_wait_max_tokens_ratio * max_num_tokens. This mechanism enhances GPU utilization efficiency by ensuring adequate batch sizes.If 0 disables token-based batching delays.",
status="prototype")

torch_compile_config: Optional[TorchCompileConfig] = Field(
default=None, description="Torch compile config.", status="prototype")

Expand Down Expand Up @@ -2528,6 +2540,22 @@ def validate_batch_wait_timeout_ms(self) -> 'TorchLlmArgs':
raise ValueError("batch_wait_timeout_ms must be greater than 0")
return self

@model_validator(mode='after')
def validate_batch_wait_timeout_iters(self) -> 'TorchLlmArgs':
if self.batch_wait_timeout_iters < 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.batch_wait_timeout_iters < 0: --> self.batch_wait_timeout_iters <= 0

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think == 0 should be okay? ==0 means no wait right?

raise ValueError(
f"batch_wait_timeout_iters must be >= 0, got {self.batch_wait_timeout_iters}"
)
return self

@model_validator(mode='after')
def validate_batch_wait_max_tokens_ratio(self) -> 'TorchLlmArgs':
if self.batch_wait_max_tokens_ratio < 0 or self.batch_wait_max_tokens_ratio > 1:
raise ValueError(
f"batch_wait_max_tokens_ratio must be in range [0, 1], got {self.batch_wait_max_tokens_ratio}"
)
return self

def get_executor_config(
self,
_hf_model_dir: Optional[Path] = None,
Expand Down Expand Up @@ -2603,7 +2631,10 @@ def get_pytorch_backend_config(self) -> "PyTorchConfig":
attention_dp_batching_wait_iters=self.attention_dp_config.
batching_wait_iters if self.attention_dp_config is not None else
AttentionDpConfig.model_fields['batching_wait_iters'].default,
batch_wait_timeout_ms=self.batch_wait_timeout_ms)
batch_wait_timeout_ms=self.batch_wait_timeout_ms,
batch_wait_timeout_iters=self.batch_wait_timeout_iters,
batch_wait_max_tokens_ratio=self.batch_wait_max_tokens_ratio,
)


def update_llm_args_with_extra_dict(
Expand Down
43 changes: 43 additions & 0 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1546,6 +1546,49 @@ def test_nvfp4(self, fp8kv, attention_dp, cuda_graph, overlap_scheduler,
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

@skip_pre_blackwell
@parametrize_with_ids("torch_compile", [False, True])
@parametrize_with_ids("fp8kv,cuda_graph,overlap_scheduler",
[(False, False, False), (True, True, True)])
@parametrize_with_ids("mtp_nextn", [0, 2])
@parametrize_with_ids(
"batch_wait_timeout_iters,batch_wait_max_tokens_ratio", [(0, 0),
(10, 0.75),
(10, 0),
(0, 0.75)])
def test_nvfp4_batch_waiting(self, torch_compile, fp8kv, cuda_graph,
overlap_scheduler, mtp_nextn,
batch_wait_timeout_iters,
batch_wait_max_tokens_ratio):
moe_backend = "CUTLASS"
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
torch_compile_config = TorchCompileConfig(
enable_fullgraph=True,
enable_piecewise_cuda_graph=cuda_graph,
capture_num_tokens=[2048, 8192],
max_num_streams=3) if torch_compile else None
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
torch_compile_config=torch_compile_config,
batch_wait_timeout_iters=batch_wait_timeout_iters,
batch_wait_max_tokens_ratio=batch_wait_max_tokens_ratio,
moe_config=MoeConfig(backend=moe_backend))
mtp_config = None
if mtp_nextn > 0:
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
if fp8kv:
kv_cache_config.dtype = "fp8"
with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/nvfp4_moe_only_mtp",
kv_cache_config=kv_cache_config,
**pytorch_config,
enable_attention_dp=False,
speculative_config=mtp_config) as llm:
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4

task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

@pytest.mark.skip_less_device(4)
@skip_pre_blackwell
@parametrize_with_ids("torch_compile", [False, True])
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/qa/llm_function_full.txt
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,7 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_batch_waiting[batch_wait_timeout_iters=10-batch_wait_max_tokens_ratio=0.75-mtp_nextn=0-fp8kv=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus_static_eplb
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=0]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2]
Expand Down
8 changes: 8 additions & 0 deletions tests/unittest/api_stability/references/llm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,14 @@ methods:
annotation: float
default: 0
status: prototype
batch_wait_timeout_iters:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that there are separate configuration for both ADP and TP?
Is it possible(or does it make sense) to combine them or it is better to keep the settings separate?

annotation: int
default: 0
status: prototype
batch_wait_max_tokens_ratio:
annotation: float
default: 0
status: prototype
print_iter_log:
annotation: bool
default: False
Expand Down