Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 89 additions & 5 deletions tensorrt_llm/_torch/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

import torch
from cuda.bindings import driver

import tensorrt_llm
from tensorrt_llm.bindings.internal.runtime import delay_kernel
Expand Down Expand Up @@ -94,11 +95,15 @@ class TuningConfig:
If not provided, the autotuner will not consider the max num tokens.
inputs_pre_hook (Callable): A function that takes a list of input tensors, returns a list of modified input tensors.
It is called before the input tensors are prepared for the tuning process to match the real data distribution.
use_cold_l2_cache (bool): Whether to use cold L2 cache.
This flag is to create circular buffer of input tensors to avoid L2 cache hits to simulate cold L2 cache.
Notice that not all tuning processes can benefit from this feature.
"""
dynamic_tensor_specs: Tuple[DynamicTensorSpec, ...] = ()
constraint_specs: Tuple[ConstraintSpec, ...] = ()
tune_max_num_tokens: int = None
inputs_pre_hook: Callable = None
use_cold_l2_cache: bool = False
use_cuda_graph: bool = True


Expand Down Expand Up @@ -816,6 +821,7 @@ def _profile_runners(
runner=runner,
inputs=input_tensors,
tactic=tac,
tuning_config=tuning_config,
use_cuda_graph=tuning_config.use_cuda_graph,
**kwargs,
)
Expand Down Expand Up @@ -864,6 +870,7 @@ def _profile_single_kernel(
runner: TunableRunner,
inputs: List[torch.Tensor],
tactic: Any,
tuning_config: TuningConfig,
use_cuda_graph: bool = False,
**kwargs,
) -> float:
Expand All @@ -882,6 +889,9 @@ def _profile_single_kernel(
to get an average execution time. Stream synchronization and delays
are used to ensure accurate timing.
"""
input_tensor_batches = self._prepare_input_tensors_with_batches(
inputs, tuning_config)

stream = torch.cuda.current_stream()
# If the warm up time is longer than 0.5ms, we will profile the kernel with fewer repeats.
profile_fewer_repeat = 2
Expand All @@ -897,8 +907,13 @@ def pure_profile(stream: torch.cuda.Stream, repeat: int):
with torch.cuda.stream(stream):
if use_cuda_graph:
with torch.cuda.graph(graph):
for _ in range(repeat):
runner(inputs, tactic=tactic, **kwargs)
for r in range(repeat):
runner(
input_tensor_batches[r %
len(input_tensor_batches)],
tactic=tactic,
**kwargs,
)

stream.synchronize()

Expand All @@ -915,16 +930,20 @@ def pure_profile(stream: torch.cuda.Stream, repeat: int):
if use_cuda_graph:
graph.replay()
else:
for _ in range(repeat):
runner(inputs, tactic=tactic, **kwargs)
for r in range(repeat):
runner(
input_tensor_batches[r % len(input_tensor_batches)],
tactic=tactic,
**kwargs,
)

end.record()
stream.synchronize()

return start.elapsed_time(end) / repeat

for _ in range(self.warmup):
runner(inputs, tactic=tactic, **kwargs)
runner(input_tensor_batches[-1], tactic=tactic, **kwargs)

fewer_repeat_avg_time = pure_profile(stream, profile_fewer_repeat)

Expand Down Expand Up @@ -1127,6 +1146,39 @@ def _prepare_input_tensors(
tensors.append(tensor)
return tensors

def _prepare_input_tensors_with_batches(
self,
inputs: List[torch.Tensor],
tuning_config: TuningConfig,
) -> List[List[torch.Tensor]]:
if not tuning_config.use_cold_l2_cache:
return [inputs]

one_buffer_bytes = sum(
input.numel() *
input.element_size() if isinstance(input, torch.Tensor) else 0
for input in inputs)
if one_buffer_bytes <= 0:
logger.debug(
"[Autotuner] No tensor inputs or zero-sized tensors; falling back to single-batch profiling."
)
return [inputs]

num_buffers = self._get_l2_cache_size_in_bytes(
) * 3 // one_buffer_bytes + 1
num_buffers = min(num_buffers, self.repeat + 1)

inputs_list = [inputs]
for _ in range(num_buffers - 1):
inputs_list.append(
list(t.clone() if isinstance(t, torch.Tensor) else t
for t in inputs))

logger.debug(
f"[Autotuner] use_cold_l2_cache={tuning_config.use_cold_l2_cache}, use {num_buffers} different tensors for profiling"
)
return inputs_list

def clear_cache(self) -> None:
"""Clear the profiling cache."""
self.profiling_cache.clear()
Expand Down Expand Up @@ -1233,3 +1285,35 @@ def replay(self, *config: Tuple[Tuple[TunableRunner, int], ...]):
tactics_capture._replay_context_idx = 0
# Restore previous active capture state
self._active_capture = prev_active

def _get_l2_cache_size_in_bytes(self, device_id: int = 0) -> int:
device = self._checkCudaErrors(driver.cuDeviceGet(device_id))
return self._checkCudaErrors(
driver.cuDeviceGetAttribute(
driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE,
device,
))

def _checkCudaErrors(self, result) -> Any:
status = result[0]
if status != driver.CUresult.CUDA_SUCCESS:
code = getattr(status, "value", status)
raise RuntimeError(
f"CUDA error code={code}({self._cudaGetErrorEnum(status)})")
# CUDA APIs always return the status as the first element of the result tuple
if len(result) == 1:
return None
elif len(result) == 2:
return result[1]
else:
return result[1:]

def _cudaGetErrorEnum(self, error) -> str:
from cuda.bindings import nvrtc
if isinstance(error, driver.CUresult):
err, name = driver.cuGetErrorName(error)
return name if err == driver.CUresult.CUDA_SUCCESS else "<unknown>"
elif isinstance(error, nvrtc.nvrtcResult):
return nvrtc.nvrtcGetErrorString(error)[1]
else:
raise RuntimeError("Unknown error type: {}".format(error))
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class CuteDSLNVFP4BlackwellRunner(TunableRunner):
0, 0, get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2), ),
constraint_specs=(ConstraintSpec(2, 0, fp4_scale_infer_shape), ),
use_cold_l2_cache=True,
)

def __init__(self, alpha: float, output_dtype: torch.dtype):
Expand Down