Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tensorrt_llm/_torch/attention_backend/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ def __post_init__(self) -> None:
def create_cuda_graph_metadata(self,
max_batch_size: int,
sub_cross_metadata: bool = False,
max_draft_tokens: int = 0) -> Self:
max_draft_tokens: int = 0,
buffers=None) -> Self:
metadata = super().create_cuda_graph_metadata(max_batch_size,
sub_cross_metadata,
max_draft_tokens)
Expand Down
5 changes: 4 additions & 1 deletion tensorrt_llm/_torch/attention_backend/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class AttentionMetadata:

# This buffer is currently only used for TrtllmAttentionMetadata.
cache_indirection: Optional[torch.Tensor] = None
cuda_graph_buffers: dict[str, list[torch.Tensor]] = None

def __post_init__(self) -> None:
if self.is_cross:
Expand Down Expand Up @@ -282,7 +283,8 @@ def prepare(self):
def create_cuda_graph_metadata(self,
max_batch_size: int,
sub_cross_metadata: bool = False,
max_draft_tokens: int = 0) -> Self:
max_draft_tokens: int = 0,
buffers=None) -> Self:
"""
Creates metadata for CUDA graph execution.
CUDA graphs require to use pre-allocated buffers for all tensors in fields.
Expand All @@ -294,6 +296,7 @@ def create_cuda_graph_metadata(self,

cuda_graph_metadata = copy.copy(self)
cuda_graph_metadata.is_cuda_graph = True
cuda_graph_metadata.cuda_graph_buffers = buffers
if self.has_cross_sub_metadata:
cuda_graph_metadata.cross = cuda_graph_metadata.cross.create_cuda_graph_metadata(
max_batch_size, True)
Expand Down
85 changes: 70 additions & 15 deletions tensorrt_llm/_torch/attention_backend/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,21 +600,76 @@ def host_kv_cache_pool_mapping(self) -> Optional[torch.Tensor]:

def __post_init__(self) -> None:
super().__post_init__()
self._post_init_with_buffers(self.cuda_graph_buffers)

def _post_init_with_buffers(self, buffers) -> None:

# Set a default value, as max_num_sequences is not always set.
if self.max_num_sequences is None:
self.max_num_sequences = self.max_num_requests

self.prompt_lens_cuda = torch.empty(
def get_empty(tensor_shape: list[int], dtype: torch.dtype,
cache_name: str) -> torch.Tensor:
"""
Finds a compatible, reusable buffer from a cache or creates a new one.

This function searches for a pre-allocated tensor (buffer) that can be
reused for an operation involving a tensor with the shape of `tensor_shape`.

The compatibility rules are: The buffer's total elements must be >= tensor_shape's.

If a compatible buffer is found, it's returned immediately. Otherwise, a new
buffer is allocated on the 'cuda' device with the give properties of 'tensor_shape' and 'dtype'.

Args:
tensor_shape: The required shape.
dtype: The required dtype.
cache_name: The key for the specific list of buffers to search in.

Returns:
An existing compatible buffer or a newly created one.
"""
if buffers is not None:
# Safely get the list of candidates. Defaults to an empty list if key is missing.
candidate_buffers = buffers.get(cache_name, [])
numel_like = math.prod(tensor_shape)

for buffer in candidate_buffers:
numel_buffer = buffer.numel()

# buffer just needs to be large enough.
if numel_buffer >= numel_like:
return buffer[0:numel_like].view(
tensor_shape) # Found a fit, return immediately.

# If we get here, no suitable buffer was found in the cache. Create a new one.
new_buffer = torch.zeros(tensor_shape, device='cuda', dtype=dtype)
if buffers is not None:
buffers.setdefault(cache_name, []).append(new_buffer)
return new_buffer

def get_empty_like(like_tensor: torch.Tensor,
cache_name: str) -> torch.Tensor:
return get_empty(
like_tensor.shape,
cache_name=cache_name,
dtype=like_tensor.dtype,
)

self.prompt_lens_cuda = get_empty(
(self.max_num_sequences, ),
device='cuda',
cache_name="prompt_lens_cuda",
dtype=torch.int,
)
self.prompt_lens_cpu = torch.empty_like(
self.prompt_lens_cuda,
device='cpu',
pin_memory=True,
)
self.kv_lens_cuda = torch.empty_like(self.prompt_lens_cuda)
self.kv_lens_cuda = get_empty_like(
self.prompt_lens_cuda,
cache_name="kv_lens_cuda",
)
self.kv_lens = torch.empty_like(self.kv_lens_cuda,
device='cpu',
pin_memory=True)
Expand All @@ -628,13 +683,13 @@ def __post_init__(self) -> None:
dtype=torch.int8,
)
if self.kv_cache_manager is not None:
self.kv_cache_block_offsets = torch.empty(
self.kv_cache_block_offsets = get_empty(
[
self.kv_cache_manager.num_pools, self.max_num_sequences, 2,
self.kv_cache_manager.max_blocks_per_seq
],
cache_name="kv_cache_block_offsets",
dtype=torch.int32,
device='cuda',
)
self.host_kv_cache_block_offsets = torch.empty_like(
self.kv_cache_block_offsets,
Expand All @@ -644,37 +699,37 @@ def __post_init__(self) -> None:
self.block_ids_per_seq = None
self.kv_block_ids_per_seq = None
if self.enable_flash_mla:
self.block_ids_per_seq = torch.zeros(
self.block_ids_per_seq = get_empty(
[
self.kv_cache_manager.max_batch_size,
self.kv_cache_manager.max_blocks_per_seq
],
cache_name="block_ids_per_seq",
dtype=torch.int32,
device='cuda',
)
self.kv_block_ids_per_seq = torch.zeros(
self.kv_block_ids_per_seq = get_empty(
[
self.kv_cache_manager.max_batch_size,
self.kv_cache_manager.max_blocks_per_seq
],
cache_name="kv_block_ids_per_seq",
dtype=torch.int32,
device='cuda',
)
if self.enable_paged_context_mla:
# for kv cache reuse/chunked context in MLA
self.ctx_cached_token_indptr = torch.zeros(
self.ctx_cached_token_indptr = get_empty(
(self.max_num_requests + 1, ),
device='cuda',
cache_name="ctx_cached_token_indptr",
dtype=torch.int64,
)
self.host_ctx_cached_token_indptr = torch.zeros_like(
self.ctx_cached_token_indptr,
device='cpu',
pin_memory=True,
)
self.ctx_uncached_token_indptr = torch.zeros(
self.ctx_uncached_token_indptr = get_empty(
(self.max_num_requests + 1, ),
device='cuda',
cache_name="ctx_uncached_token_indptr",
dtype=torch.int64,
)
self.host_ctx_uncached_token_indptr = torch.zeros_like(
Expand All @@ -683,9 +738,9 @@ def __post_init__(self) -> None:
pin_memory=True,
)
# context full seqlens include cached tokens and uncached tokens
self.ctx_kv_indptr = torch.zeros(
self.ctx_kv_indptr = get_empty(
(self.max_num_requests + 1, ),
device='cuda',
cache_name="ctx_kv_indptr",
dtype=torch.int64,
)
self.host_ctx_kv_indptr = torch.zeros_like(
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def __init__(
self.optional_extra_model_inputs = ["mrope_position_deltas"]

def __del__(self):
self._graph.reset()
if self._graph is not None:
self._graph.reset()

def capture(
self,
Expand Down
11 changes: 7 additions & 4 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ def __init__(
self.kv_cache_manager_key = ResourceManagerType.KV_CACHE_MANAGER
self.lora_model_config: Optional[LoraModelConfig] = None
self.cuda_graph_dummy_request = None
self.cuda_graph_meta_buffers: dict[str, list[torch.Tensor]] = {}

# Setup the local cache indirection buffer only once and reuse it.
# This way it can also be used for CUDA graphs.
Expand Down Expand Up @@ -970,15 +971,16 @@ def _maybe_get_cuda_graph(

num_sequences_in_batch = batch_size * self.max_beam_width
attn_metadata = self.attn_metadata.create_cuda_graph_metadata(
num_sequences_in_batch, False, draft_len)
num_sequences_in_batch, False, draft_len,
self.cuda_graph_meta_buffers)

assert attn_metadata.is_cuda_graph

spec_metadata = None
if self.enable_spec_decode:
spec_metadata = self.spec_metadata.create_cuda_graph_metadata(
num_sequences_in_batch)
spec_metadata.draft_tokens = self.draft_tokens_cuda
else:
spec_metadata = None

# Initialize nested dictionary if needed
if batch_size not in self._cuda_graphs:
Expand Down Expand Up @@ -1143,9 +1145,10 @@ def _release_cuda_graphs(self):
for draft_len, graph in draft_graphs.items():
del graph
self._cuda_graphs.clear()
torch.cuda.empty_cache()
del self._cuda_graph_mem_pool
self._cuda_graph_mem_pool = None
self.cuda_graph_meta_buffers.clear()
torch.cuda.empty_cache()

def get_max_num_sequences(self) -> int:
"""
Expand Down
1 change: 0 additions & 1 deletion tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,6 @@ full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen1.5_7b
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_vl_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5359696)
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2.5_7b_chat-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5410391)
accuracy/test_llm_api.py::TestMistral_Nemo_12B_Base::test_fp8 SKIP (https://nvbugs/5413197)
accuracy/test_cli_flow.py::TestLlama3_8BInstructGradient1048k::test_long_context_ppl SKIP (https://nvbugs/5413362)
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp8] SKIP (https://nvbugs/5455140)
Expand Down