Skip to content
Merged
Prev Previous commit
Next Next commit
Address comments to remove code setting strategies to Linear whem no
mappingg

Signed-off-by: Hui Gao <huig@nvidia.com>
  • Loading branch information
HuiGao-NV committed Jun 11, 2025
commit a7fab8b1ce9e8bc7802ca6e051961e8c76c06b3b
34 changes: 16 additions & 18 deletions examples/pytorch/out_of_tree_example/modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,24 +64,22 @@ def __init__(
config.hidden_size,
elementwise_affine=config.layer_norm_elementwise_affine,
dtype=config.torch_dtype)
self.fc1 = Linear(
config.hidden_size,
config.ffn_dim,
bias=config.enable_bias,
dtype=config.torch_dtype,
mapping=model_config.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
quant_config=model_config.get_quant_config(),
)
self.fc2 = Linear(
config.ffn_dim,
config.hidden_size,
bias=config.enable_bias,
dtype=config.torch_dtype,
mapping=model_config.mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=model_config.get_quant_config(),
)
self.fc1 = Linear(config.hidden_size,
config.ffn_dim,
bias=config.enable_bias,
dtype=config.torch_dtype,
mapping=model_config.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
quant_config=model_config.get_quant_config(),
allreduce_strategy=model_config.allreduce_strategy)
self.fc2 = Linear(config.ffn_dim,
config.hidden_size,
bias=config.enable_bias,
dtype=config.torch_dtype,
mapping=model_config.mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=model_config.get_quant_config(),
allreduce_strategy=model_config.allreduce_strategy)
self.final_layer_norm = LayerNorm(
config.hidden_size,
elementwise_affine=config.layer_norm_elementwise_affine,
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
try:
from ....mapping import Mapping
from ...distributed import AllReduce, allgather
from ...modules.linear import AllReduceFusionOp, AllReduceParams
from ...modules.linear import AllReduceFusionOp, AllReduceParams, AllReduceStrategy

def trtllm_allgather(tensor, dim, sizes=None):
rank, world_size = get_rank_world_size()
Expand All @@ -17,7 +17,7 @@ def trtllm_allreduce(tensor, op, all_reduce_params=None):
rank, world_size = get_rank_world_size()
assert op == ReduceOp.SUM, "TRT-LLM all reduce only supports SUM op."
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
torch_op = AllReduce(mapping=p_config)
torch_op = AllReduce(mapping=p_config, strategy=AllReduceStrategy.AUTO)
return torch_op(tensor, all_reduce_params=all_reduce_params)

@torch.library.custom_op(
Expand Down
8 changes: 4 additions & 4 deletions tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class ModelConfig(Generic[TConfig]):

attn_backend: str = 'TRTLLM'
moe_backend: str = 'CUTLASS' # options can be CUTLASS, TRTLLM
allreduce_backend: AllReduceStrategy = AllReduceStrategy.AUTO
allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO

# If true, enable min-latency mode. Currently only used for Llama4.
enable_min_latency: bool = False
Expand Down Expand Up @@ -123,9 +123,9 @@ def get_all_reduce_strategy(strategy: str = "AUTO"):
key = strategy.upper()
return maps[key] if key in maps else AllReduceStrategy.AUTO

if isinstance(self.allreduce_backend, str):
self.allreduce_backend = get_all_reduce_strategy(
self.allreduce_backend)
if isinstance(self.allreduce_strategy, str):
self.allreduce_strategy = get_all_reduce_strategy(
self.allreduce_strategy)

@property
def fuse_pos_embd(self):
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def __init__(self,
reduce_output=False)

self.allreduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_backend)
strategy=model_config.allreduce_strategy)
self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared]
self.event_dict = {
key: torch.cuda.Event()
Expand Down Expand Up @@ -630,7 +630,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
dtype=config.torch_dtype)
self.layer_idx = layer_idx
self.allreduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_backend,
strategy=model_config.allreduce_strategy,
dtype=config.torch_dtype)
self.moe_allreduce = MoEAllReduce(self.mapping)
self.next_layer_layernorm: RMSNorm = None
Expand Down
6 changes: 3 additions & 3 deletions tensorrt_llm/_torch/models/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def __init__(
self.mapping = model_config.mapping
self.all_reduce = AllReduce(
mapping=model_config.mapping,
strategy=model_config.allreduce_backend,
strategy=model_config.allreduce_strategy,
)
self.moe_event = [torch.cuda.Event(), torch.cuda.Event()]
self.aux_stream = aux_stream
Expand Down Expand Up @@ -418,7 +418,7 @@ def __init__(

self.mapping = model_config.mapping
self.all_reduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_backend)
strategy=model_config.allreduce_strategy)
self.next_layer_layernorm: RMSNorm = None
self.next_attn: LlamaAttention = None

