Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/layer_wise_benchmarks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def comma_separated_floats(s):
if autotune_flag:
if args.enable_autotuner:
cache_path = os.getenv("TLLM_AUTOTUNER_CACHE_PATH") or None
with autotune(cache_path=cache_path, rank=rank):
with autotune(cache_path=cache_path):
run_pack()
if args.run_type == "GEN":
logger.info("Layer-wise benchmarks: Prefill KV cache")
Expand Down
250 changes: 201 additions & 49 deletions tensorrt_llm/_torch/autotuner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ast
import contextlib
import copy
import enum
import inspect
import itertools
import json
Expand All @@ -16,8 +17,25 @@
from cuda.bindings import driver

import tensorrt_llm
from tensorrt_llm._torch.distributed import Distributed
from tensorrt_llm.bindings.internal.runtime import delay_kernel
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping


class DistributedTuningStrategy(enum.Enum):
"""
Strategy for distributed tuning.
Args:
BROADCAST: One rank (rank 0) tunes and broadcasts results to others
INDEPENDENT: Each rank tunes independently (default for non-comm ops)
MERGE: All ranks participate in tuning and reach merge
PARALLEL: All ranks participate in tuning with partial tactics
"""
BROADCAST = "broadcast"
INDEPENDENT = "independent"
MERGE = "merge"
PARALLEL = "parallel"


@dataclass(slots=True, unsafe_hash=True)
Expand Down Expand Up @@ -99,13 +117,15 @@ class TuningConfig:
This flag is to create circular buffer of input tensors to avoid L2 cache hits to simulate cold L2 cache.
Notice that not all tuning processes can benefit from this feature.
use_cuda_graph (bool): Whether to use CUDA graph for the tuning process.
distributed_tuning_strategy (DistributedTuningStrategy): Strategy for distributed tuning.
"""
dynamic_tensor_specs: Tuple[DynamicTensorSpec, ...] = ()
constraint_specs: Tuple[ConstraintSpec, ...] = ()
tune_max_num_tokens: int = None
inputs_pre_hook: Callable = None
use_cold_l2_cache: bool = False
use_cuda_graph: bool = True
distributed_tuning_strategy: DistributedTuningStrategy = DistributedTuningStrategy.INDEPENDENT


@dataclass(unsafe_hash=True)
Expand Down Expand Up @@ -229,7 +249,16 @@ def unique_id(self):


@contextlib.contextmanager
def autotune(tune_mode: bool = True, cache_path: str = None, rank: int = 0):
def autotune(tune_mode: bool = True, cache_path: str = None):
"""Context manager for autotuning with distributed support.

