Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@
from ..modules.rms_norm import RMSNorm
from ..peft.lora.layer import LoraLayer
from ..speculative import SpecMetadata
from ..utils import AuxStreamType, EventType, Fp4QuantizedTensor
from ..utils import (AuxStreamType, EventType, Fp4QuantizedTensor,
create_lm_head_tp_mapping)
from .modeling_speculative import SpecDecOneEngineForCausalLM
from .modeling_utils import (DecoderModel, EagerFusionConfig, filter_weights,
register_auto_model)
Expand Down Expand Up @@ -145,6 +146,12 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
self.norm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
if self.model_config.mapping.enable_attention_dp and \
getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False):
self.mapping_lm_head_tp = create_lm_head_tp_mapping(
self.model_config.mapping)
else:
self.mapping_lm_head_tp = self.model_config.mapping

@torch.compile(options={"max-autotune": True})
def get_last_token_states(self, hidden_states, attn_metadata):
Expand All @@ -167,10 +174,21 @@ def forward(self,
else:
hidden_states = hidden_states[-1].unsqueeze(0)

if not (self.model_config.mapping.enable_attention_dp):
enable_attention_dp = self.model_config.mapping.enable_attention_dp
enable_lm_head_tp_in_adp = self.model_config.mapping.enable_lm_head_tp_in_adp

# Add pre-lm gather logic
if enable_lm_head_tp_in_adp:
# ADP + LM TP mode: perform All-Gather before LM_head
hidden_states = allgather(hidden_states,
self.mapping_lm_head_tp,
dim=0)

# Temporarily disable gather_output when not in ADP mode or (in ADP mode and LM TP is enabled)
if not enable_attention_dp or enable_lm_head_tp_in_adp:
lm_head.gather_output = False
logits = lm_head(hidden_states)
if not (self.model_config.mapping.enable_attention_dp):
logits = lm_head(hidden_states, is_spec_decoding_head=True)
if not enable_attention_dp or enable_lm_head_tp_in_adp:
lm_head.gather_output = True
return logits

Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def __init__(self, model: TModel, *, config: ModelConfig[TConfig],
self.pp_size = config.mapping.pp_size
self.has_custom_lm_head = False

if config.mapping.enable_attention_dp:
if config.mapping.enable_attention_dp and not config.mapping.enable_lm_head_tp_in_adp:
self.lm_head = LMHead(
vocab_size,
hidden_size,
Expand Down
32 changes: 27 additions & 5 deletions tensorrt_llm/_torch/modules/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tensorrt_llm.mapping import Mapping

from ..distributed import allgather
from ..utils import create_lm_head_tp_mapping
from .linear import Linear, TensorParallelMode


Expand All @@ -35,6 +36,11 @@ def __init__(
local_in_features = embedding_dim
local_out_features = num_embeddings
mapping = mapping or Mapping()
self.enable_lm_head_tp_in_adp = mapping.enable_attention_dp and \
getattr(mapping, 'enable_lm_head_tp_in_adp', False)
if self.enable_lm_head_tp_in_adp:
mapping = create_lm_head_tp_mapping(mapping)

tp_size = mapping.tp_size

# Attention DP doesn't work with embedding parallelization.
Expand Down Expand Up @@ -72,6 +78,18 @@ def __init__(
self.weight = Parameter(torch.empty(weight_shape, dtype=dtype))
self.register_parameter("bias", None)

# For LM head TP in ADP, we need to slice the weight for the LM head
self.lm_head_slice_obj = None
if self.enable_lm_head_tp_in_adp:
tp_rank = self.mapping.tp_rank
tp_size = self.mapping.tp_size
slice_width = math.ceil(self.out_features / tp_size)
slice_start = tp_rank * slice_width
slice_end = min((tp_rank + 1) * slice_width, self.out_features)
slice_obj = [slice(None)] * len(self.weight.shape)
slice_obj[0] = slice(slice_start, slice_end)
self.lm_head_slice_obj = tuple(slice_obj)

@property
def vocab_size_padded(self) -> int:
if self.tp_mode == TensorParallelMode.COLUMN and self.gather_output:
Expand All @@ -80,12 +98,16 @@ def vocab_size_padded(self) -> int:
return self.out_features

def forward(
self,
input: torch.Tensor,
*,
all_reduce_params: Optional[AllReduceParams] = None
self,
input: torch.Tensor,
*,
all_reduce_params: Optional[AllReduceParams] = None,
is_spec_decoding_head: bool = False,
) -> torch.Tensor:
output = super().forward(input, all_reduce_params=all_reduce_params)
if is_spec_decoding_head and self.enable_lm_head_tp_in_adp:
output = F.linear(input, self.weight[self.lm_head_slice_obj], None)
else:
output = super().forward(input, all_reduce_params=all_reduce_params)
if (self.tp_mode == TensorParallelMode.COLUMN and self.gather_output
and self.padding_size > 0):
output = output[..., :-self.padding_size]
Expand Down
63 changes: 56 additions & 7 deletions tensorrt_llm/_torch/speculative/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
from typing import TYPE_CHECKING, List, Optional

import torch
import torch.nn.functional as F
from torch import nn

from tensorrt_llm.mapping import Mapping

from ..attention_backend import AttentionMetadata
from ..distributed.ops import allgather
from ..model_config import ModelConfig
Expand Down Expand Up @@ -1069,12 +1072,13 @@ def prepare_drafter_inputs(
}

@torch.compile(options={"max-autotune": True})
def get_local_max_and_combined(self, logits):
def get_local_max_and_combined(self, logits, mapping_lm_tp=None):
local_max_values, local_argmax = torch.max(logits, dim=-1, keepdim=True)
# Adjust indices based on TP rank and size
vocab_per_rank = logits.shape[-1]
mapping_lm_tp = mapping_lm_tp if mapping_lm_tp is not None else self.model_config.mapping
max_index_per_rank = local_argmax.type(
torch.int32) + (self.model_config.mapping.tp_rank * vocab_per_rank)
torch.int32) + (mapping_lm_tp.tp_rank * vocab_per_rank)
# Use torch.stack and flatten instead of view+cat to avoid torch.compile issues
# Convert both to float32 to ensure consistent dtype
max_index_per_rank_float = max_index_per_rank.float()
Expand Down Expand Up @@ -1102,6 +1106,7 @@ def get_draft_tokens_from_gathered(self, gathered):
def draft_sampler(
self,
logits: torch.Tensor,
mapping_lm_head_tp: Mapping = None,
):
'''
Sampling draft tokens.
Expand All @@ -1123,6 +1128,20 @@ def draft_sampler(
combined = self.get_local_max_and_combined(logits)
gathered = allgather(combined, self.model_config.mapping, dim=-1)
draft_tokens = self.get_draft_tokens_from_gathered(gathered)
elif (self.model_config is not None
and hasattr(self.model_config, 'mapping')
and self.model_config.mapping.tp_size
> 1) and self.model_config.mapping.enable_lm_head_tp_in_adp:
# For ADP + LM head TP mode, we need to find the global argmax across all TP ranks
combined = self.get_local_max_and_combined(logits,
mapping_lm_head_tp)
gathered = allgather(combined, mapping_lm_head_tp, dim=-1)
batch_size = logits.shape[0]
local_batch_size = batch_size // mapping_lm_head_tp.tp_size
gathered = gathered.view(mapping_lm_head_tp.tp_size,
local_batch_size, -1)
sliced_gathered = gathered[mapping_lm_head_tp.tp_rank]
draft_tokens = self.get_draft_tokens_from_gathered(sliced_gathered)
else:
# Simple argmax if no TP or no model config
draft_tokens = torch.argmax(logits, dim=-1).type(torch.int32)
Expand Down Expand Up @@ -1229,14 +1248,44 @@ def prepare_position_ids_and_last_tokens(position_ids, attn_metadata):
self.guided_decoder.add_draft_batch(new_tokens,
num_accepted_tokens,
draft_step=i)

logits = draft_model.mtp_layers[0].shared_head(
hidden_states[gather_ids], draft_model.lm_head, attn_metadata,
True)
if self.model_config.mapping.enable_attention_dp and \
getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False):
hidden_states_gathered = hidden_states[gather_ids]
token_count = hidden_states_gathered.view(
-1, hidden_states_gathered.shape[-1]).shape[0]
max_num_requests = spec_metadata.max_num_requests
pad_len = max_num_requests - token_count
if pad_len > 0:
padded_hidden_states = F.pad(hidden_states_gathered.view(
-1, hidden_states_gathered.shape[-1]),
(0, 0, 0, pad_len),
mode="constant",
value=0)
elif pad_len == 0:
padded_hidden_states = hidden_states_gathered.view(
-1, hidden_states_gathered.shape[-1])
else:
raise ValueError(
f"In MTPEagleWorker.forward(), token_count < max_num_requests, which is not supported"
)
logits = draft_model.mtp_layers[0].shared_head(
padded_hidden_states, draft_model.lm_head, attn_metadata,
True)
else:
logits = draft_model.mtp_layers[0].shared_head(
hidden_states[gather_ids], draft_model.lm_head,
attn_metadata, True)
if self.guided_decoder is not None:
self.guided_decoder.execute_draft_batch(logits, draft_step=i)

new_draft_token = self.draft_sampler(logits)
if self.model_config.mapping.enable_attention_dp and \
getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False):
mapping_lm_head_tp = draft_model.mtp_layers[
0].shared_head.mapping_lm_head_tp
new_draft_token = self.draft_sampler(logits, mapping_lm_head_tp)
new_draft_token = new_draft_token[:token_count]
else:
new_draft_token = self.draft_sampler(logits)

hidden_states, position_ids = self.update_draft_tokens(
next_draft_tokens, new_draft_token, hidden_states, gather_ids,
Expand Down
18 changes: 18 additions & 0 deletions tensorrt_llm/_torch/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
import os
import threading
from dataclasses import dataclass
from enum import Enum
Expand All @@ -7,6 +8,7 @@
import torch

from tensorrt_llm._utils import TensorWrapper, convert_to_torch_tensor
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.math_utils import ceil_div, pad_up
from tensorrt_llm.quantization.utils import fp4_utils

Expand Down Expand Up @@ -284,3 +286,19 @@ def set_per_request_piecewise_cuda_graph_flag(enable: bool):

def get_per_request_piecewise_cuda_graph_flag() -> bool:
return getattr(_global_attrs, 'per_request_piecewise_cuda_graph_flag', True)


def create_lm_head_tp_mapping(mapping: Mapping) -> Mapping:
lm_head_tp_size = int(os.getenv('LM_HEAD_TP_SIZE', 2))
assert mapping.tp_size % lm_head_tp_size == 0
lm_head_pp_size = mapping.pp_size * mapping.tp_size // lm_head_tp_size

return Mapping(
world_size=lm_head_tp_size * lm_head_pp_size,
rank=mapping.rank,
gpus_per_node=mapping.gpus_per_node,
tp_size=lm_head_tp_size,
pp_size=lm_head_pp_size,
enable_attention_dp=mapping.enable_attention_dp,
enable_lm_head_tp_in_adp=mapping.enable_lm_head_tp_in_adp,
)
8 changes: 8 additions & 0 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ class _ParallelConfig:
moe_ep_size: int = 1
cp_config: dict = field(default_factory=dict)
enable_attention_dp: bool = False
enable_lm_head_tp_in_adp: bool = False
auto_parallel: bool = False

_world_size: int = field(default=1, init=False)
Expand Down Expand Up @@ -288,6 +289,7 @@ def to_mapping(self) -> Mapping:
cp_size=self.cp_size,
cp_config=self.cp_config,
enable_attention_dp=self.enable_attention_dp,
enable_lm_head_tp_in_adp=self.enable_lm_head_tp_in_adp,
moe_cluster_size=self.moe_cluster_size,
moe_tp_size=self.moe_tp_size,
moe_ep_size=self.moe_ep_size,
Expand Down Expand Up @@ -1264,6 +1266,11 @@ class BaseLlmArgs(StrictBaseModel):
description="Enable attention data parallel.",
status="beta")

enable_lm_head_tp_in_adp: bool = Field(
default=False,
description="Enable LM head TP in attention dp.",
status="beta")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Status should be "prototype" rather than "beta"?
@Njuapp

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I submitted a new PR for this line of code: #7891


cp_config: Optional[dict] = Field(default_factory=dict,
description="Context parallel config.",
status="prototype")
Expand Down Expand Up @@ -1511,6 +1518,7 @@ def validate_parallel_config(self):
moe_tp_size=self.moe_tensor_parallel_size,
moe_ep_size=self.moe_expert_parallel_size,
enable_attention_dp=self.enable_attention_dp,
enable_lm_head_tp_in_adp=self.enable_lm_head_tp_in_adp,
cp_config=self.cp_config)
return self

Expand Down
8 changes: 7 additions & 1 deletion tensorrt_llm/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ def __init__(
attn_tp_size=-1,
attn_cp_size=-1,
auto_parallel=False,
enable_attention_dp=False):
enable_attention_dp=False,
enable_lm_head_tp_in_adp=False):
# set default values for non-moe cases
# or where only one MOE parallelism size is specified
if moe_cluster_size == -1:
Expand Down Expand Up @@ -224,6 +225,9 @@ def __init__(
self.auto_parallel = auto_parallel
self.world_size = world_size
self.enable_attention_dp = enable_attention_dp
if enable_lm_head_tp_in_adp:
assert enable_attention_dp, "enable_lm_head_tp_in_adp requires enable_attention_dp"
self.enable_lm_head_tp_in_adp = enable_lm_head_tp_in_adp
self.rank = rank
self.gpus_per_node = gpus_per_node
self.pp_groups = []
Expand Down Expand Up @@ -510,4 +514,6 @@ def to_dict(self):
'attn_cp_size': self.attn_cp_size,
'cp_config': self.cp_config,
'auto_parallel': self.auto_parallel,
'enable_attention_dp': self.enable_attention_dp,
'enable_lm_head_tp_in_adp': self.enable_lm_head_tp_in_adp,
}
4 changes: 4 additions & 0 deletions tests/unittest/api_stability/references/llm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,10 @@ methods:
annotation: Optional[tensorrt_llm.llmapi.llm_args.KvCacheConnectorConfig]
default: null
status: prototype
enable_lm_head_tp_in_adp:
annotation: bool
default: False
status: prototype
return_annotation: None
generate:
parameters:
Expand Down