diff --git a/examples/layer_wise_benchmarks/run.py b/examples/layer_wise_benchmarks/run.py index 8e590dec44e..c1e3ab51339 100644 --- a/examples/layer_wise_benchmarks/run.py +++ b/examples/layer_wise_benchmarks/run.py @@ -206,7 +206,7 @@ def comma_separated_floats(s): if autotune_flag: if args.enable_autotuner: cache_path = os.getenv("TLLM_AUTOTUNER_CACHE_PATH") or None - with autotune(cache_path=cache_path, rank=rank): + with autotune(cache_path=cache_path): run_pack() if args.run_type == "GEN": logger.info("Layer-wise benchmarks: Prefill KV cache") diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index 748b6cbe043..679ce2ad822 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -1,6 +1,7 @@ import ast import contextlib import copy +import enum import inspect import itertools import json @@ -16,8 +17,25 @@ from cuda.bindings import driver import tensorrt_llm +from tensorrt_llm._torch.distributed import Distributed from tensorrt_llm.bindings.internal.runtime import delay_kernel from tensorrt_llm.logger import logger +from tensorrt_llm.mapping import Mapping + + +class DistributedTuningStrategy(enum.Enum): + """ + Strategy for distributed tuning. + Args: + BROADCAST: One rank (rank 0) tunes and broadcasts results to others + INDEPENDENT: Each rank tunes independently (default for non-comm ops) + MERGE: All ranks participate in tuning and reach merge + PARALLEL: All ranks participate in tuning with partial tactics + """ + BROADCAST = "broadcast" + INDEPENDENT = "independent" + MERGE = "merge" + PARALLEL = "parallel" @dataclass(slots=True, unsafe_hash=True) @@ -99,6 +117,7 @@ class TuningConfig: This flag is to create circular buffer of input tensors to avoid L2 cache hits to simulate cold L2 cache. Notice that not all tuning processes can benefit from this feature. use_cuda_graph (bool): Whether to use CUDA graph for the tuning process. + distributed_tuning_strategy (DistributedTuningStrategy): Strategy for distributed tuning. """ dynamic_tensor_specs: Tuple[DynamicTensorSpec, ...] = () constraint_specs: Tuple[ConstraintSpec, ...] = () @@ -106,6 +125,7 @@ class TuningConfig: inputs_pre_hook: Callable = None use_cold_l2_cache: bool = False use_cuda_graph: bool = True + distributed_tuning_strategy: DistributedTuningStrategy = DistributedTuningStrategy.INDEPENDENT @dataclass(unsafe_hash=True) @@ -229,7 +249,16 @@ def unique_id(self): @contextlib.contextmanager -def autotune(tune_mode: bool = True, cache_path: str = None, rank: int = 0): +def autotune(tune_mode: bool = True, cache_path: str = None): + """Context manager for autotuning with distributed support. + + Args: + tune_mode: Whether to enable tuning mode + cache_path: Path to save/load cache files + """ + autotuner = AutoTuner.get() + rank = autotuner.mapping.rank + # if cache_path is provided, use the rank-specific file tune_required = tune_mode if cache_path is not None: @@ -242,25 +271,27 @@ def autotune(tune_mode: bool = True, cache_path: str = None, rank: int = 0): if file_exists: logger.info( f"[Autotuner] Loading cache from {cache_path_no_ext_rank}") - AutoTuner.get().profiling_cache.load_cache(cache_path_no_ext_rank) + autotuner.profiling_cache.load_cache(cache_path_no_ext_rank) # record the old tuning mode - old_mode = AutoTuner.get().is_tuning_mode - AutoTuner.get().is_tuning_mode = tune_required + old_mode = autotuner.is_tuning_mode + autotuner.is_tuning_mode = tune_required autotune_enabled = tune_required and not old_mode + if autotune_enabled: logger.info("[Autotuner] Autotuning process starts ...") + try: yield finally: - AutoTuner.get().is_tuning_mode = old_mode + autotuner.is_tuning_mode = old_mode if autotune_enabled: logger.info("[Autotuner] Autotuning process ends") # save cache if cache_path is not None: logger.info(f"[Autotuner] Saving cache to {cache_path_no_ext_rank}") - AutoTuner.get().profiling_cache.save_cache(cache_path_no_ext_rank) + autotuner.profiling_cache.save_cache(cache_path_no_ext_rank) @dataclass @@ -399,6 +430,9 @@ def get_cache_key( ), ) + def merge_cache_data(self, cache_data: Dict[str, Any]): + self.cache.update(cache_data) + def get_specific_custom_op(self, custom_op: str) -> Dict[Tuple, Tuple]: return {k: v for k, v in self.cache.items() if k[0] == custom_op} @@ -561,6 +595,11 @@ class AutoTuner: _instance = None def __init__(self, warmup=2, repeat=10, stream_delay_micro_secs=1000): + # Increase log level for AutoTuner associated logger` + self._log_level_to_info = os.getenv( + "TLLM_AUTOTUNER_LOG_LEVEL_DEBUG_TO_INFO", '0') == '1' + self._debug_logger = logger.info if self._log_level_to_info else logger.debug + self.repeat = repeat self.warmup = warmup self.stream_delay_micro_secs = stream_delay_micro_secs @@ -575,10 +614,9 @@ def __init__(self, warmup=2, repeat=10, stream_delay_micro_secs=1000): # Last captured choose_one() contexts self._last_capture: Optional['AutoTuner.TacticsCapture'] = None - # Increase log level for AutoTuner associated logger - self._log_level_to_info = os.getenv( - "TLLM_AUTOTUNER_LOG_LEVEL_DEBUG_TO_INFO", '0') == '1' - self._debug_logger = logger.info if self._log_level_to_info else logger.debug + # Dsitributed tuning state + self._dist: Optional[Distributed] = None + self.mapping: Mapping = Mapping() @classmethod def get(cls): @@ -586,6 +624,9 @@ def get(cls): cls._instance = AutoTuner() return cls._instance + def set_mapping(self, mapping: Mapping = None): + self.mapping = mapping + class TacticsCapture: """Object returned by capture() that can be iterated to get all tactic combinations. @@ -768,42 +809,26 @@ def choose_one( self.stats.tuned_op_profiled_configs[custom_op] = 0 if custom_op not in self.stats.failed_profiling_count: self.stats.failed_profiling_count[custom_op] = set() - new_tuning_failure_occured = False - - for p in profiles: - tensors = self._prepare_input_tensors(p, inputs) - is_cache_hit, *_ = self.profiling_cache.search_cache( - custom_op, runners, p.get_opt_shapes(), tuning_config) - if not is_cache_hit: - # Initialize runner and tactic as None in case of no valid tactic or runners are found - best_runner_id, best_tactic, min_time, has_tuning_failure_occured = self._profile_runners( - custom_op, runners, tensors, p, tuning_config, **kwargs) - if best_runner_id is not None: - # At least one valid (runner, tactic) pair is found - cache_key = self.profiling_cache.get_cache_key( - custom_op, runners[best_runner_id], p.get_opt_shapes(), - tuning_config) - - self._debug_logger( - f"[Autotuner] Profiling runner={runners[best_runner_id]}, tactic={best_tactic} for cache_key={cache_key}." - ) - # inspect call stack - self.profiling_cache[cache_key] = (best_runner_id, - best_tactic, min_time) - - self.stats.tuned_op_profiled_configs[custom_op] += 1 - else: - logger.warning_once( - f"[Autotuner] No valid runner/tactic was found for custom_op={custom_op}, input_shapes={input_shapes}. " - f"At least one valid (runner, tactic) pair is required. " - f"If get_valid_tactics is intended to return empty list, please ensure that this profile is not valid for the custom_op " - f"and should not occurs during the inference stage, or fallback tactic is implemented. Otherwise, the the tuning process will crash.", - key=(custom_op, "warning_autotuning_no_valid_tactic"), - ) - new_tuning_failure_occured = new_tuning_failure_occured or has_tuning_failure_occured + new_tuning_failure_occurred = False + + # Synchronize ranks before profiling + if self._should_current_rank_tune( + tuning_config.distributed_tuning_strategy): + for p in profiles: + tensors = self._prepare_input_tensors(p, inputs) + is_cache_hit, *_ = self.profiling_cache.search_cache( + custom_op, runners, p.get_opt_shapes(), tuning_config) + if not is_cache_hit: + # Initialize runner and tactic as None in case of no valid tactic or runners are found + best_runner_id, best_tactic, min_time, has_tuning_failure_occurred = self._profile_runners( + custom_op, runners, tensors, p, tuning_config, **kwargs) + new_tuning_failure_occurred = new_tuning_failure_occurred or has_tuning_failure_occurred + + self._maybe_sync_cache_data(tuning_config.distributed_tuning_strategy, + custom_op) # If failed profiling tactics occurs, log the error. - if new_tuning_failure_occured: + if new_tuning_failure_occurred: logger.warning_once( f"[Autotuner] New tuning error occurs:" f"Total failed profiling tactics occurs: {len(self.stats.failed_profiling_count[custom_op])} for custom_op={custom_op}. " @@ -834,7 +859,7 @@ def _profile_runners( **kwargs, ) -> float: min_time = float('inf') - has_tuning_failure_occured = False + has_tuning_failure_occurred = False best_runner_id, best_tactic = None, None # If the inputs_pre_hook is provided, it will be called before profiling. if tuning_config.inputs_pre_hook is not None: @@ -845,8 +870,11 @@ def _profile_runners( p.name for p in inspect.signature(runner.forward).parameters.values() } - valid_tactics = runner.get_valid_tactics(input_tensors, profile, - **kwargs) + all_valid_tactics = runner.get_valid_tactics( + input_tensors, profile, **kwargs) + + valid_tactics = self._maybe_parallelize_tactics( + all_valid_tactics, tuning_config.distributed_tuning_strategy) if "do_preparation" in runner_arg_names and len(valid_tactics) > 0: runner( input_tensors, @@ -882,12 +910,36 @@ def _profile_runners( # Set time_measured to inf to notify the failure of the tactic. This can happen when `get_valid_tactics` mistakenly return wrong tactics # or some runtime error occurs during profiling. time_measured = float('inf') - has_tuning_failure_occured = True + has_tuning_failure_occurred = True if time_measured < min_time: min_time = time_measured best_runner_id, best_tactic = runner_id, tac - return best_runner_id, best_tactic, min_time, has_tuning_failure_occured + if best_runner_id is not None: + # At least one valid (runner, tactic) pair is found + cache_key = self.profiling_cache.get_cache_key( + custom_op, runners[best_runner_id], profile.get_opt_shapes(), + tuning_config) + + self._debug_logger( + f"[Autotuner] Profiling runner={runners[best_runner_id]}, tactic={best_tactic} for cache_key={cache_key}." + ) + # inspect call stack + # TODO: use named tuple to make it more readable + self.profiling_cache[cache_key] = (best_runner_id, best_tactic, + min_time) + + self.stats.tuned_op_profiled_configs[custom_op] += 1 + else: + logger.warning_once( + f"[Autotuner] No valid runner/tactic was found for custom_op={custom_op}, input_shapes={profile.get_opt_shapes()}. " + f"At least one valid (runner, tactic) pair is required. " + f"If get_valid_tactics is intended to return empty list, please ensure that this profile is not valid for the custom_op " + f"and should not occurs during the inference stage, or fallback tactic is implemented. Otherwise, the the tuning process will crash.", + key=(custom_op, "warning_autotuning_no_valid_tactic"), + ) + + return best_runner_id, best_tactic, min_time, has_tuning_failure_occurred def _get_input_sizes(self, inputs: List[torch.Tensor]) -> List[torch.Size]: @@ -1358,3 +1410,103 @@ def _cudaGetErrorEnum(self, error) -> str: return nvrtc.nvrtcGetErrorString(error)[1] else: raise RuntimeError("Unknown error type: {}".format(error)) + + def setup_distributed_state(self, mapping: Mapping, dist: Distributed): + """Setup distributed communication state for autotuning.""" + self.mapping = mapping + self._dist = dist + self._debug_logger( + f"[AutoTuner] Whether using distributed tuning: {self._is_distributed()}" + ) + + def _is_distributed(self) -> bool: + """Check if we're in a distributed environment.""" + return self.mapping is not None and self.mapping.tp_size > 1 and self._dist is not None + + def _maybe_parallelize_tactics( + self, all_valid_tactics: List[Any], + strategy: DistributedTuningStrategy) -> List[Any]: + """Parallelize tactics across all TP ranks if strategy is PARALLEL.""" + if strategy == DistributedTuningStrategy.PARALLEL: + # only distribute across TP ranks + # each TP rank will only tune the tactics that are assigned to it + tp_size = self.mapping.tp_size + tp_rank = self.mapping.tp_rank + valid_tactics = [] + for idx, tactic in enumerate(all_valid_tactics): + if idx % tp_size == tp_rank: + valid_tactics.append(tactic) + return valid_tactics + else: + return all_valid_tactics + + def _maybe_sync_cache_data(self, strategy: DistributedTuningStrategy, + custom_op: str): + """Synchronize cache data across all ranks.""" + if not self._is_distributed(): + logger.warning( + f"[AutoTuner] Not in distributed environment, skipping synchronization" + ) + return + + if strategy == DistributedTuningStrategy.BROADCAST: + self._broadcast_cache_data(custom_op) + elif strategy == DistributedTuningStrategy.INDEPENDENT: + return + elif strategy == DistributedTuningStrategy.MERGE: + self._merge_cache_data(custom_op) + elif strategy == DistributedTuningStrategy.PARALLEL: + self._merge_cache_data(custom_op) + else: + logger.error( + f"[AutoTuner] Unknown distributed tuning strategy: {strategy}, falling back to independent" + ) + return + + def _merge_cache_data(self, custom_op: str): + cache_data = self.profiling_cache.get_specific_custom_op(custom_op) + merged_cache_data = dict() + all_cache_data = self._dist.tp_allgather(obj=cache_data) + + for data in all_cache_data: + for key, value in data.items(): + current_time = merged_cache_data.get(key, [ + float('inf'), + ])[-1] + if value[-1] < current_time: + merged_cache_data[key] = value + + self.profiling_cache.merge_cache_data(merged_cache_data) + + def _broadcast_cache_data( + self, + custom_op: str, + ) -> None: + """Broadcast tactics from root rank to all other ranks.""" + cache_data = self.profiling_cache.get_specific_custom_op(custom_op) + root = 0 + cache_data = self._dist.tp_broadcast(obj=cache_data, root=root) + + self.profiling_cache.merge_cache_data(cache_data) + + def _should_current_rank_tune(self, + strategy: DistributedTuningStrategy) -> bool: + """Determine if this rank should perform tuning based on strategy.""" + if not self._is_distributed(): + return True + + if strategy == DistributedTuningStrategy.BROADCAST: + # Only rank 0 tunes + return self.mapping.rank == 0 + elif strategy in { + DistributedTuningStrategy.INDEPENDENT, + DistributedTuningStrategy.MERGE, + DistributedTuningStrategy.PARALLEL, + }: + # All ranks tune independently + return True + else: + logger.error( + f"[AutoTuner] Unknown distributed tuning strategy: {strategy}, falling back to independent" + ) + return True diff --git a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py index d497ace49b2..1b072eba481 100644 --- a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py @@ -7,8 +7,9 @@ from ..._utils import get_sm_version from ...math_utils import pad_up -from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec, - OptimizationProfile, TunableRunner, TuningConfig) +from ..autotuner import (AutoTuner, ConstraintSpec, DistributedTuningStrategy, + DynamicTensorSpec, OptimizationProfile, TunableRunner, + TuningConfig) from ..cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE from ..utils import (fp4_scale_infer_shape, get_last_power_of_2_num_tokens_buckets, @@ -364,6 +365,7 @@ class CuteDSLNVFP4BlackwellLinear(TunableRunner): last_positive_power_of_2), ), constraint_specs=(ConstraintSpec(2, 0, fp4_scale_infer_shape), ), use_cold_l2_cache=True, + distributed_tuning_strategy=DistributedTuningStrategy.PARALLEL, ) def __init__(self, diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index fe09758cfe5..d338f611455 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -10,8 +10,9 @@ from tensorrt_llm._utils import get_sm_version from tensorrt_llm.logger import logger -from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec, - OptimizationProfile, TunableRunner, TuningConfig) +from ..autotuner import (AutoTuner, ConstraintSpec, DistributedTuningStrategy, + DynamicTensorSpec, OptimizationProfile, TunableRunner, + TuningConfig) from ..cublaslt_utils import IS_CUBLASLT_AVAILABLE from ..cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE from ..modules.multi_stream_utils import do_multi_stream @@ -35,6 +36,7 @@ class MoERunner(TunableRunner): 0, 0, get_last_power_of_2_num_tokens_buckets, last_positive_power_of_2), ), tune_max_num_tokens=8192, + distributed_tuning_strategy=DistributedTuningStrategy.PARALLEL, ) def __init__( @@ -103,11 +105,8 @@ def unique_id(self): self.output_dtype, self.top_k, self.tp_size, - self.tp_rank, self.ep_size, - self.ep_rank, self.cluster_size, - self.cluster_rank, self.enable_alltoall, self.use_deepseek_fp8_block_scale, self.use_w4_group_scaling, diff --git a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py index a8236d88fcf..f3918d0aa2c 100644 --- a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py @@ -11,8 +11,9 @@ last_positive_power_of_2, next_positive_power_of_2) -from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec, - OptimizationProfile, TunableRunner, TuningConfig) +from ..autotuner import (AutoTuner, ConstraintSpec, DistributedTuningStrategy, + DynamicTensorSpec, OptimizationProfile, TunableRunner, + TuningConfig) def prepare_dummy_topk_and_hook( @@ -345,8 +346,10 @@ def get_tuning_config(cls) -> TuningConfig: dynamic_tensor_specs = cls.get_dynamic_tensor_specs() constraint_specs = cls.get_constraint_specs() - tuning_config = TuningConfig(dynamic_tensor_specs=dynamic_tensor_specs, - constraint_specs=constraint_specs) + tuning_config = TuningConfig( + dynamic_tensor_specs=dynamic_tensor_specs, + constraint_specs=constraint_specs, + distributed_tuning_strategy=DistributedTuningStrategy.PARALLEL) return tuning_config @@ -667,8 +670,10 @@ def get_tuning_config(cls) -> TuningConfig: dynamic_tensor_specs = cls.get_dynamic_tensor_specs() constraint_specs = cls.get_constraint_specs() - tuning_config = TuningConfig(dynamic_tensor_specs=dynamic_tensor_specs, - constraint_specs=constraint_specs) + tuning_config = TuningConfig( + dynamic_tensor_specs=dynamic_tensor_specs, + constraint_specs=constraint_specs, + distributed_tuning_strategy=DistributedTuningStrategy.PARALLEL) return tuning_config @@ -966,8 +971,10 @@ def get_tuning_config(cls) -> TuningConfig: dynamic_tensor_specs = cls.get_dynamic_tensor_specs() constraint_specs = cls.get_constraint_specs() - tuning_config = TuningConfig(dynamic_tensor_specs=dynamic_tensor_specs, - constraint_specs=constraint_specs) + tuning_config = TuningConfig( + dynamic_tensor_specs=dynamic_tensor_specs, + constraint_specs=constraint_specs, + distributed_tuning_strategy=DistributedTuningStrategy.PARALLEL) return tuning_config @@ -1237,8 +1244,10 @@ def get_tuning_config(cls) -> TuningConfig: dynamic_tensor_specs = cls.get_dynamic_tensor_specs() constraint_specs = cls.get_constraint_specs() - tuning_config = TuningConfig(dynamic_tensor_specs=dynamic_tensor_specs, - constraint_specs=constraint_specs) + tuning_config = TuningConfig( + dynamic_tensor_specs=dynamic_tensor_specs, + constraint_specs=constraint_specs, + distributed_tuning_strategy=DistributedTuningStrategy.PARALLEL) return tuning_config @@ -1506,8 +1515,10 @@ def get_tuning_config(cls) -> TuningConfig: dynamic_tensor_specs = cls.get_dynamic_tensor_specs() constraint_specs = cls.get_constraint_specs() - tuning_config = TuningConfig(dynamic_tensor_specs=dynamic_tensor_specs, - constraint_specs=constraint_specs) + tuning_config = TuningConfig( + dynamic_tensor_specs=dynamic_tensor_specs, + constraint_specs=constraint_specs, + distributed_tuning_strategy=DistributedTuningStrategy.PARALLEL) return tuning_config @@ -1764,8 +1775,10 @@ def get_tuning_config(cls) -> TuningConfig: dynamic_tensor_specs = cls.get_dynamic_tensor_specs() constraint_specs = cls.get_constraint_specs() - tuning_config = TuningConfig(dynamic_tensor_specs=dynamic_tensor_specs, - constraint_specs=constraint_specs) + tuning_config = TuningConfig( + dynamic_tensor_specs=dynamic_tensor_specs, + constraint_specs=constraint_specs, + distributed_tuning_strategy=DistributedTuningStrategy.PARALLEL) return tuning_config diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 5da64a55695..3deb54788d4 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -625,7 +625,7 @@ def _run_autotuner_warmup(self, resource_manager: ResourceManager): """Runs a forward pass to populate the autotuner cache.""" if not self.llm_args.enable_autotuner: return - + AutoTuner.get().setup_distributed_state(self.mapping, self.dist) logger.info("Running autotuner warmup...") kv_cache_manager = resource_manager.get_resource_manager( self.kv_cache_manager_key) @@ -635,8 +635,7 @@ def _run_autotuner_warmup(self, resource_manager: ResourceManager): self.batch_size * (self.max_seq_len - 1)) cache_path = os.environ.get("TLLM_AUTOTUNER_CACHE_PATH", None) - with self.no_cuda_graph(), autotune(cache_path=cache_path, - rank=self.mapping.rank): + with self.no_cuda_graph(), autotune(cache_path=cache_path): warmup_request = self._create_warmup_request( resource_manager, curr_max_num_tokens, 0) with self._release_batch_context(warmup_request, diff --git a/tests/unittest/_torch/misc/test_autotuner.py b/tests/unittest/_torch/misc/test_autotuner.py index 2323d0ac980..a6116d544f2 100644 --- a/tests/unittest/_torch/misc/test_autotuner.py +++ b/tests/unittest/_torch/misc/test_autotuner.py @@ -1,20 +1,38 @@ import itertools import os +import pickle +import sys import tempfile from typing import Any, List +import cloudpickle +import pytest import torch +from mpi4py import MPI +import tensorrt_llm import tensorrt_llm._torch.autotuner as autotuner -from tensorrt_llm._torch.autotuner import (AutoTuner, DynamicDim, - DynamicTensorSpec, FakeTensor, - OptimizationProfile, StaticDim, - TunableRunner, TuningConfig, - autotune) +from tensorrt_llm._torch.autotuner import (AutoTuner, DistributedTuningStrategy, + DynamicDim, DynamicTensorSpec, + FakeTensor, OptimizationProfile, + StaticDim, TunableRunner, + TuningConfig, autotune) from tensorrt_llm._torch.utils import (get_power_of_2_num_tokens_buckets, next_positive_power_of_2) from tensorrt_llm.bindings.internal.runtime import delay_kernel from tensorrt_llm.logger import logger +from tensorrt_llm.mapping import Mapping + +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +cloudpickle.register_pickle_by_value(sys.modules[__name__]) +MPI.pickle.__init__( + cloudpickle.dumps, + cloudpickle.loads, + pickle.HIGHEST_PROTOCOL, +) + +# needed since we reuse the mpi executor pool, first test running will leak a thread +pytestmark = pytest.mark.threadleak(enabled=False) def test_multi_dynamic_dims(): @@ -599,3 +617,105 @@ def test_kernel_testing_mismatched_ops(): assert "Custom op mismatch" in error_msg, f"Expected 'Custom op mismatch' in error message, got: {error_msg}" assert "test_op_A" in error_msg, f"Expected 'test_op_A' in error message, got: {error_msg}" assert "test_op_B" in error_msg, f"Expected 'test_op_B' in error message, got: {error_msg}" + + +class DistributedGemmRunner(TunableRunner): + + def __init__(self, prefer_tactics: List[int] = [0, 1]): + self.prefer_tactics = prefer_tactics + + def get_valid_tactics(self, inputs, profile, **kwargs): + # Return all tactics so merge strategy can choose between them + return self.prefer_tactics + + def forward(self, inputs, *, tactic=-1, **kwargs): + # tactic 0 is slower + if tactic % 2 == 0: + for _ in range(5): + inputs[0] @ inputs[1] + return inputs[0] @ inputs[1] + + def unique_id(self): + return () + + +def _distributed_worker_function(world_size, strategy): + """Worker function to run on each MPI rank.""" + rank = tensorrt_llm.mpi_rank() + mapping = Mapping(world_size=world_size, + rank=rank, + tp_size=world_size, + pp_size=1) + tuner = AutoTuner.get() + tuner.clear_cache() + tuner.setup_distributed_state(mapping) + + x = torch.randn(16, 32, device='cuda') + w = torch.randn(32, 64, device='cuda') + inputs = [x, w] + + if strategy == DistributedTuningStrategy.PARALLEL: + # All ranks get the same set of tactics + prefer_tactics = [0, 1, 2, 3] + else: + # Each rank prefers different tactics + prefer_tactics = [rank] + runner = DistributedGemmRunner(prefer_tactics=prefer_tactics) + config = TuningConfig(distributed_tuning_strategy=strategy) + + cache_path = os.environ.get("TLLM_AUTOTUNER_CACHE_PATH", None) + with autotune(tune_mode=True, cache_path=cache_path): + tuner.choose_one(custom_op=f"test_distributed_{strategy}", + runners=[runner], + tuning_config=config, + inputs=inputs) + selected_runner, best_tactic = tuner.choose_one( + custom_op=f"test_distributed_{strategy}", + runners=[runner], + tuning_config=config, + inputs=inputs) + + if strategy == DistributedTuningStrategy.BROADCAST: + # All ranks should select tactic 0 + assert best_tactic == 0 + elif strategy == DistributedTuningStrategy.INDEPENDENT: + # Each rank should select the tactic it prefers + assert best_tactic == rank + elif strategy == DistributedTuningStrategy.MERGE: + # Because tactic 0 is slower, two ranks should always select tactic 1 + assert best_tactic == 1 + elif strategy == DistributedTuningStrategy.PARALLEL: + # Tactic 1 or 3 should be selected since they are faster. + # TODO: This might not cover the case that rank1 tunes nothing + assert best_tactic % 2 == 1 + else: + assert False, f"Unknown strategy: {strategy}" + + return True + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Requires at least 2 GPUs for this test") +@pytest.mark.parametrize( + "strategy", + [ + DistributedTuningStrategy.BROADCAST, + DistributedTuningStrategy.INDEPENDENT, + DistributedTuningStrategy.MERGE, + DistributedTuningStrategy.PARALLEL, + ], +) +@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True) +def test_distributed_broadcast_strategy(strategy, mpi_pool_executor): + """Test broadcast strategy with real MPI processes.""" + world_size = 2 + # Use MPIPoolExecutor to run distributed test + results = mpi_pool_executor.map( + _distributed_worker_function, + *zip(*[( + world_size, + strategy, + )] * world_size), + ) + for r in results: + assert r is True