Args:
tune_mode: Whether to enable tuning mode
cache_path: Path to save/load cache files
"""
autotuner = AutoTuner.get()
rank = autotuner.mapping.rank

# if cache_path is provided, use the rank-specific file
tune_required = tune_mode
if cache_path is not None:
Expand All @@ -242,25 +271,27 @@ def autotune(tune_mode: bool = True, cache_path: str = None, rank: int = 0):
if file_exists:
logger.info(
f"[Autotuner] Loading cache from {cache_path_no_ext_rank}")
AutoTuner.get().profiling_cache.load_cache(cache_path_no_ext_rank)
autotuner.profiling_cache.load_cache(cache_path_no_ext_rank)

# record the old tuning mode
old_mode = AutoTuner.get().is_tuning_mode
AutoTuner.get().is_tuning_mode = tune_required
old_mode = autotuner.is_tuning_mode
autotuner.is_tuning_mode = tune_required
autotune_enabled = tune_required and not old_mode

if autotune_enabled:
logger.info("[Autotuner] Autotuning process starts ...")

try:
yield
finally:
AutoTuner.get().is_tuning_mode = old_mode
autotuner.is_tuning_mode = old_mode
if autotune_enabled:
logger.info("[Autotuner] Autotuning process ends")

# save cache
if cache_path is not None:
logger.info(f"[Autotuner] Saving cache to {cache_path_no_ext_rank}")
AutoTuner.get().profiling_cache.save_cache(cache_path_no_ext_rank)
autotuner.profiling_cache.save_cache(cache_path_no_ext_rank)


@dataclass
Expand Down Expand Up @@ -399,6 +430,9 @@ def get_cache_key(
),
)

def merge_cache_data(self, cache_data: Dict[str, Any]):
self.cache.update(cache_data)

def get_specific_custom_op(self, custom_op: str) -> Dict[Tuple, Tuple]:
return {k: v for k, v in self.cache.items() if k[0] == custom_op}

Expand Down Expand Up @@ -561,6 +595,11 @@ class AutoTuner:
_instance = None

def __init__(self, warmup=2, repeat=10, stream_delay_micro_secs=1000):
# Increase log level for AutoTuner associated logger`
self._log_level_to_info = os.getenv(
"TLLM_AUTOTUNER_LOG_LEVEL_DEBUG_TO_INFO", '0') == '1'
self._debug_logger = logger.info if self._log_level_to_info else logger.debug

self.repeat = repeat
self.warmup = warmup
self.stream_delay_micro_secs = stream_delay_micro_secs
Expand All @@ -575,17 +614,19 @@ def __init__(self, warmup=2, repeat=10, stream_delay_micro_secs=1000):
# Last captured choose_one() contexts
self._last_capture: Optional['AutoTuner.TacticsCapture'] = None

# Increase log level for AutoTuner associated logger
self._log_level_to_info = os.getenv(
"TLLM_AUTOTUNER_LOG_LEVEL_DEBUG_TO_INFO", '0') == '1'
self._debug_logger = logger.info if self._log_level_to_info else logger.debug
# Dsitributed tuning state
self._dist: Optional[Distributed] = None
self.mapping: Mapping = Mapping()

@classmethod
def get(cls):
if cls._instance is None:
cls._instance = AutoTuner()
return cls._instance

def set_mapping(self, mapping: Mapping = None):
self.mapping = mapping

class TacticsCapture:
"""Object returned by capture() that can be iterated to get all tactic combinations.

