1313from typing import Any , Callable , Dict , List , Set , Tuple , Union
1414
1515import torch
16+ from cuda .bindings import driver
1617
1718import tensorrt_llm
1819from tensorrt_llm .bindings .internal .runtime import delay_kernel
@@ -99,6 +100,7 @@ class TuningConfig:
99100 constraint_specs : Tuple [ConstraintSpec , ...] = ()
100101 tune_max_num_tokens : int = None
101102 inputs_pre_hook : Callable = None
103+ use_cold_l2_cache : bool = True
102104
103105
104106@dataclass (unsafe_hash = True )
@@ -524,7 +526,7 @@ class AutoTuner:
524526 """
525527 _instance = None
526528
527- def __init__ (self , warmup = 3 , repeat = 10 , stream_delay_micro_secs = 1000 ):
529+ def __init__ (self , warmup = 3 , repeat = 30 , stream_delay_micro_secs = 1000 ):
528530 self .repeat = repeat
529531 self .warmup = warmup
530532 self .stream_delay_micro_secs = stream_delay_micro_secs
@@ -622,7 +624,7 @@ def choose_one(
622624 self .stats .tuned_op_successful_configs [
623625 custom_op ] = self .stats .tuned_op_successful_configs .get (
624626 custom_op , 0 ) + 1
625- logger .debug (
627+ logger .info (
626628 f"[Autotuner] Profiling runner={ runners [best_runner_id ]} , tactic={ best_tactic } for cache_key={ cache_key } ."
627629 )
628630 else :
@@ -685,7 +687,7 @@ def _profile_runners(
685687 for tac in valid_tactics :
686688 try :
687689 time_measured = self ._profile_single_kernel (
688- runner , input_tensors , tac , ** kwargs )
690+ runner , input_tensors , tac , tuning_config , ** kwargs )
689691 except Exception as e :
690692 # Handle None tensors for optional inputs
691693 shapes = self ._get_input_sizes (input_tensors )
@@ -727,6 +729,7 @@ def _profile_single_kernel(
727729 runner : TunableRunner ,
728730 inputs : List [torch .Tensor ],
729731 tactic : Any ,
732+ tuning_config : TuningConfig ,
730733 ** kwargs ,
731734 ) -> float :
732735 """Profile a single kernel implementation for performance measurement.
@@ -744,10 +747,13 @@ def _profile_single_kernel(
744747 to get an average execution time. Stream synchronization and delays
745748 are used to ensure accurate timing.
746749 """
750+ input_tensor_batches = self ._prepare_input_tensors_with_batches (inputs , tuning_config )
751+
752+
747753 stream = torch .cuda .current_stream ()
748754 # warm up, no timing
749755 for _ in range (self .warmup ):
750- runner (inputs , tactic = tactic , ** kwargs )
756+ runner (input_tensor_batches [ - 1 ] , tactic = tactic , ** kwargs )
751757 stream .synchronize ()
752758
753759 # Delay the profiled kernel launch to eliminate affects of host time overhead in profiling.
@@ -758,8 +764,12 @@ def _profile_single_kernel(
758764 end = torch .cuda .Event (enable_timing = True )
759765
760766 start .record (stream = stream )
761- for _ in range (self .repeat ):
762- runner (inputs , tactic = tactic , ** kwargs )
767+ for r in range (self .repeat ):
768+ runner (
769+ input_tensor_batches [r % len (input_tensor_batches )],
770+ tactic = tactic ,
771+ ** kwargs ,
772+ )
763773 end .record (stream = stream )
764774 stream .synchronize ()
765775
@@ -939,6 +949,39 @@ def _prepare_input_tensors(
939949 tensors .append (tensor )
940950 return tensors
941951
952+ def _prepare_input_tensors_with_batches (
953+ self ,
954+ inputs : List [torch .Tensor ],
955+ tuning_config : TuningConfig ,
956+ ) -> List [List [torch .Tensor ]]:
957+ if not tuning_config .use_cold_l2_cache :
958+ return [inputs ]
959+
960+ one_buffer_bytes = sum (
961+ input .numel () *
962+ input .element_size () if isinstance (input , torch .Tensor ) else 0
963+ for input in inputs )
964+ if one_buffer_bytes <= 0 :
965+ logger .info (
966+ "[Autotuner] No tensor inputs or zero-sized tensors; falling back to single-batch profiling."
967+ )
968+ return [inputs ]
969+
970+ num_buffers = self ._get_l2_cache_size_in_bytes (
971+ ) * 3 // one_buffer_bytes + 1
972+ num_buffers = min (num_buffers , self .repeat + 1 )
973+
974+ inputs_list = [inputs ]
975+ for _ in range (num_buffers - 1 ):
976+ inputs_list .append (
977+ list (t .clone () if isinstance (t , torch .Tensor ) else t
978+ for t in inputs ))
979+
980+ logger .info (
981+ f"[Autotuner] use_cold_l2_cache={ tuning_config .use_cold_l2_cache } , use { num_buffers } different tensors for profiling"
982+ )
983+ return inputs_list
984+
942985 def clear_cache (self ) -> None :
943986 """Clear the profiling cache."""
944987 self .profiling_cache .clear ()
@@ -957,3 +1000,35 @@ def print_profiling_cache(self):
9571000 logger .debug (
9581001 f"[Autotuner] { key } : (runner_id={ runner_id } , tactic={ tactic } , min_time={ min_time } )"
9591002 )
1003+
1004+ def _get_l2_cache_size_in_bytes (self , device_id : int = 0 ) -> int :
1005+ device = self ._checkCudaErrors (driver .cuDeviceGet (device_id ))
1006+ return self ._checkCudaErrors (
1007+ driver .cuDeviceGetAttribute (
1008+ driver .CUdevice_attribute .CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE ,
1009+ device ,
1010+ ))
1011+
1012+ def _checkCudaErrors (self , result ) -> Any :
1013+ status = result [0 ]
1014+ if status != driver .CUresult .CUDA_SUCCESS :
1015+ code = getattr (status , "value" , status )
1016+ raise RuntimeError (
1017+ f"CUDA error code={ code } ({ self ._cudaGetErrorEnum (status )} )" )
1018+ # CUDA APIs always return the status as the first element of the result tuple
1019+ if len (result ) == 1 :
1020+ return None
1021+ elif len (result ) == 2 :
1022+ return result [1 ]
1023+ else :
1024+ return result [1 :]
1025+
1026+ def _cudaGetErrorEnum (self , error ) -> str :
1027+ from cuda .bindings import nvrtc
1028+ if isinstance (error , driver .CUresult ):
1029+ err , name = driver .cuGetErrorName (error )
1030+ return name if err == driver .CUresult .CUDA_SUCCESS else "<unknown>"
1031+ elif isinstance (error , nvrtc .nvrtcResult ):
1032+ return nvrtc .nvrtcGetErrorString (error )[1 ]
1033+ else :
1034+ raise RuntimeError ("Unknown error type: {}" .format (error ))
0 commit comments