Skip to content

Commit dbd4a64

Browse files
committed
[TRTLLM-7963][feat] Cold L2 cache when doing autotune benchmarking.
When running ops in models, the L2 cache is usually cold. But in autotune, we don't clear L2 cache when benchmarking. Some kernels may easily been influenced by warm/cold L2 cache. In order to make the kernel selection more accurate, we clear L2 cache (circle buffer method) for autotune benchmark in this PR. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
1 parent 6b755fd commit dbd4a64

File tree

1 file changed

+81
-6
lines changed

1 file changed

+81
-6
lines changed

tensorrt_llm/_torch/autotuner.py

Lines changed: 81 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from typing import Any, Callable, Dict, List, Set, Tuple, Union
1414

1515
import torch
16+
from cuda.bindings import driver
1617

1718
import tensorrt_llm
1819
from tensorrt_llm.bindings.internal.runtime import delay_kernel
@@ -99,6 +100,7 @@ class TuningConfig:
99100
constraint_specs: Tuple[ConstraintSpec, ...] = ()
100101
tune_max_num_tokens: int = None
101102
inputs_pre_hook: Callable = None
103+
use_cold_l2_cache: bool = True
102104

103105

104106
@dataclass(unsafe_hash=True)
@@ -524,7 +526,7 @@ class AutoTuner:
524526
"""
525527
_instance = None
526528

527-
def __init__(self, warmup=3, repeat=10, stream_delay_micro_secs=1000):
529+
def __init__(self, warmup=3, repeat=30, stream_delay_micro_secs=1000):
528530
self.repeat = repeat
529531
self.warmup = warmup
530532
self.stream_delay_micro_secs = stream_delay_micro_secs
@@ -622,7 +624,7 @@ def choose_one(
622624
self.stats.tuned_op_successful_configs[
623625
custom_op] = self.stats.tuned_op_successful_configs.get(
624626
custom_op, 0) + 1
625-
logger.debug(
627+
logger.info(
626628
f"[Autotuner] Profiling runner={runners[best_runner_id]}, tactic={best_tactic} for cache_key={cache_key}."
627629
)
628630
else:
@@ -685,7 +687,7 @@ def _profile_runners(
685687
for tac in valid_tactics:
686688
try:
687689
time_measured = self._profile_single_kernel(
688-
runner, input_tensors, tac, **kwargs)
690+
runner, input_tensors, tac, tuning_config, **kwargs)
689691
except Exception as e:
690692
# Handle None tensors for optional inputs
691693
shapes = self._get_input_sizes(input_tensors)
@@ -727,6 +729,7 @@ def _profile_single_kernel(
727729
runner: TunableRunner,
728730
inputs: List[torch.Tensor],
729731
tactic: Any,
732+
tuning_config: TuningConfig,
730733
**kwargs,
731734
) -> float:
732735
"""Profile a single kernel implementation for performance measurement.
@@ -744,10 +747,13 @@ def _profile_single_kernel(
744747
to get an average execution time. Stream synchronization and delays
745748
are used to ensure accurate timing.
746749
"""
750+
input_tensor_batches = self._prepare_input_tensors_with_batches(inputs, tuning_config)
751+
752+
747753
stream = torch.cuda.current_stream()
748754
# warm up, no timing
749755
for _ in range(self.warmup):
750-
runner(inputs, tactic=tactic, **kwargs)
756+
runner(input_tensor_batches[-1], tactic=tactic, **kwargs)
751757
stream.synchronize()
752758

753759
# Delay the profiled kernel launch to eliminate affects of host time overhead in profiling.
@@ -758,8 +764,12 @@ def _profile_single_kernel(
758764
end = torch.cuda.Event(enable_timing=True)
759765

760766
start.record(stream=stream)
761-
for _ in range(self.repeat):
762-
runner(inputs, tactic=tactic, **kwargs)
767+
for r in range(self.repeat):
768+
runner(
769+
input_tensor_batches[r % len(input_tensor_batches)],
770+
tactic=tactic,
771+
**kwargs,
772+
)
763773
end.record(stream=stream)
764774
stream.synchronize()
765775

@@ -939,6 +949,39 @@ def _prepare_input_tensors(
939949
tensors.append(tensor)
940950
return tensors
941951

952+
def _prepare_input_tensors_with_batches(
953+
self,
954+
inputs: List[torch.Tensor],
955+
tuning_config: TuningConfig,
956+
) -> List[List[torch.Tensor]]:
957+
if not tuning_config.use_cold_l2_cache:
958+
return [inputs]
959+
960+
one_buffer_bytes = sum(
961+
input.numel() *
962+
input.element_size() if isinstance(input, torch.Tensor) else 0
963+
for input in inputs)
964+
if one_buffer_bytes <= 0:
965+
logger.info(
966+
"[Autotuner] No tensor inputs or zero-sized tensors; falling back to single-batch profiling."
967+
)
968+
return [inputs]
969+
970+
num_buffers = self._get_l2_cache_size_in_bytes(
971+
) * 3 // one_buffer_bytes + 1
972+
num_buffers = min(num_buffers, self.repeat + 1)
973+
974+
inputs_list = [inputs]
975+
for _ in range(num_buffers - 1):
976+
inputs_list.append(
977+
list(t.clone() if isinstance(t, torch.Tensor) else t
978+
for t in inputs))
979+
980+
logger.info(
981+
f"[Autotuner] use_cold_l2_cache={tuning_config.use_cold_l2_cache}, use {num_buffers} different tensors for profiling"
982+
)
983+
return inputs_list
984+
942985
def clear_cache(self) -> None:
943986
"""Clear the profiling cache."""
944987
self.profiling_cache.clear()
@@ -957,3 +1000,35 @@ def print_profiling_cache(self):
9571000
logger.debug(
9581001
f"[Autotuner] {key}: (runner_id={runner_id}, tactic={tactic}, min_time={min_time})"
9591002
)
1003+
1004+
def _get_l2_cache_size_in_bytes(self, device_id: int = 0) -> int:
1005+
device = self._checkCudaErrors(driver.cuDeviceGet(device_id))
1006+
return self._checkCudaErrors(
1007+
driver.cuDeviceGetAttribute(
1008+
driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE,
1009+
device,
1010+
))
1011+
1012+
def _checkCudaErrors(self, result) -> Any:
1013+
status = result[0]
1014+
if status != driver.CUresult.CUDA_SUCCESS:
1015+
code = getattr(status, "value", status)
1016+
raise RuntimeError(
1017+
f"CUDA error code={code}({self._cudaGetErrorEnum(status)})")
1018+
# CUDA APIs always return the status as the first element of the result tuple
1019+
if len(result) == 1:
1020+
return None
1021+
elif len(result) == 2:
1022+
return result[1]
1023+
else:
1024+
return result[1:]
1025+
1026+
def _cudaGetErrorEnum(self, error) -> str:
1027+
from cuda.bindings import nvrtc
1028+
if isinstance(error, driver.CUresult):
1029+
err, name = driver.cuGetErrorName(error)
1030+
return name if err == driver.CUresult.CUDA_SUCCESS else "<unknown>"
1031+
elif isinstance(error, nvrtc.nvrtcResult):
1032+
return nvrtc.nvrtcGetErrorString(error)[1]
1033+
else:
1034+
raise RuntimeError("Unknown error type: {}".format(error))

0 commit comments

Comments
 (0)