Expand Down Expand Up @@ -768,42 +809,26 @@ def choose_one(
self.stats.tuned_op_profiled_configs[custom_op] = 0
if custom_op not in self.stats.failed_profiling_count:
self.stats.failed_profiling_count[custom_op] = set()
new_tuning_failure_occured = False

for p in profiles:
tensors = self._prepare_input_tensors(p, inputs)
is_cache_hit, *_ = self.profiling_cache.search_cache(
custom_op, runners, p.get_opt_shapes(), tuning_config)
if not is_cache_hit:
# Initialize runner and tactic as None in case of no valid tactic or runners are found
best_runner_id, best_tactic, min_time, has_tuning_failure_occured = self._profile_runners(
custom_op, runners, tensors, p, tuning_config, **kwargs)
if best_runner_id is not None:
# At least one valid (runner, tactic) pair is found
cache_key = self.profiling_cache.get_cache_key(
custom_op, runners[best_runner_id], p.get_opt_shapes(),
tuning_config)

self._debug_logger(
f"[Autotuner] Profiling runner={runners[best_runner_id]}, tactic={best_tactic} for cache_key={cache_key}."
)
# inspect call stack
self.profiling_cache[cache_key] = (best_runner_id,
best_tactic, min_time)

self.stats.tuned_op_profiled_configs[custom_op] += 1
else:
logger.warning_once(
f"[Autotuner] No valid runner/tactic was found for custom_op={custom_op}, input_shapes={input_shapes}. "
f"At least one valid (runner, tactic) pair is required. "
f"If get_valid_tactics is intended to return empty list, please ensure that this profile is not valid for the custom_op "
f"and should not occurs during the inference stage, or fallback tactic is implemented. Otherwise, the the tuning process will crash.",
key=(custom_op, "warning_autotuning_no_valid_tactic"),
)
new_tuning_failure_occured = new_tuning_failure_occured or has_tuning_failure_occured
new_tuning_failure_occurred = False

# Synchronize ranks before profiling
if self._should_current_rank_tune(
tuning_config.distributed_tuning_strategy):
for p in profiles:
tensors = self._prepare_input_tensors(p, inputs)
is_cache_hit, *_ = self.profiling_cache.search_cache(
custom_op, runners, p.get_opt_shapes(), tuning_config)
if not is_cache_hit:
# Initialize runner and tactic as None in case of no valid tactic or runners are found
best_runner_id, best_tactic, min_time, has_tuning_failure_occurred = self._profile_runners(
custom_op, runners, tensors, p, tuning_config, **kwargs)
new_tuning_failure_occurred = new_tuning_failure_occurred or has_tuning_failure_occurred

self._maybe_sync_cache_data(tuning_config.distributed_tuning_strategy,
custom_op)

# If failed profiling tactics occurs, log the error.
if new_tuning_failure_occured:
if new_tuning_failure_occurred:
logger.warning_once(
f"[Autotuner] New tuning error occurs:"
f"Total failed profiling tactics occurs: {len(self.stats.failed_profiling_count[custom_op])} for custom_op={custom_op}. "
Expand Down Expand Up @@ -834,7 +859,7 @@ def _profile_runners(
**kwargs,
) -> float:
min_time = float('inf')
has_tuning_failure_occured = False
has_tuning_failure_occurred = False
best_runner_id, best_tactic = None, None
# If the inputs_pre_hook is provided, it will be called before profiling.
if tuning_config.inputs_pre_hook is not None:
Expand All @@ -845,8 +870,11 @@ def _profile_runners(
p.name
for p in inspect.signature(runner.forward).parameters.values()
}
valid_tactics = runner.get_valid_tactics(input_tensors, profile,
**kwargs)
all_valid_tactics = runner.get_valid_tactics(
input_tensors, profile, **kwargs)

valid_tactics = self._maybe_parallelize_tactics(
all_valid_tactics, tuning_config.distributed_tuning_strategy)
if "do_preparation" in runner_arg_names and len(valid_tactics) > 0:
runner(
input_tensors,
Expand Down Expand Up @@ -882,12 +910,36 @@ def _profile_runners(
# Set time_measured to inf to notify the failure of the tactic. This can happen when `get_valid_tactics` mistakenly return wrong tactics
# or some runtime error occurs during profiling.
time_measured = float('inf')
has_tuning_failure_occured = True
has_tuning_failure_occurred = True
if time_measured < min_time:
min_time = time_measured
best_runner_id, best_tactic = runner_id, tac

return best_runner_id, best_tactic, min_time, has_tuning_failure_occured
if best_runner_id is not None:
# At least one valid (runner, tactic) pair is found
cache_key = self.profiling_cache.get_cache_key(
custom_op, runners[best_runner_id], profile.get_opt_shapes(),
tuning_config)

self._debug_logger(
f"[Autotuner] Profiling runner={runners[best_runner_id]}, tactic={best_tactic} for cache_key={cache_key}."
)
# inspect call stack
# TODO: use named tuple to make it more readable
self.profiling_cache[cache_key] = (best_runner_id, best_tactic,
min_time)

self.stats.tuned_op_profiled_configs[custom_op] += 1
else:
logger.warning_once(
f"[Autotuner] No valid runner/tactic was found for custom_op={custom_op}, input_shapes={profile.get_opt_shapes()}. "
f"At least one valid (runner, tactic) pair is required. "
f"If get_valid_tactics is intended to return empty list, please ensure that this profile is not valid for the custom_op "
f"and should not occurs during the inference stage, or fallback tactic is implemented. Otherwise, the the tuning process will crash.",
key=(custom_op, "warning_autotuning_no_valid_tactic"),
)

return best_runner_id, best_tactic, min_time, has_tuning_failure_occurred

def _get_input_sizes(self, inputs: List[torch.Tensor]) -> List[torch.Size]:

Expand Down Expand Up @@ -1358,3 +1410,103 @@ def _cudaGetErrorEnum(self, error) -> str:
return nvrtc.nvrtcGetErrorString(error)[1]
else:
raise RuntimeError("Unknown error type: {}".format(error))

def setup_distributed_state(self, mapping: Mapping, dist: Distributed):
"""Setup distributed communication state for autotuning."""
self.mapping = mapping
self._dist = dist
self._debug_logger(
f"[AutoTuner] Whether using distributed tuning: {self._is_distributed()}"
)

def _is_distributed(self) -> bool:
"""Check if we're in a distributed environment."""
return self.mapping is not None and self.mapping.tp_size > 1 and self._dist is not None

def _maybe_parallelize_tactics(
self, all_valid_tactics: List[Any],
strategy: DistributedTuningStrategy) -> List[Any]:
"""Parallelize tactics across all TP ranks if strategy is PARALLEL."""
if strategy == DistributedTuningStrategy.PARALLEL:
# only distribute across TP ranks
# each TP rank will only tune the tactics that are assigned to it
tp_size = self.mapping.tp_size
tp_rank = self.mapping.tp_rank
valid_tactics = []
for idx, tactic in enumerate(all_valid_tactics):
if idx % tp_size == tp_rank:
valid_tactics.append(tactic)
return valid_tactics
else:
return all_valid_tactics

def _maybe_sync_cache_data(self, strategy: DistributedTuningStrategy,
custom_op: str):
"""Synchronize cache data across all ranks."""
if not self._is_distributed():
logger.warning(
f"[AutoTuner] Not in distributed environment, skipping synchronization"
)
return

if strategy == DistributedTuningStrategy.BROADCAST:
self._broadcast_cache_data(custom_op)
elif strategy == DistributedTuningStrategy.INDEPENDENT:
return
elif strategy == DistributedTuningStrategy.MERGE:
self._merge_cache_data(custom_op)
elif strategy == DistributedTuningStrategy.PARALLEL:
self._merge_cache_data(custom_op)
else:
logger.error(
f"[AutoTuner] Unknown distributed tuning strategy: {strategy}, falling back to independent"
)
return

def _merge_cache_data(self, custom_op: str):
cache_data = self.profiling_cache.get_specific_custom_op(custom_op)
merged_cache_data = dict()
all_cache_data = self._dist.tp_allgather(obj=cache_data)

for data in all_cache_data:
for key, value in data.items():
current_time = merged_cache_data.get(key, [
float('inf'),
])[-1]
if value[-1] < current_time:
merged_cache_data[key] = value

self.profiling_cache.merge_cache_data(merged_cache_data)

def _broadcast_cache_data(
self,
custom_op: str,
) -> None:
"""Broadcast tactics from root rank to all other ranks."""
cache_data = self.profiling_cache.get_specific_custom_op(custom_op)
root = 0
cache_data = self._dist.tp_broadcast(obj=cache_data, root=root)

self.profiling_cache.merge_cache_data(cache_data)

def _should_current_rank_tune(self,
strategy: DistributedTuningStrategy) -> bool:
"""Determine if this rank should perform tuning based on strategy."""
if not self._is_distributed():
return True

if strategy == DistributedTuningStrategy.BROADCAST:
# Only rank 0 tunes
return self.mapping.rank == 0
elif strategy in {
DistributedTuningStrategy.INDEPENDENT,
DistributedTuningStrategy.MERGE,
DistributedTuningStrategy.PARALLEL,
}:
# All ranks tune independently
return True
else:
logger.error(
f"[AutoTuner] Unknown distributed tuning strategy: {strategy}, falling back to independent"
)
return True
Loading