diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 0ca4d28085b..09793491e4a 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -66,7 +66,8 @@ from ..modules.rms_norm import RMSNorm from ..peft.lora.layer import LoraLayer from ..speculative import SpecMetadata -from ..utils import AuxStreamType, EventType, Fp4QuantizedTensor +from ..utils import (AuxStreamType, EventType, Fp4QuantizedTensor, + create_lm_head_tp_mapping) from .modeling_speculative import SpecDecOneEngineForCausalLM from .modeling_utils import (DecoderModel, EagerFusionConfig, filter_weights, register_auto_model) @@ -145,6 +146,12 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]): self.norm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype) + if self.model_config.mapping.enable_attention_dp and \ + getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False): + self.mapping_lm_head_tp = create_lm_head_tp_mapping( + self.model_config.mapping) + else: + self.mapping_lm_head_tp = self.model_config.mapping @torch.compile(options={"max-autotune": True}) def get_last_token_states(self, hidden_states, attn_metadata): @@ -167,10 +174,21 @@ def forward(self, else: hidden_states = hidden_states[-1].unsqueeze(0) - if not (self.model_config.mapping.enable_attention_dp): + enable_attention_dp = self.model_config.mapping.enable_attention_dp + enable_lm_head_tp_in_adp = self.model_config.mapping.enable_lm_head_tp_in_adp + + # Add pre-lm gather logic + if enable_lm_head_tp_in_adp: + # ADP + LM TP mode: perform All-Gather before LM_head + hidden_states = allgather(hidden_states, + self.mapping_lm_head_tp, + dim=0) + + # Temporarily disable gather_output when not in ADP mode or (in ADP mode and LM TP is enabled) + if not enable_attention_dp or enable_lm_head_tp_in_adp: lm_head.gather_output = False - logits = lm_head(hidden_states) - if not (self.model_config.mapping.enable_attention_dp): + logits = lm_head(hidden_states, is_spec_decoding_head=True) + if not enable_attention_dp or enable_lm_head_tp_in_adp: lm_head.gather_output = True return logits diff --git a/tensorrt_llm/_torch/models/modeling_utils.py b/tensorrt_llm/_torch/models/modeling_utils.py index 284a31c26a6..5b84411a2e1 100755 --- a/tensorrt_llm/_torch/models/modeling_utils.py +++ b/tensorrt_llm/_torch/models/modeling_utils.py @@ -352,7 +352,7 @@ def __init__(self, model: TModel, *, config: ModelConfig[TConfig], self.pp_size = config.mapping.pp_size self.has_custom_lm_head = False - if config.mapping.enable_attention_dp: + if config.mapping.enable_attention_dp and not config.mapping.enable_lm_head_tp_in_adp: self.lm_head = LMHead( vocab_size, hidden_size, diff --git a/tensorrt_llm/_torch/modules/embedding.py b/tensorrt_llm/_torch/modules/embedding.py index 3be7a652f6f..81367352114 100644 --- a/tensorrt_llm/_torch/modules/embedding.py +++ b/tensorrt_llm/_torch/modules/embedding.py @@ -9,6 +9,7 @@ from tensorrt_llm.mapping import Mapping from ..distributed import allgather +from ..utils import create_lm_head_tp_mapping from .linear import Linear, TensorParallelMode @@ -35,6 +36,11 @@ def __init__( local_in_features = embedding_dim local_out_features = num_embeddings mapping = mapping or Mapping() + self.enable_lm_head_tp_in_adp = mapping.enable_attention_dp and \ + getattr(mapping, 'enable_lm_head_tp_in_adp', False) + if self.enable_lm_head_tp_in_adp: + mapping = create_lm_head_tp_mapping(mapping) + tp_size = mapping.tp_size # Attention DP doesn't work with embedding parallelization. @@ -72,6 +78,18 @@ def __init__( self.weight = Parameter(torch.empty(weight_shape, dtype=dtype)) self.register_parameter("bias", None) + # For LM head TP in ADP, we need to slice the weight for the LM head + self.lm_head_slice_obj = None + if self.enable_lm_head_tp_in_adp: + tp_rank = self.mapping.tp_rank + tp_size = self.mapping.tp_size + slice_width = math.ceil(self.out_features / tp_size) + slice_start = tp_rank * slice_width + slice_end = min((tp_rank + 1) * slice_width, self.out_features) + slice_obj = [slice(None)] * len(self.weight.shape) + slice_obj[0] = slice(slice_start, slice_end) + self.lm_head_slice_obj = tuple(slice_obj) + @property def vocab_size_padded(self) -> int: if self.tp_mode == TensorParallelMode.COLUMN and self.gather_output: @@ -80,12 +98,16 @@ def vocab_size_padded(self) -> int: return self.out_features def forward( - self, - input: torch.Tensor, - *, - all_reduce_params: Optional[AllReduceParams] = None + self, + input: torch.Tensor, + *, + all_reduce_params: Optional[AllReduceParams] = None, + is_spec_decoding_head: bool = False, ) -> torch.Tensor: - output = super().forward(input, all_reduce_params=all_reduce_params) + if is_spec_decoding_head and self.enable_lm_head_tp_in_adp: + output = F.linear(input, self.weight[self.lm_head_slice_obj], None) + else: + output = super().forward(input, all_reduce_params=all_reduce_params) if (self.tp_mode == TensorParallelMode.COLUMN and self.gather_output and self.padding_size > 0): output = output[..., :-self.padding_size] diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index ab1ef6e615d..73363798482 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -2,8 +2,11 @@ from typing import TYPE_CHECKING, List, Optional import torch +import torch.nn.functional as F from torch import nn +from tensorrt_llm.mapping import Mapping + from ..attention_backend import AttentionMetadata from ..distributed.ops import allgather from ..model_config import ModelConfig @@ -1069,12 +1072,13 @@ def prepare_drafter_inputs( } @torch.compile(options={"max-autotune": True}) - def get_local_max_and_combined(self, logits): + def get_local_max_and_combined(self, logits, mapping_lm_tp=None): local_max_values, local_argmax = torch.max(logits, dim=-1, keepdim=True) # Adjust indices based on TP rank and size vocab_per_rank = logits.shape[-1] + mapping_lm_tp = mapping_lm_tp if mapping_lm_tp is not None else self.model_config.mapping max_index_per_rank = local_argmax.type( - torch.int32) + (self.model_config.mapping.tp_rank * vocab_per_rank) + torch.int32) + (mapping_lm_tp.tp_rank * vocab_per_rank) # Use torch.stack and flatten instead of view+cat to avoid torch.compile issues # Convert both to float32 to ensure consistent dtype max_index_per_rank_float = max_index_per_rank.float() @@ -1102,6 +1106,7 @@ def get_draft_tokens_from_gathered(self, gathered): def draft_sampler( self, logits: torch.Tensor, + mapping_lm_head_tp: Mapping = None, ): ''' Sampling draft tokens. @@ -1123,6 +1128,20 @@ def draft_sampler( combined = self.get_local_max_and_combined(logits) gathered = allgather(combined, self.model_config.mapping, dim=-1) draft_tokens = self.get_draft_tokens_from_gathered(gathered) + elif (self.model_config is not None + and hasattr(self.model_config, 'mapping') + and self.model_config.mapping.tp_size + > 1) and self.model_config.mapping.enable_lm_head_tp_in_adp: + # For ADP + LM head TP mode, we need to find the global argmax across all TP ranks + combined = self.get_local_max_and_combined(logits, + mapping_lm_head_tp) + gathered = allgather(combined, mapping_lm_head_tp, dim=-1) + batch_size = logits.shape[0] + local_batch_size = batch_size // mapping_lm_head_tp.tp_size + gathered = gathered.view(mapping_lm_head_tp.tp_size, + local_batch_size, -1) + sliced_gathered = gathered[mapping_lm_head_tp.tp_rank] + draft_tokens = self.get_draft_tokens_from_gathered(sliced_gathered) else: # Simple argmax if no TP or no model config draft_tokens = torch.argmax(logits, dim=-1).type(torch.int32) @@ -1229,14 +1248,44 @@ def prepare_position_ids_and_last_tokens(position_ids, attn_metadata): self.guided_decoder.add_draft_batch(new_tokens, num_accepted_tokens, draft_step=i) - - logits = draft_model.mtp_layers[0].shared_head( - hidden_states[gather_ids], draft_model.lm_head, attn_metadata, - True) + if self.model_config.mapping.enable_attention_dp and \ + getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False): + hidden_states_gathered = hidden_states[gather_ids] + token_count = hidden_states_gathered.view( + -1, hidden_states_gathered.shape[-1]).shape[0] + max_num_requests = spec_metadata.max_num_requests + pad_len = max_num_requests - token_count + if pad_len > 0: + padded_hidden_states = F.pad(hidden_states_gathered.view( + -1, hidden_states_gathered.shape[-1]), + (0, 0, 0, pad_len), + mode="constant", + value=0) + elif pad_len == 0: + padded_hidden_states = hidden_states_gathered.view( + -1, hidden_states_gathered.shape[-1]) + else: + raise ValueError( + f"In MTPEagleWorker.forward(), token_count < max_num_requests, which is not supported" + ) + logits = draft_model.mtp_layers[0].shared_head( + padded_hidden_states, draft_model.lm_head, attn_metadata, + True) + else: + logits = draft_model.mtp_layers[0].shared_head( + hidden_states[gather_ids], draft_model.lm_head, + attn_metadata, True) if self.guided_decoder is not None: self.guided_decoder.execute_draft_batch(logits, draft_step=i) - new_draft_token = self.draft_sampler(logits) + if self.model_config.mapping.enable_attention_dp and \ + getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False): + mapping_lm_head_tp = draft_model.mtp_layers[ + 0].shared_head.mapping_lm_head_tp + new_draft_token = self.draft_sampler(logits, mapping_lm_head_tp) + new_draft_token = new_draft_token[:token_count] + else: + new_draft_token = self.draft_sampler(logits) hidden_states, position_ids = self.update_draft_tokens( next_draft_tokens, new_draft_token, hidden_states, gather_ids, diff --git a/tensorrt_llm/_torch/utils.py b/tensorrt_llm/_torch/utils.py index 2d19c044fd0..2123e5e0968 100644 --- a/tensorrt_llm/_torch/utils.py +++ b/tensorrt_llm/_torch/utils.py @@ -1,4 +1,5 @@ import contextlib +import os import threading from dataclasses import dataclass from enum import Enum @@ -7,6 +8,7 @@ import torch from tensorrt_llm._utils import TensorWrapper, convert_to_torch_tensor +from tensorrt_llm.mapping import Mapping from tensorrt_llm.math_utils import ceil_div, pad_up from tensorrt_llm.quantization.utils import fp4_utils @@ -284,3 +286,19 @@ def set_per_request_piecewise_cuda_graph_flag(enable: bool): def get_per_request_piecewise_cuda_graph_flag() -> bool: return getattr(_global_attrs, 'per_request_piecewise_cuda_graph_flag', True) + + +def create_lm_head_tp_mapping(mapping: Mapping) -> Mapping: + lm_head_tp_size = int(os.getenv('LM_HEAD_TP_SIZE', 2)) + assert mapping.tp_size % lm_head_tp_size == 0 + lm_head_pp_size = mapping.pp_size * mapping.tp_size // lm_head_tp_size + + return Mapping( + world_size=lm_head_tp_size * lm_head_pp_size, + rank=mapping.rank, + gpus_per_node=mapping.gpus_per_node, + tp_size=lm_head_tp_size, + pp_size=lm_head_pp_size, + enable_attention_dp=mapping.enable_attention_dp, + enable_lm_head_tp_in_adp=mapping.enable_lm_head_tp_in_adp, + ) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 58e18ccc944..45929ff60d7 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -225,6 +225,7 @@ class _ParallelConfig: moe_ep_size: int = 1 cp_config: dict = field(default_factory=dict) enable_attention_dp: bool = False + enable_lm_head_tp_in_adp: bool = False auto_parallel: bool = False _world_size: int = field(default=1, init=False) @@ -288,6 +289,7 @@ def to_mapping(self) -> Mapping: cp_size=self.cp_size, cp_config=self.cp_config, enable_attention_dp=self.enable_attention_dp, + enable_lm_head_tp_in_adp=self.enable_lm_head_tp_in_adp, moe_cluster_size=self.moe_cluster_size, moe_tp_size=self.moe_tp_size, moe_ep_size=self.moe_ep_size, @@ -1264,6 +1266,11 @@ class BaseLlmArgs(StrictBaseModel): description="Enable attention data parallel.", status="beta") + enable_lm_head_tp_in_adp: bool = Field( + default=False, + description="Enable LM head TP in attention dp.", + status="beta") + cp_config: Optional[dict] = Field(default_factory=dict, description="Context parallel config.", status="prototype") @@ -1511,6 +1518,7 @@ def validate_parallel_config(self): moe_tp_size=self.moe_tensor_parallel_size, moe_ep_size=self.moe_expert_parallel_size, enable_attention_dp=self.enable_attention_dp, + enable_lm_head_tp_in_adp=self.enable_lm_head_tp_in_adp, cp_config=self.cp_config) return self diff --git a/tensorrt_llm/mapping.py b/tensorrt_llm/mapping.py index cfc997b786a..c2c896a932c 100644 --- a/tensorrt_llm/mapping.py +++ b/tensorrt_llm/mapping.py @@ -141,7 +141,8 @@ def __init__( attn_tp_size=-1, attn_cp_size=-1, auto_parallel=False, - enable_attention_dp=False): + enable_attention_dp=False, + enable_lm_head_tp_in_adp=False): # set default values for non-moe cases # or where only one MOE parallelism size is specified if moe_cluster_size == -1: @@ -224,6 +225,9 @@ def __init__( self.auto_parallel = auto_parallel self.world_size = world_size self.enable_attention_dp = enable_attention_dp + if enable_lm_head_tp_in_adp: + assert enable_attention_dp, "enable_lm_head_tp_in_adp requires enable_attention_dp" + self.enable_lm_head_tp_in_adp = enable_lm_head_tp_in_adp self.rank = rank self.gpus_per_node = gpus_per_node self.pp_groups = [] @@ -510,4 +514,6 @@ def to_dict(self): 'attn_cp_size': self.attn_cp_size, 'cp_config': self.cp_config, 'auto_parallel': self.auto_parallel, + 'enable_attention_dp': self.enable_attention_dp, + 'enable_lm_head_tp_in_adp': self.enable_lm_head_tp_in_adp, } diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index 56573d1fcec..51518bbccdb 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -171,6 +171,10 @@ methods: annotation: Optional[tensorrt_llm.llmapi.llm_args.KvCacheConnectorConfig] default: null status: prototype + enable_lm_head_tp_in_adp: + annotation: bool + default: False + status: prototype return_annotation: None generate: parameters: