diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index 89d866ee9a1..da4df91f693 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -25,8 +25,8 @@ class DynamicTensorSpec: """ input_idx: int dim_idx: int - gen_tuning_buckets: Union[Tuple[int], Callable] - map_to_tuning_buckets: Callable + gen_tuning_buckets: Union[Tuple[int], Callable] = () + map_to_tuning_buckets: Callable = lambda x: x @dataclass(slots=True, unsafe_hash=True) @@ -43,7 +43,7 @@ class ConstraintSpec: infer_shape: Callable -@dataclass(kw_only=True, unsafe_hash=True) +@dataclass(kw_only=True) class TuningConfig: """Configuration for autotuning. @@ -81,9 +81,15 @@ class TuningConfig: ... ), ... ) ... ) + tune_max_num_tokens (int): The maximum saturation number of tokens to be tuned. + During the inference, the input tensor will be saturated with the same value. Or if + any value is provided to the choose_one function, the input tensor will be saturated + with the provided value. + If not provided, the autotuner will not consider the max num tokens. """ dynamic_tensor_specs: Tuple[DynamicTensorSpec, ...] = () constraint_specs: Tuple[ConstraintSpec, ...] = () + tune_max_num_tokens: int = None @dataclass(unsafe_hash=True) @@ -139,12 +145,13 @@ class TunableRunner(ABC): @abstractmethod def get_valid_tactics(self, inputs: List[torch.Tensor], - profile: OptimizationProfile) -> List[int]: + profile: OptimizationProfile, **kwargs) -> List[Any]: """One tactic corresponding to one cuda kernel normally, but how to interpret the meaning of tactic is pure internal details of the runner. - The autotuner will just pass the tactic value to the forward w/o any knowledge on what the tactic - means. + The autotuner will just pass the tactic value to the forward w/o. any knowledge on what the tactic + means. User can choose to implement their own types of tactic for flexibility, such as using a dict-typed + to represent a collection of named configs. tactic==-1 has special meaning, means the fallback kernel which should be able to implement any shapes This fallback tactic is needed for 2 reasons: @@ -166,15 +173,17 @@ def forward( /, # tensors are position only inputs: List[torch.Tensor], *, # all others are keyword args only - tactic: int = -1, - do_preparation: bool = False) -> Any: + tactic: Any = -1, + do_preparation: bool = False, + **kwargs) -> Any: """Forward pass for tunable runners. Args: inputs: List of input tensors (position-only argument) - tactic: Integer ID specifying which implementation tactic to use. - -1 (default) represents the fallback tactic that must be implemented - to handle any input shapes when autotuning is disabled. + tactic: A arbitrary type that represents a specific kernel config. + For instance, it can be an integer number that specifies the unique ID of the implementation tactic to use. + -1 (default) represents the fallback tactic that must be implemented + to handle any input shapes when autotuning is disabled. do_preparation: When True, allows one-time setup operations to be performed before tactic evaluation begins. These operations are excluded from the performance measurements during autotuning. Notice that @@ -182,7 +191,7 @@ def forward( and can be accessed by the following forward calls. Returns: - Any: Output of the forward pass + Any: Output of the forward pass. """ raise NotImplementedError @@ -277,6 +286,7 @@ def __init__(self, warmup=3, repeat=10, stream_delay_micro_secs=1000): self.warmup = warmup self.stream_delay_micro_secs = stream_delay_micro_secs self.profiling_cache = {} + self.registered_tuning_configs = {} self.is_tuning_mode = False # Add statistics tracking @@ -296,7 +306,7 @@ def search_cache( runners: List[TunableRunner], input_shapes: Tuple[torch.Size], tuning_config: TuningConfig, - ) -> Tuple[bool, int, int, OptimizationProfile]: + ) -> Tuple[bool, int, int, Dict[str, Any], OptimizationProfile]: """Search for cached profiling results matching the current configuration. Args: @@ -316,9 +326,14 @@ def search_cache( return False, 0, -1, None - def choose_one(self, custom_op: str, runners: List[TunableRunner], - tuning_config: TuningConfig, inputs: List[torch.Tensor], - **kwargs) -> Tuple[TunableRunner, int]: + def choose_one( + self, + custom_op: str, + runners: List[TunableRunner], + tuning_config: TuningConfig, + inputs: List[torch.Tensor], + **kwargs, + ) -> Tuple: """Choose the best runner and tactic combination through performance profiling. Args: @@ -329,9 +344,10 @@ def choose_one(self, custom_op: str, runners: List[TunableRunner], **kwargs: Arbitrary keyword arguments, will be passed to get_valid_tactics and forward method of each runner Returns: - Tuple[TunableRunner, int]: A tuple containing: + Tuple: A tuple containing: - The selected runner implementation - The best tactic ID for that runner (-1 if using fallback) + - The best config for that runner (if configs is not empty) Note: The method profiles different implementations and tactics to find the @@ -342,26 +358,29 @@ def choose_one(self, custom_op: str, runners: List[TunableRunner], """ input_shapes = tuple(self._get_input_sizes(inputs)) - # Early return if it's not tuning, use cache found one or fallback one if not self.is_tuning_mode: - is_cache_hit, runner_id, tactic, stored_profile = self.search_cache( + is_cache_hit, best_runner_id, best_tactic, stored_profile = self.search_cache( custom_op, runners, input_shapes, tuning_config) - runner = runners[runner_id] + best_runner = runners[best_runner_id] # TODO: check the stored runner and tactic can implement this shape here # Should not directly try (runner, tactic) here, or it will hurt a lot of inference perf. - if not is_cache_hit and len(self.profiling_cache) > 0: - # Only log once for each custom op and only when cache is not empty + + # Record the cache miss config. + # Expect no cache miss in inference. Thus, any cache miss should be recorded. + if not is_cache_hit: logger.warning_once( f"[AutoTunner] Using the fallback tactic, due to cache miss on input shapes={input_shapes}", key=(custom_op)) - return runner, tactic + + return (best_runner, best_tactic) assert len(runners) > 0, "At least one runner is required" assert all([isinstance(r, TunableRunner) for r in runners]), \ "All Given runners must be subclass of TunableRunner" profiles = self._optimization_profiles(tuning_config, inputs) + # Record the total configs to try self.stats.tuned_op_total_configs[custom_op] = len(profiles) @@ -369,63 +388,26 @@ def choose_one(self, custom_op: str, runners: List[TunableRunner], for p in profiles: tensors = self._prepare_input_tensors(p, inputs) - is_cache_hit, runner_id, tactic, _ = self.search_cache( - custom_op, runners, p.get_opt_shapes(), tuning_config) + is_cache_hit, *_ = self.search_cache(custom_op, runners, + p.get_opt_shapes(), + tuning_config) if not is_cache_hit: - min_time = float('inf') # Initialize runner and tactic as None in case of no valid tactic or runners are found - runner_id, tactic = None, None - for r_id, r in enumerate(runners): - # TODO: use FakeTensor here. - valid_tactics = r.get_valid_tactics(tensors, p) - runner_arg_names = { - p.name - for p in inspect.signature( - r.forward).parameters.values() - } - if "do_preparation" in runner_arg_names and len( - valid_tactics) > 0: - r(tensors, tactic=-1, do_preparation=True, **kwargs) - for tac in valid_tactics: - try: - time_measured = self._profile_single_kernel( - r, tensors, tac, **kwargs) - except Exception as e: - shapes = self._get_input_sizes(tensors) - - logger.warning( - f"[Autotuner] Failed when profiling runner={r}, tactic={tac}, shapes={shapes}. Set TLLM_LOG_LEVEL=DEBUG for more details." - ) - logger.debug(f"[Autotuner] Exception captured: {e}") - - # Record the failed profiling combinations - new_tuning_failure_occured = True - if custom_op not in self.stats.failed_profiling_count: - self.stats.failed_profiling_count[ - custom_op] = set() - self.stats.failed_profiling_count[custom_op].add( - AutoTuner._get_cache_key( - custom_op, r, p.get_opt_shapes(), - tuning_config)) - - # Set time_measured to inf to notify the failure of the tactic. This can happen when `get_valid_tactics` mistakenly return wrong tactics - # or some runtime error occurs during profiling. - time_measured = float('inf') - if time_measured < min_time: - min_time = time_measured - runner_id, tactic = r_id, tac - if runner_id is not None: + best_runner_id, best_tactic, has_tuning_failure_occured = self._profile_runners( + custom_op, runners, tensors, p, tuning_config, **kwargs) + if best_runner_id is not None: # At least one valid (runner, tactic) pair is found cache_key = AutoTuner._get_cache_key( - custom_op, runners[runner_id], p.get_opt_shapes(), + custom_op, runners[best_runner_id], p.get_opt_shapes(), tuning_config) # inspect call stack - self.profiling_cache[cache_key] = (runner_id, tactic, p) + self.profiling_cache[cache_key] = (best_runner_id, + best_tactic, p) self.stats.tuned_op_successful_configs[ custom_op] = self.stats.tuned_op_successful_configs.get( custom_op, 0) + 1 logger.debug( - f"[Autotuner] Profiling runner={runners[runner_id]}, tactic={tactic} for cache_key={cache_key}." + f"[Autotuner] Profiling runner={runners[best_runner_id]}, tactic={best_tactic} for cache_key={cache_key}." ) else: logger.warning( @@ -434,6 +416,7 @@ def choose_one(self, custom_op: str, runners: List[TunableRunner], f"If get_valid_tactics is intended to return empty list, please ensure that this profile is not valid for the custom_op " f"and should not occurs during the inference stage, or fallback tactic is implemented. Otherwise, the the tuning process will crash." ) + new_tuning_failure_occured = new_tuning_failure_occured or has_tuning_failure_occured # If failed profiling tactics occurs, log the error. if new_tuning_failure_occured: @@ -450,7 +433,64 @@ def choose_one(self, custom_op: str, runners: List[TunableRunner], _, runner_id, tactic, _ = self.search_cache(custom_op, runners, input_shapes, tuning_config) - return runners[runner_id], tactic + return (runners[runner_id], tactic) + + def _profile_runners( + self, + custom_op: str, + runners: List[TunableRunner], + input_tensors: List[torch.Tensor], + profile: OptimizationProfile, + tuning_config: TuningConfig, + **kwargs, + ) -> float: + min_time = float('inf') + has_tuning_failure_occured = False + best_runner_id, best_tactic = None, None + for runner_id, runner in enumerate(runners): + # TODO: use FakeTensor here. + runner_arg_names = { + p.name + for p in inspect.signature(runner.forward).parameters.values() + } + valid_tactics = runner.get_valid_tactics(input_tensors, profile) + if "do_preparation" in runner_arg_names and len(valid_tactics) > 0: + runner( + input_tensors, + tactic=-1, + do_preparation=True, + **kwargs, + ) + + for tac in valid_tactics: + try: + time_measured = self._profile_single_kernel( + runner, input_tensors, tac, **kwargs) + except Exception as e: + # Handle None tensors for optional inputs + shapes = self._get_input_sizes(input_tensors) + logger.warning( + f"[Autotuner] Failed when profiling runner={runner}, tactic={tac}, shapes={shapes}. Set TLLM_LOG_LEVEL=DEBUG for more details." + ) + logger.debug(f"[Autotuner] Exception captured: {e}") + + # Record the failed profiling combinations + if custom_op not in self.stats.failed_profiling_count: + self.stats.failed_profiling_count[custom_op] = set() + self.stats.failed_profiling_count[custom_op].add( + AutoTuner._get_cache_key(custom_op, runner, + profile.get_opt_shapes(), + tuning_config)) + + # Set time_measured to inf to notify the failure of the tactic. This can happen when `get_valid_tactics` mistakenly return wrong tactics + # or some runtime error occurs during profiling. + time_measured = float('inf') + has_tuning_failure_occured = True + if time_measured < min_time: + min_time = time_measured + best_runner_id, best_tactic = runner_id, tac + + return best_runner_id, best_tactic, has_tuning_failure_occured def _get_input_sizes(self, inputs: List[torch.Tensor]) -> List[torch.Size]: @@ -462,15 +502,19 @@ def _get_input_sizes(self, inputs: List[torch.Tensor]) -> List[torch.Size]: return sizes - def _profile_single_kernel(self, runner: TunableRunner, - inputs: List[torch.Tensor], tactic: int, - **kwargs) -> float: + def _profile_single_kernel( + self, + runner: TunableRunner, + inputs: List[torch.Tensor], + tactic: Any, + **kwargs, + ) -> float: """Profile a single kernel implementation for performance measurement. Args: runner (TunableRunner): The runner implementation to profile inputs (List[torch.Tensor]): Input tensors for the kernel - tactic (int): Tactic ID to use for this profiling run + tactic (Any): Tactic to use for this profiling run Returns: Average execution time in milliseconds @@ -503,7 +547,7 @@ def _profile_single_kernel(self, runner: TunableRunner, shapes = self._get_input_sizes(inputs) logger.debug( - f"[Autotuner] Profiled runner={runner}, tactic={tactic}, shapes={shapes}: {avg_time}ms." + f"[Autotuner] Profiled runner={runner}, tactic={tactic}, shapes={shapes}: {avg_time:.6f}ms." ) return avg_time @@ -541,10 +585,23 @@ def _optimization_profiles( assert inspect.isfunction(spec.gen_tuning_buckets) or isinstance(spec.gen_tuning_buckets, (list, tuple)), \ "The given dynamic dimension must provide a opt value generation function or a list of opt values" if inspect.isfunction(spec.gen_tuning_buckets): - opt_shapes = spec.gen_tuning_buckets( - base_profile.shapes[spec.input_idx][spec.dim_idx].val) + if tuning_config.tune_max_num_tokens is None: + # Use the current input size as the opt value + opt_shapes = spec.gen_tuning_buckets( + base_profile.shapes[spec.input_idx][spec.dim_idx].val) + else: + # Use the tune_max_num_tokens as the opt value + opt_shapes = spec.gen_tuning_buckets( + tuning_config.tune_max_num_tokens) else: + # Default values is an empty tuple, means that user does not want to tune this dimension. opt_shapes = spec.gen_tuning_buckets + # Add the current input value as one of the opt values + opt_shapes = set(opt_shapes) + opt_shapes.add( + spec.map_to_tuning_buckets( + base_profile.shapes[spec.input_idx][spec.dim_idx].val)) + opt_shapes = sorted(list(opt_shapes)) opt_shapes_max = tuple(opt_shapes[1:]) + (float('inf'), ) opt_shapes_max = { v1: v2 @@ -570,6 +627,8 @@ def _optimization_profiles( for spec in tuning_config.constraint_specs: min_value = opt_value = max_value = spec.infer_shape( p.get_opt_shapes()) + if p.shapes[spec.input_idx] == [StaticDim(0)]: + continue p.shapes[spec.input_idx][spec.dim_idx] = DynamicDim( min_value, opt_value, max_value) generated_profiles.append(p) @@ -578,8 +637,13 @@ def _optimization_profiles( @classmethod @lru_cache(maxsize=None) - def _find_nearest_profile(cls, shapes: Tuple[torch.Size], - tuning_config: TuningConfig) -> Tuple: + def _find_nearest_profile( + cls, + shapes: Tuple[torch.Size], + dynamic_tensor_specs: Tuple[DynamicTensorSpec, ...], + constraint_specs: Tuple[ConstraintSpec, ...], + tune_max_num_tokens: int = None, + ) -> Tuple: """Find the nearest optimization profile for given inputs User can define their own nearest profile generation method to reduce the host overhead. @@ -594,13 +658,20 @@ def _find_nearest_profile(cls, shapes: Tuple[torch.Size], """ base_profile = list(list(shape) for shape in shapes) - for spec in tuning_config.dynamic_tensor_specs: + for spec in dynamic_tensor_specs: base_profile[spec.input_idx][ spec.dim_idx] = spec.map_to_tuning_buckets( base_profile[spec.input_idx][spec.dim_idx]) + if tune_max_num_tokens is not None: + base_profile[spec.input_idx][spec.dim_idx] = min( + base_profile[spec.input_idx][spec.dim_idx], + tune_max_num_tokens) + # associated dimensions dependent on other free dynamic dimensions, so assign -1 in the profile - for spec in tuning_config.constraint_specs: + for spec in constraint_specs: + if base_profile[spec.input_idx] == [0]: + continue base_profile[spec.input_idx][spec.dim_idx] = -1 return tuple(tuple(shape) for shape in base_profile) @@ -614,7 +685,10 @@ def _get_cache_key( tuning_config: TuningConfig, ) -> Tuple: return (custom_op, runner.__class__.__name__, hash(runner), - cls._find_nearest_profile(input_shapes, tuning_config)) + cls._find_nearest_profile(input_shapes, + tuning_config.dynamic_tensor_specs, + tuning_config.constraint_specs, + tuning_config.tune_max_num_tokens)) def _create_tensor_like(self, origin_tensor: torch.Tensor, dims: List[Dim]) -> torch.Tensor: @@ -672,5 +746,6 @@ def print_profiling_cache(self): f"[Autotuner] Cache contents: (custom_op, runner, hash(attributes), shape_profiles) -> (runner_id, tactic, shape_profile(ignored))" ) for key, value in self.profiling_cache.items(): - runner_id, tactic, _ = value - logger.debug(f"[Autotuner] {key}: ({runner_id}, {tactic})") + runner_id, tactic, profile = value + logger.debug( + f"[Autotuner] {key}: (runner_id={runner_id}, tactic={tactic})") diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index 0ca269ad157..c8c103557f6 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -22,9 +22,12 @@ def bmm_out(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor) -> None: class MoERunner(TunableRunner): # avoid overhead of creating a new runner in forward pass runner_dict = dict() - tuning_config = TuningConfig(dynamic_tensor_specs=( - DynamicTensorSpec(0, 0, get_last_power_of_2_num_tokens_buckets(8192), - lambda x: min(last_positive_power_of_2(x), 8192)), )) + tuning_config = TuningConfig( + dynamic_tensor_specs=(DynamicTensorSpec( + 0, 0, get_last_power_of_2_num_tokens_buckets(8192), + lambda x: min(last_positive_power_of_2(x), 8192)), ), + tune_max_num_tokens=8192, + ) def __init__( self, @@ -109,15 +112,6 @@ def forward( do_preparation, ) - @classmethod - @lru_cache(maxsize=None) - def refine_tuning_config(cls, tune_max_num_tokens: int): - cls.tuning_config = TuningConfig( - dynamic_tensor_specs=(DynamicTensorSpec( - 0, 0, get_last_power_of_2_num_tokens_buckets( - tune_max_num_tokens), lambda x: min( - last_positive_power_of_2(x), tune_max_num_tokens)), )) - @torch.library.custom_op("trtllm::fused_moe", mutates_args=()) def fused_moe( @@ -153,7 +147,6 @@ def fused_moe( ) -> List[torch.Tensor]: tuner = AutoTuner.get() - MoERunner.refine_tuning_config(tune_max_num_tokens) # Only the non-alltoall case is considered for profiling in the warmup phase. # Therefore, to get the correct tactics during the actual inference, the inputs to the tuner should be the same as when not using alltoall. @@ -186,6 +179,8 @@ def fused_moe( use_fused_finalize=use_fused_finalize, ) + MoERunner.tuning_config.tune_max_num_tokens = tune_max_num_tokens + _, gemm_tactic_1 = tuner.choose_one( "trtllm::fused_moe::gemm1", [moe_runner], diff --git a/tests/unittest/_torch/test_autotuner.py b/tests/unittest/_torch/test_autotuner.py index 21eb0a96260..c2f5c32141a 100644 --- a/tests/unittest/_torch/test_autotuner.py +++ b/tests/unittest/_torch/test_autotuner.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Dict, List import torch @@ -19,14 +19,16 @@ def test_multi_dynamic_dims(): x = torch.rand([5, 1024]) w = torch.rand([7, 19]) dynamic_tensor_specs = ( - DynamicTensorSpec(0, 0, [1, 3, 5], lambda x: x // 2), - DynamicTensorSpec(0, 1, [16, 24, 1024], lambda x: x // 2), + DynamicTensorSpec(0, 0, [1, 3, 5]), + DynamicTensorSpec(0, 1, [16, 24, 1024]), DynamicTensorSpec(1, 1, [3, 7, 9], lambda x: x // 2), ) profiles = tuner._optimization_profiles( tuning_config=TuningConfig(dynamic_tensor_specs=dynamic_tensor_specs), inputs=[x, w]) + # choice(0, 0) * choice(0, 1) * choice(1, 1) + # 3 * 3 * 3 = 27, because 19 is mapped to 9 and already inside the bucket assert len(profiles) == 27 sample_0 = OptimizationProfile(shapes=[[ DynamicDim(min=1, opt=1, max=3), @@ -90,7 +92,7 @@ def check_gemm_tactic_valid(tactic: int, m: int) -> bool: class GemmRunner(TunableRunner): def get_valid_tactics(self, inputs: List[FakeTensor], - profile: OptimizationProfile) -> List[int]: + profile: OptimizationProfile, **kwargs) -> List[int]: # The simulated delay is not deterministic, so we need to return specific tactics here return [-1, 0, 1] @@ -98,7 +100,8 @@ def forward(self, /, inputs: List[torch.Tensor], *, - tactic: int = -1) -> torch.Tensor: + tactic: int = -1, + **kwargs) -> torch.Tensor: assert tactic in [-1, 0, 1] return [gemm_0, gemm_1, gemm_fallback][tactic](*inputs) @@ -258,14 +261,18 @@ def test_multiple_runners_different_attributes(): # Verify different cache keys are generated shapes = (x.shape, w.shape) - cache_key_0 = tuner._get_cache_key(custom_op="test_multiple_runners", - input_shapes=shapes, - runner=runner_0, - tuning_config=tuning_config) - cache_key_1 = tuner._get_cache_key(custom_op="test_multiple_runners", - input_shapes=shapes, - runner=runner_1, - tuning_config=tuning_config) + cache_key_0 = tuner._get_cache_key( + custom_op="test_multiple_runners", + input_shapes=shapes, + runner=runner_0, + tuning_config=tuning_config, + ) + cache_key_1 = tuner._get_cache_key( + custom_op="test_multiple_runners", + input_shapes=shapes, + runner=runner_1, + tuning_config=tuning_config, + ) assert cache_key_0 != cache_key_1, "Runners with different attributes should have different cache keys" @@ -301,3 +308,47 @@ def test_multiple_dynamic_shapes_cache(): ] assert len(cache_entries) == 12, \ f"Expected 12 cache entries for 3x4 shape combinations, got {len(cache_entries)}" + + +class GemmRunnerWithTacticConfigs(TunableRunner): + valid_tactic_ids = [-1, 0, 1] + + def get_valid_tactics( + self, + inputs: List[FakeTensor], + profile: OptimizationProfile, + ) -> List[Dict[str, int]]: + # The simulated delay is not deterministic, so we need to return specific tactics here + return [{ + "block_size": block_size, + "tactic_id": tactic_id + } for tactic_id in self.valid_tactic_ids for block_size in [128, 256]] + + def forward( + self, + /, + inputs: List[torch.Tensor], + *, + tactic: dict = {}, + ) -> torch.Tensor: + # Notice that in fallback case tactic is -1 + if tactic == -1: + # assign default configs for fallback case + block_size, tactic_id = 128, -1 + else: + block_size, tactic_id = tactic["block_size"], tactic["tactic_id"] + assert tactic_id in self.valid_tactic_ids + return [gemm_0, gemm_1, gemm_fallback][tactic_id](*inputs) + + +def test_autotuner_tactic_configs(): + runner_0 = GemmRunnerWithTacticConfigs() + runners = [runner_0] + x, w = torch.randn(64, 64), torch.randn(64, 128) + tuning_config = TuningConfig() + with autotune(): + tuner = AutoTuner.get() + runner, tactic = tuner.choose_one("test_autotuner_tactic_configs", + runners, tuning_config, [x, w]) + + runner_0.forward(inputs=[x, w], tactic=tactic)