Skip to content

Commit 69185c5

Browse files
yilin-voidtimlee0212
authored andcommitted
DeepEP LL support variable hidden size and tokens num (NVIDIA#6141)
Signed-off-by: Yilin Zhang <18275976+yilin-void@users.noreply.github.com>
1 parent bb142c9 commit 69185c5

File tree

3 files changed

+11
-41
lines changed

3 files changed

+11
-41
lines changed

cpp/tensorrt_llm/deep_ep/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
set(DEEP_EP_COMMIT eb3f072664251c05074c3ecc3c3f5dad179c29a9)
1+
set(DEEP_EP_COMMIT 7b15af835942675df041eca2dcb9930b880287e1)
22
set(NVSHMEM_URL_HASH
33
SHA256=eb2c8fb3b7084c2db86bd9fd905387909f1dfd483e7b45f7b3c3d5fcf5374b5a)
44

tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ class VariableLengthLowLatencyBuffer:
100100
def __init__(self, mapping: Mapping):
101101
self.comm = mpi_comm().Split(mapping.pp_rank, mapping.moe_ep_rank)
102102
self.buffer = None
103-
self.num_max_dispatch_tokens_per_rank = None
103+
self.num_experts = None
104104

105105
def __del__(self):
106106
self.comm.Free()
@@ -120,6 +120,7 @@ def reserve(self, num_max_dispatch_tokens_per_rank: int, hidden_size: int,
120120
allow_nvlink_for_low_latency_mode = (os.environ.get(
121121
"TRTLLM_DEEP_EP_DISABLE_P2P_FOR_LOW_LATENCY_MODE", "0") == "0")
122122

123+
assert self.num_experts is None or self.num_experts == num_experts
123124
# Allocate a buffer if not existed or not enough buffer size
124125
if self.buffer is None or self.buffer.num_rdma_bytes < num_rdma_bytes:
125126
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
@@ -133,17 +134,13 @@ def reserve(self, num_max_dispatch_tokens_per_rank: int, hidden_size: int,
133134
allow_nvlink_for_low_latency_mode=
134135
allow_nvlink_for_low_latency_mode,
135136
comm=self.comm)
137+
self.num_experts = num_experts
136138

137139
def low_latency_dispatch(self, hidden_states: torch.Tensor,
138140
topk_idx: torch.Tensor,
139141
num_max_dispatch_tokens_per_rank: int,
140142
num_experts: int):
141-
if self.num_max_dispatch_tokens_per_rank is None:
142-
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
143-
if num_max_dispatch_tokens_per_rank != self.num_max_dispatch_tokens_per_rank:
144-
raise NotImplementedError(
145-
"There are issues if `low_latency_dispatch` calls use different `num_max_dispatch_tokens_per_rank` values"
146-
)
143+
assert num_experts == self.num_experts
147144

148145
# Do MoE dispatch, compatible with CUDA graph (but you may restore some buffer status once you replay)
149146
recv_hidden_states, recv_expert_count, handle, event, hook = \

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -463,15 +463,14 @@ def forward_chunk(
463463
if not use_postquant_alltoall:
464464
deep_ep_topk_idx = token_selected_slots
465465
deep_ep_topk_weights = token_final_scales
466+
assert all_rank_max_num_tokens <= self.deep_ep_max_num_tokens
466467
x, recv_expert_count, deep_ep_handle = \
467-
self.deep_ep_buffer.low_latency_dispatch(x, deep_ep_topk_idx, self.deep_ep_max_num_tokens, self.num_slots)
468-
# x shape: [#local experts, EP size * deep_ep_max_num_tokens, hidden_size]
468+
self.deep_ep_buffer.low_latency_dispatch(x, deep_ep_topk_idx, all_rank_max_num_tokens, self.num_slots)
469+
# x shape: [#local experts, EP size * all_rank_max_num_tokens, hidden_size]
469470
# recv_expert_count shape: [#local experts]
470471

471472
# Adapter between `torch.ops.trtllm.fused_moe` and DeepEP
472473
# TODO: remove the adapter by changing `torch.ops.trtllm.fused_moe` API
473-
x = x[:, :self.mapping.moe_ep_size *
474-
all_rank_max_num_tokens]
475474
mask = torch.arange(
476475
x.shape[1], dtype=torch.int32, device=x.device).expand(
477476
x.shape[0],
@@ -615,26 +614,14 @@ def forward_chunk(
615614

616615
deep_ep_topk_idx = token_selected_slots
617616
deep_ep_topk_weights = token_final_scales
618-
# Each LL combine/dispatch kernel call requires that the `dispatch_rdma_recv_count_buffer` be properly cleaned.
619-
# However, the offset of this buffer within the entire RDMA buffer changes according to the hidden size.
620-
# Therefore, if the hidden size for the next LL dispatch/combine call is different from the current kernel call, manual cleaning is necessary.
621-
if packed_hidden_size != hidden_size:
622-
self.deep_ep_buffer.clean_low_latency_buffer(
623-
self.deep_ep_max_num_tokens, packed_hidden_size,
624-
self.num_slots)
617+
618+
assert all_rank_max_num_tokens <= self.deep_ep_max_num_tokens
625619
fp4_packed_tensor, recv_expert_count, deep_ep_handle = \
626-
self.deep_ep_buffer.low_latency_dispatch(fp4_packed_tensor, deep_ep_topk_idx, self.deep_ep_max_num_tokens, self.num_slots)
627-
if packed_hidden_size != hidden_size:
628-
self.deep_ep_buffer.clean_low_latency_buffer(
629-
self.deep_ep_max_num_tokens, hidden_size,
630-
self.num_slots)
620+
self.deep_ep_buffer.low_latency_dispatch(fp4_packed_tensor, deep_ep_topk_idx, all_rank_max_num_tokens, self.num_slots)
631621
deep_ep_handle = list(deep_ep_handle)
632622
deep_ep_handle[3] = hidden_size
633623
deep_ep_handle = tuple(deep_ep_handle)
634624

635-
fp4_packed_tensor = fp4_packed_tensor[:, :self.mapping.
636-
moe_ep_size *
637-
all_rank_max_num_tokens]
638625
assert fp4_packed_tensor.ndim == 3 and fp4_packed_tensor.shape[
639626
2] == packed_hidden_size
640627
x_sf = fp4_packed_tensor[:, :, x.shape[1]:x.shape[1] +
@@ -707,23 +694,9 @@ def forward_chunk(
707694
final_hidden_states, deep_ep_handle)
708695
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
709696
num_tokens_per_expert_for_fused_moe = self.mapping.moe_ep_size * all_rank_max_num_tokens
710-
num_tokens_per_expert_for_deep_ep = self.deep_ep_max_num_tokens * self.mapping.moe_ep_size
711697
final_hidden_states = final_hidden_states.view(
712698
self.expert_size_per_partition,
713699
num_tokens_per_expert_for_fused_moe, self.hidden_size)
714-
if num_tokens_per_expert_for_deep_ep != num_tokens_per_expert_for_fused_moe:
715-
# Adapter between fused_moe num_tokens and DeepEP num_tokens
716-
# This adapter can be removed if fused_moe accepts DeepEP num_tokens without overhead
717-
final_hidden_states_for_fused_moe = final_hidden_states
718-
final_hidden_states = torch.empty(
719-
self.expert_size_per_partition,
720-
self.deep_ep_max_num_tokens * self.mapping.moe_ep_size,
721-
self.hidden_size,
722-
dtype=final_hidden_states.dtype,
723-
device=final_hidden_states.device)
724-
final_hidden_states[:, :
725-
num_tokens_per_expert_for_fused_moe] = final_hidden_states_for_fused_moe
726-
del final_hidden_states_for_fused_moe # Release memory
727700
final_hidden_states = self.deep_ep_buffer.low_latency_combine(
728701
final_hidden_states, deep_ep_topk_idx, deep_ep_topk_weights,
729702
deep_ep_handle)

0 commit comments

Comments
 (0)