Expand Down Expand Up @@ -629,7 +629,7 @@ def __init__(
quant_config=model_config.get_quant_config(),
skip_create_weights_in_init=model_config.
skip_create_weights_in_init,
)
allreduce_strategy=model_config.allreduce_strategy)


class Eagle3LlamaDecoderLayer(DecoderLayer):
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/models/modeling_nemotron_nas.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _create_linear_from_configs(model_config: ModelConfig[PretrainedConfig],
gather_output=True,
quant_config=model_config.get_quant_config(),
skip_create_weights_in_init=model_config.skip_create_weights_in_init,
)
allreduce_strategy=model_config.allreduce_strategy)


class NemotronNASAttention(Attention):
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/models/modeling_qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(
self.enable_attention_dp = model_config.mapping.enable_attention_dp
self.mapping = model_config.mapping
self.allreduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_backend)
strategy=model_config.allreduce_strategy)
self.enable_alltoall = Qwen3MoE.should_enable_alltoall(
model_config, self.top_k)
if self.enable_alltoall:
Expand Down Expand Up @@ -204,7 +204,7 @@ def __init__(self, model_config: ModelConfig[Qwen3MoeConfig],
self.layer_idx = layer_idx

self.allreduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_backend)
strategy=model_config.allreduce_strategy)
self.next_layer_layernorm: RMSNorm = None

self.fusion_config = EagerFusionConfig()
Expand Down
14 changes: 8 additions & 6 deletions tensorrt_llm/_torch/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__(
weight_mode=WeightMode.FUSED_QKV_LINEAR),
quant_config=config.get_quant_config(),
skip_create_weights_in_init=config.skip_create_weights_in_init,
)
allreduce_strategy=config.allreduce_strategy)
self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE],
[self.hidden_size])

Expand All @@ -140,7 +140,7 @@ def __init__(
quant_config=config.get_quant_config(),
skip_create_weights_in_init=config.skip_create_weights_in_init,
lora=self.o_lora,
)
allreduce_strategy=config.allreduce_strategy)

self.quant_config = config.get_quant_config()
self.attn_backend = config.attn_backend
Expand Down Expand Up @@ -481,7 +481,8 @@ def __init__(
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
quant_config=quant_config,
skip_create_weights_in_init=config.skip_create_weights_in_init)
skip_create_weights_in_init=config.skip_create_weights_in_init,
allreduce_strategy=config.allreduce_strategy)
else:
self.fused_a = Linear(
hidden_size,
Expand All @@ -501,7 +502,7 @@ def __init__(
tensor_parallel_mode=TensorParallelMode.COLUMN,
quant_config=quant_config,
skip_create_weights_in_init=config.skip_create_weights_in_init,
)
allreduce_strategy=config.allreduce_strategy)
self.q_b_proj = self.q_proj

self.kv_a_layernorm = RMSNorm(hidden_size=kv_lora_rank,
Expand All @@ -517,7 +518,8 @@ def __init__(
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
quant_config=quant_config,
skip_create_weights_in_init=config.skip_create_weights_in_init)
skip_create_weights_in_init=config.skip_create_weights_in_init,
allreduce_strategy=config.allreduce_strategy)
# This parameter will view into self.kv_b_proj.weight after loading weights.
# For dummy weight initialization, this parameter is initialized with empty tensor.
# Used in forward_generation only
Expand All @@ -538,7 +540,7 @@ def __init__(
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=quant_config,
skip_create_weights_in_init=config.skip_create_weights_in_init,
)
allreduce_strategy=config.allreduce_strategy)

def yarn_get_mscale(scale=1, mscale=1):
if scale <= 1:
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/modules/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def __init__(
self.parallel_size = self.mapping.tp_size

self.all_reduce = AllReduce(mapping=self.mapping,
strategy=model_config.allreduce_backend)
strategy=model_config.allreduce_strategy)

self.intermediate_size_per_partition = intermediate_size // self.tp_size

Expand Down Expand Up @@ -935,7 +935,7 @@ def __init__(
self.parallel_size = self.mapping.tp_size

self.all_reduce = AllReduce(mapping=self.mapping,
strategy=model_config.allreduce_backend)
strategy=model_config.allreduce_strategy)

self.intermediate_size_per_partition = intermediate_size // self.tp_size

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
self.parallel_size = self.mapping.tp_size

self.all_reduce = AllReduce(mapping=self.mapping,
strategy=model_config.allreduce_backend)
strategy=model_config.allreduce_strategy)

