diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index 8251ca8ed10..02e1acca18b 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -13,6 +13,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import torch +from cuda.bindings import driver import tensorrt_llm from tensorrt_llm.bindings.internal.runtime import delay_kernel @@ -94,11 +95,15 @@ class TuningConfig: If not provided, the autotuner will not consider the max num tokens. inputs_pre_hook (Callable): A function that takes a list of input tensors, returns a list of modified input tensors. It is called before the input tensors are prepared for the tuning process to match the real data distribution. + use_cold_l2_cache (bool): Whether to use cold L2 cache. + This flag is to create circular buffer of input tensors to avoid L2 cache hits to simulate cold L2 cache. + Notice that not all tuning processes can benefit from this feature. """ dynamic_tensor_specs: Tuple[DynamicTensorSpec, ...] = () constraint_specs: Tuple[ConstraintSpec, ...] = () tune_max_num_tokens: int = None inputs_pre_hook: Callable = None + use_cold_l2_cache: bool = False use_cuda_graph: bool = True @@ -816,6 +821,7 @@ def _profile_runners( runner=runner, inputs=input_tensors, tactic=tac, + tuning_config=tuning_config, use_cuda_graph=tuning_config.use_cuda_graph, **kwargs, ) @@ -864,6 +870,7 @@ def _profile_single_kernel( runner: TunableRunner, inputs: List[torch.Tensor], tactic: Any, + tuning_config: TuningConfig, use_cuda_graph: bool = False, **kwargs, ) -> float: @@ -882,6 +889,9 @@ def _profile_single_kernel( to get an average execution time. Stream synchronization and delays are used to ensure accurate timing. """ + input_tensor_batches = self._prepare_input_tensors_with_batches( + inputs, tuning_config) + stream = torch.cuda.current_stream() # If the warm up time is longer than 0.5ms, we will profile the kernel with fewer repeats. profile_fewer_repeat = 2 @@ -897,8 +907,13 @@ def pure_profile(stream: torch.cuda.Stream, repeat: int): with torch.cuda.stream(stream): if use_cuda_graph: with torch.cuda.graph(graph): - for _ in range(repeat): - runner(inputs, tactic=tactic, **kwargs) + for r in range(repeat): + runner( + input_tensor_batches[r % + len(input_tensor_batches)], + tactic=tactic, + **kwargs, + ) stream.synchronize() @@ -915,8 +930,12 @@ def pure_profile(stream: torch.cuda.Stream, repeat: int): if use_cuda_graph: graph.replay() else: - for _ in range(repeat): - runner(inputs, tactic=tactic, **kwargs) + for r in range(repeat): + runner( + input_tensor_batches[r % len(input_tensor_batches)], + tactic=tactic, + **kwargs, + ) end.record() stream.synchronize() @@ -924,7 +943,7 @@ def pure_profile(stream: torch.cuda.Stream, repeat: int): return start.elapsed_time(end) / repeat for _ in range(self.warmup): - runner(inputs, tactic=tactic, **kwargs) + runner(input_tensor_batches[-1], tactic=tactic, **kwargs) fewer_repeat_avg_time = pure_profile(stream, profile_fewer_repeat) @@ -1127,6 +1146,39 @@ def _prepare_input_tensors( tensors.append(tensor) return tensors + def _prepare_input_tensors_with_batches( + self, + inputs: List[torch.Tensor], + tuning_config: TuningConfig, + ) -> List[List[torch.Tensor]]: + if not tuning_config.use_cold_l2_cache: + return [inputs] + + one_buffer_bytes = sum( + input.numel() * + input.element_size() if isinstance(input, torch.Tensor) else 0 + for input in inputs) + if one_buffer_bytes <= 0: + logger.debug( + "[Autotuner] No tensor inputs or zero-sized tensors; falling back to single-batch profiling." + ) + return [inputs] + + num_buffers = self._get_l2_cache_size_in_bytes( + ) * 3 // one_buffer_bytes + 1 + num_buffers = min(num_buffers, self.repeat + 1) + + inputs_list = [inputs] + for _ in range(num_buffers - 1): + inputs_list.append( + list(t.clone() if isinstance(t, torch.Tensor) else t + for t in inputs)) + + logger.debug( + f"[Autotuner] use_cold_l2_cache={tuning_config.use_cold_l2_cache}, use {num_buffers} different tensors for profiling" + ) + return inputs_list + def clear_cache(self) -> None: """Clear the profiling cache.""" self.profiling_cache.clear() @@ -1233,3 +1285,35 @@ def replay(self, *config: Tuple[Tuple[TunableRunner, int], ...]): tactics_capture._replay_context_idx = 0 # Restore previous active capture state self._active_capture = prev_active + + def _get_l2_cache_size_in_bytes(self, device_id: int = 0) -> int: + device = self._checkCudaErrors(driver.cuDeviceGet(device_id)) + return self._checkCudaErrors( + driver.cuDeviceGetAttribute( + driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE, + device, + )) + + def _checkCudaErrors(self, result) -> Any: + status = result[0] + if status != driver.CUresult.CUDA_SUCCESS: + code = getattr(status, "value", status) + raise RuntimeError( + f"CUDA error code={code}({self._cudaGetErrorEnum(status)})") + # CUDA APIs always return the status as the first element of the result tuple + if len(result) == 1: + return None + elif len(result) == 2: + return result[1] + else: + return result[1:] + + def _cudaGetErrorEnum(self, error) -> str: + from cuda.bindings import nvrtc + if isinstance(error, driver.CUresult): + err, name = driver.cuGetErrorName(error) + return name if err == driver.CUresult.CUDA_SUCCESS else "" + elif isinstance(error, nvrtc.nvrtcResult): + return nvrtc.nvrtcGetErrorString(error)[1] + else: + raise RuntimeError("Unknown error type: {}".format(error)) diff --git a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py index 4babe458e0c..eb67b667b23 100644 --- a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py @@ -40,6 +40,7 @@ class CuteDSLNVFP4BlackwellRunner(TunableRunner): 0, 0, get_last_power_of_2_num_tokens_buckets, last_positive_power_of_2), ), constraint_specs=(ConstraintSpec(2, 0, fp4_scale_infer_shape), ), + use_cold_l2_cache=True, ) def __init__(self, alpha: float, output_dtype: torch.dtype):