self.intermediate_size_per_partition = intermediate_size // self.tp_size

Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/modules/fused_moe/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(
self.intermediate_size_per_partition = intermediate_size // self.tp_size

self.all_reduce = AllReduce(mapping=self.mapping,
strategy=model_config.allreduce_backend)
strategy=model_config.allreduce_strategy)

@abstractmethod
def create_weights(self):
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/modules/gated_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(self,
quant_config=config.get_quant_config(),
reduce_output=False,
skip_create_weights_in_init=config.skip_create_weights_in_init,
)
allreduce_strategy=config.allreduce_strategy)

self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H],
[self.hidden_size])
Expand All @@ -89,7 +89,7 @@ def __init__(self,
reduce_output=reduce_output,
skip_create_weights_in_init=config.skip_create_weights_in_init,
lora=self.down_lora,
)
allreduce_strategy=config.allreduce_strategy)

# These two modules are mutually exclusive - either splitted_gate_up_lora or fused_gate_up_lora will be used,
# but never both at the same time. splitted_gate_up_lora handles gate and up separately while fused_gate_up_lora
Expand Down
7 changes: 5 additions & 2 deletions tensorrt_llm/_torch/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
from tensorrt_llm._torch.peft.lora.layer import LoraLayer
from tensorrt_llm.functional import AllReduceFusionOp, AllReduceParams
from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams,
AllReduceStrategy)
from tensorrt_llm.mapping import Mapping

from ...models.modeling_utils import QuantConfig
Expand Down Expand Up @@ -658,6 +659,7 @@ def __init__(
skip_create_weights_in_init: bool = False,
use_custom_cublas_mm: bool = False,
lora: Optional[LoraLayer] = None,
allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO,
):
from ..distributed import AllReduce

Expand Down Expand Up @@ -695,7 +697,8 @@ def __init__(
self.out_features = local_out_features

self.all_reduce = AllReduce(
mapping=self.mapping) if reduce_output else None
mapping=self.mapping,
strategy=allreduce_strategy) if reduce_output else None
self._weights_created = False
self.reduce_output = reduce_output
self.use_custom_cublas_mm = use_custom_cublas_mm
Expand Down
36 changes: 17 additions & 19 deletions tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,14 @@ def __init__(
self.is_paged_state = False

# in_proj
self.in_proj = Linear(
d_model,
d_in_proj,
bias=bias,
dtype=dtype,
mapping=self.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
quant_config=config.get_quant_config(),
)
self.in_proj = Linear(d_model,
d_in_proj,
bias=bias,
dtype=dtype,
mapping=self.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
quant_config=config.get_quant_config(),
allreduce_strategy=config.allreduce_strategy)

# conv1d, reuse Linear to store weights since it has support for TP > 1 already
self.conv1d = Linear(
Expand All @@ -108,7 +107,7 @@ def __init__(
tensor_parallel_mode=TensorParallelMode.COLUMN,
quant_config=config.get_quant_config(),
skip_create_weights_in_init=config.skip_create_weights_in_init,
)
allreduce_strategy=config.allreduce_strategy)

# A
self.A = nn.Parameter(
Expand Down Expand Up @@ -138,15 +137,14 @@ def __init__(
)

# out_proj
self.out_proj = Linear(
d_inner,
d_model,
bias=bias,
dtype=dtype,
mapping=self.mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=config.get_quant_config(),
)
self.out_proj = Linear(d_inner,
d_model,
bias=bias,
dtype=dtype,
mapping=self.mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=config.get_quant_config(),
allreduce_strategy=config.allreduce_strategy)

def forward(
self,
Expand Down
6 changes: 4 additions & 2 deletions tensorrt_llm/_torch/modules/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def __init__(self,
weight_mode=WeightMode.VANILLA),
quant_config=config.get_quant_config(),
skip_create_weights_in_init=config.skip_create_weights_in_init,
lora=self.up_lora)
lora=self.up_lora,
allreduce_strategy=config.allreduce_strategy)

self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H],
[self.hidden_size])
Expand All @@ -56,7 +57,8 @@ def __init__(self,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=config.get_quant_config(),
skip_create_weights_in_init=config.skip_create_weights_in_init,
lora=self.down_lora)
lora=self.down_lora,
allreduce_strategy=config.allreduce_strategy)

def forward(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(self,
)

self.allreduce = AllReduce(mapping=self.mapping,
allreduce_backend=self.strategy).cuda()
strategy=self.strategy).cuda()

self.input_tensors = []
for i in range(self.world_size):
Expand Down
Loading