diff --git a/paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp b/paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp index e0ab8c88423233..dd8e8d70e14aea 100644 --- a/paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp +++ b/paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp @@ -2048,7 +2048,8 @@ Buffer::low_latency_dispatch_two_stage( int num_experts, bool use_fp8, bool async, - bool return_recv_hook) { + bool return_recv_hook, + int num_per_channel) { EP_HOST_ASSERT(low_latency_mode); // Tensor checks @@ -2063,7 +2064,8 @@ Buffer::low_latency_dispatch_two_stage( auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)); - auto num_scales = hidden / 128, num_topk = static_cast(topk_idx.size(1)); + auto num_scales = num_per_channel == -1 ? 1 : hidden / 128, + num_topk = static_cast(topk_idx.size(1)); int num_local_experts = num_experts / num_ranks; // Buffer control @@ -2120,7 +2122,7 @@ Buffer::low_latency_dispatch_two_stage( (num_ranks / NUM_MAX_NVL_PEERS * (num_topk * 3 + 1) * sizeof(int) + sizeof(int4) - 1) / sizeof(int4) * sizeof(int4) + - (use_fp8 ? (hidden + num_scales * sizeof(float)) + (use_fp8 ? (hidden + (num_scales + 3) / 4 * 4 * sizeof(float)) : (hidden * sizeof(nv_bfloat16))); auto packed_rdma_recv_x = ConvertPaddleTensorToDetailTensor( paddle::experimental::empty({num_ranks / NUM_MAX_NVL_PEERS, @@ -2181,7 +2183,8 @@ Buffer::low_latency_dispatch_two_stage( workspace, launch_stream, phases, - low_latency_buffer_idx); + low_latency_buffer_idx, + num_per_channel); }; launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE @@ -2222,6 +2225,7 @@ Buffer::low_latency_combine_two_stage( bool dispatch_use_fp8, bool async, bool return_recv_hook, + int num_per_channel, const std::optional& out) { EP_HOST_ASSERT(low_latency_mode); @@ -2308,7 +2312,8 @@ Buffer::low_latency_combine_two_stage( launch_stream, phases, dispatch_use_fp8, - low_latency_buffer_idx); + low_latency_buffer_idx, + num_per_channel); }; launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE @@ -3098,7 +3103,8 @@ Buffer::low_latency_dispatch_two_stage_api(const paddle::Tensor& x, int num_experts, bool use_fp8, bool async, - bool return_recv_hook) { + bool return_recv_hook, + int num_per_channel) { #ifdef PADDLE_WITH_NVSHMEM const auto& x_ = ConvertPaddleTensorToDetailTensor(x); const auto& topk_idx_ = ConvertPaddleTensorToDetailTensor(topk_idx); @@ -3111,7 +3117,8 @@ Buffer::low_latency_dispatch_two_stage_api(const paddle::Tensor& x, num_experts, use_fp8, async, - return_recv_hook); + return_recv_hook, + num_per_channel); auto packed_recv_x_ = ConvertDetailTensorToPaddleTensor(std::get<0>(res)); @@ -3169,6 +3176,7 @@ Buffer::low_latency_combine_two_stage_api( bool dispatch_use_fp8, bool async, bool return_recv_hook, + int num_per_channel, const std::optional& out) { #ifdef PADDLE_WITH_NVSHMEM const auto& x_ = ConvertPaddleTensorToDetailTensor(x); @@ -3200,6 +3208,7 @@ Buffer::low_latency_combine_two_stage_api( dispatch_use_fp8, async, return_recv_hook, + num_per_channel, out_); auto combined_x_ = ConvertDetailTensorToPaddleTensor(std::get<0>(res)); diff --git a/paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp b/paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp index 94eef41ea0854f..f951bd56549d3c 100644 --- a/paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp +++ b/paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp @@ -315,7 +315,8 @@ struct Buffer { int num_experts, bool use_fp8, bool async, - bool return_recv_hook); + bool return_recv_hook, + int num_per_channel); std::tuple, @@ -334,6 +335,7 @@ struct Buffer { bool dispatch_use_fp8, bool async, bool return_recv_hook, + int num_per_channel, const std::optional& out); std::tuple, @@ -507,6 +510,7 @@ struct Buffer { bool dispatch_use_fp8, bool async, bool return_recv_hook, + int num_per_channel, const std::optional& out); std::tuple __launch_bounds__(kNumThreads, 1) __global__ void clean_low_latency_buffer_two_stage(void** buffer_ptrs_gpu, @@ -99,7 +103,8 @@ template + int kNumQPs, + int kNumPerChannels = 128> __global__ __launch_bounds__( kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void dispatch_kernel(void* packed_recv_x, @@ -157,10 +162,11 @@ __global__ __launch_bounds__( } // FP8 staffs - constexpr int kNumPerChannels = 128; constexpr float kFP8Margin = 1e-4, kFP8Amax = 448, kFP8AmaxInv = 1.0f / 448.0f; - constexpr int kNumScales = kHidden / kNumPerChannels; + constexpr int kNumScales = + kNumPerChannels == -1 ? 1 : kHidden / kNumPerChannels; + constexpr int kAlignElems = sizeof(int4) / sizeof(float); const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__nv_fp8_storage_t) : sizeof(nv_bfloat16)); const size_t hidden_int4 = hidden_bytes / sizeof(int4); @@ -172,12 +178,15 @@ __global__ __launch_bounds__( sizeof(int4) + (kNumRdmaRanks * (kTopk * 3 + 1) * sizeof(int) + sizeof(int4) - 1) / sizeof(int4) * sizeof(int4) + - (kUseFP8 ? (kHidden + kNumScales * sizeof(float)) - : (kHidden * sizeof(nv_bfloat16))); + (kUseFP8 + ? (kHidden + AlignUpElems(kNumScales, kAlignElems) * sizeof(float)) + : (kHidden * sizeof(nv_bfloat16))); // rdma_index_source, hidden, (scale) const size_t num_bytes_per_msg_rdma_revecier_and_nvl_sender = - sizeof(int4) + (kUseFP8 ? (kHidden + kNumScales * sizeof(float)) - : (kHidden * sizeof(nv_bfloat16))); + sizeof(int4) + + (kUseFP8 + ? (kHidden + AlignUpElems(kNumScales, kAlignElems) * sizeof(float)) + : (kHidden * sizeof(nv_bfloat16))); constexpr size_t combine_num_bytes_per_msg = kHidden * sizeof(nv_bfloat16); const size_t DISPATCH_NVL_BUFFER_X_BYTES = kNumLocalExperts * kNumRanks * num_max_dispatch_tokens_per_rank * @@ -200,8 +209,9 @@ __global__ __launch_bounds__( const size_t NVL_BUFFER_OFFSET = nvl_buffer_id * NVL_BUFFER_X_BYTES_PER_BUFFER; const size_t num_bytes_per_msg_rdma_to_nvl = - kUseFP8 ? (kHidden + kNumScales * sizeof(float)) - : (kHidden * sizeof(nv_bfloat16)); + kUseFP8 + ? (kHidden + AlignUpElems(kNumScales, kAlignElems) * sizeof(float)) + : (kHidden * sizeof(nv_bfloat16)); const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4); const size_t num_int4_per_msg_rdma_revecier_and_nvl_sender = num_bytes_per_msg_rdma_revecier_and_nvl_sender / sizeof(int4); @@ -240,35 +250,63 @@ __global__ __launch_bounds__( const auto rdma_x_scales = reinterpret_cast( reinterpret_cast(rdma_x_vec) + hidden_bytes); const auto index_source = rdma_x_src_idx; - const auto nvl_rank_meta = - reinterpret_cast(rdma_x_scales + (kUseFP8 ? kNumScales : 0)); + const auto nvl_rank_meta = reinterpret_cast( + rdma_x_scales + + (kUseFP8 ? AlignUpElems(kNumScales, kAlignElems) : 0)); thread_id == 0 ? (*index_source = token_idx) : 0; + if constexpr (kUseFP8 && + kNumPerChannels == -1) { // fp8 per-token dynamic quant + const auto warp_nums = kNumWarpGroups * kNumWarpsPerGroup; + __shared__ float amax_cache[warp_nums]; + for (int i = thread_id; i < warp_nums; i += num_threads) { + amax_cache[i] = 0.0f; + } + __syncthreads(); + float amax = kFP8Margin, scale, scale_inv; #pragma unroll - for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) { - // Read - auto int4_value = __ldg(x_int4 + i); - - if (kUseFP8) { - // Calculate local amax + for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) { + auto int4_value = __ldg(x_int4 + i); auto bf16_values = reinterpret_cast(&int4_value); float fp32_values[kNumElemsPerRead]; - float amax = kFP8Margin, scale, scale_inv; #pragma unroll for (int j = 0; j < kNumElemsPerRead; ++j) { fp32_values[j] = static_cast(bf16_values[j]); amax = fmaxf(amax, fabsf(fp32_values[j])); } - // Reduce amax and scale - EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, + EP_STATIC_ASSERT((kNumPerChannels == -1) || + (kNumElemsPerRead * 32 / kNumPerChannels == 2), "Invalid vectorization"); - amax = half_warp_reduce_max(amax), scale = kFP8Amax / amax, - scale_inv = amax * kFP8AmaxInv; - if (lane_id == 0 || lane_id == 16) - rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv; + amax = warp_reduce_max(amax); + if (lane_id == 0) { + amax_cache[warp_id] = amax; + } + } + __syncthreads(); + if (warp_id == 0) { + float thread_amax = lane_id < warp_nums ? amax_cache[lane_id] : 0.0f; + thread_amax = warp_reduce_max(thread_amax); + if (lane_id == 0) { + amax_cache[0] = thread_amax; + } + } + __syncthreads(); + amax = amax_cache[0]; + scale = 440.f / amax; + // scale_inv = amax * kFP8AmaxInv; + if (threadIdx.x == 0) { + rdma_x_scales[0] = amax; + } + for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) { + auto int4_value = __ldg(x_int4 + i); + auto bf16_values = reinterpret_cast(&int4_value); + float fp32_values[kNumElemsPerRead]; + for (int j = 0; j < kNumElemsPerRead; ++j) { + fp32_values[j] = static_cast(bf16_values[j]); + } // Cast into send buffer vec_t int2_value; auto fp8x2_values = @@ -281,9 +319,48 @@ __global__ __launch_bounds__( __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3); } rdma_x_vec[i] = int2_value; - } else { - // Reinterpret-cast is for C++14 compatibility - rdma_x_vec[i] = *reinterpret_cast(&int4_value); + } + } else { +#pragma unroll + for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) { + // Read + auto int4_value = __ldg(x_int4 + i); + + if constexpr (kUseFP8) { + // Calculate local amax + auto bf16_values = reinterpret_cast(&int4_value); + float fp32_values[kNumElemsPerRead]; + float amax = kFP8Margin, scale, scale_inv; +#pragma unroll + for (int j = 0; j < kNumElemsPerRead; ++j) { + fp32_values[j] = static_cast(bf16_values[j]); + amax = fmaxf(amax, fabsf(fp32_values[j])); + } + + // Reduce amax and scale + EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, + "Invalid vectorization"); + amax = half_warp_reduce_max(amax), scale = kFP8Amax / amax, + scale_inv = amax * kFP8AmaxInv; + if (lane_id == 0 || lane_id == 16) + rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv; + + // Cast into send buffer + vec_t int2_value; + auto fp8x2_values = + reinterpret_cast<__nv_fp8x2_storage_t*>(&int2_value); +#pragma unroll + for (int j = 0; j < kNumElemsPerRead; j += 2) { + float2 fp32x2 = {fp32_values[j] * scale, + fp32_values[j + 1] * scale}; + fp8x2_values[j / 2] = + __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3); + } + rdma_x_vec[i] = int2_value; + } else { + // Reinterpret-cast is for C++14 compatibility + rdma_x_vec[i] = *reinterpret_cast(&int4_value); + } } } __syncthreads(); @@ -477,7 +554,8 @@ LOW_LATENCY_DISPATCH_RECV: const auto rdma_recv_x_scales = reinterpret_cast( reinterpret_cast(src_data) + sizeof(int4) + hidden_bytes); const auto rdma_recv_nvl_rank_meta = reinterpret_cast( - rdma_recv_x_scales + (kUseFP8 ? kNumScales : 0)); + rdma_recv_x_scales + + (kUseFP8 ? AlignUpElems(kNumScales, kAlignElems) : 0)); const int dst_nvl_experts = *(rdma_recv_nvl_rank_meta + rdma_rank * (kTopk * 3 + 1)); const auto rdma_recv_nvl_rank_meta_now = @@ -651,16 +729,23 @@ LOW_LATENCY_DISPATCH_RECV: const auto dst_scales = reinterpret_cast(recv_x_scales + recv_token_begin_idx + i); const auto scale_stride = kNumRanks * num_max_dispatch_tokens_per_rank; - auto scale_0 = - lane_id < kNumScales ? ld_nc_global(src_scales + lane_id) : 0; - auto scale_1 = (lane_id + 32) < kNumScales - ? ld_nc_global(src_scales + lane_id + 32) - : 0; - lane_id < kNumScales ? dst_scales[lane_id * scale_stride] = scale_0 - : 0.0f; - (lane_id + 32) < kNumScales - ? dst_scales[(lane_id + 32) * scale_stride] = scale_1 - : 0.0f; + if constexpr (kNumPerChannels == -1) { + if (lane_id == 0) { + auto scale = ld_nc_global(src_scales); + dst_scales[0] = scale; + } + } else { + auto scale_0 = + lane_id < kNumScales ? ld_nc_global(src_scales + lane_id) : 0; + auto scale_1 = (lane_id + 32) < kNumScales + ? ld_nc_global(src_scales + lane_id + 32) + : 0; + lane_id < kNumScales ? dst_scales[lane_id * scale_stride] = scale_0 + : 0.0f; + (lane_id + 32) < kNumScales + ? dst_scales[(lane_id + 32) * scale_stride] = scale_1 + : 0.0f; + } } } } @@ -694,7 +779,8 @@ void dispatch(void* packed_recv_x, void* workspace, cudaStream_t stream, int phases, - int next_buffer_id) { + int next_buffer_id, + int num_per_channel) { constexpr int kNumMaxTopK = 8; constexpr int kNumQPs = 32; constexpr int NUM_WARPS = 32; @@ -736,65 +822,73 @@ void dispatch(void* packed_recv_x, {DISPATCH_NUM_EXPERTS( num_experts, kNumExperts, - {DISPATCH_NUM_WARP_GROUPS(num_warp_groups, kNumWarpGroups, { - constexpr int kNumWarpsPerGroup = - NUM_WARPS / kNumWarpGroups; - assert(num_rdma_ranks <= - kNumWarpGroups * kNumWarpsPerGroup); - EP_STATIC_ASSERT( - kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, - "Too many top-k selections"); - auto dispatch_func = - use_fp8 ? dispatch_kernel - : dispatch_kernel; - SETUP_LAUNCH_CONFIG(num_sms, - kNumWarpGroups * kNumWarpsPerGroup * 32, - stream); - LAUNCH_KERNEL(&cfg, - dispatch_func, - packed_recv_x, - packed_recv_x_scales, - packed_rdma_recv_x, - packed_recv_src_info, - packed_recv_layout_range, - packed_recv_count, - packed_rdma_recv_count, - rdma_send_flags, - rdma_recv_x, - rdma_recv_count, - rdma_x, - nvl_recv_x, - x, - topk_idx, - topk_weights, - atomic_counter_per_expert, - atomic_counter_per_rdma, - atomic_finished_counter_per_rdma, - atomic_recv_tokens_per_rdma_expert, - atomic_nvl_sender_multi_sms, - atomic_counter_per_qp, - next_clean, - num_next_clean_int, - num_tokens, - num_max_dispatch_tokens_per_rank, - rank, - phases, - next_buffer_id); - })})})})}); + {DISPATCH_NUM_WARP_GROUPS( + num_warp_groups, + kNumWarpGroups, + {DISPATCH_NUM_PER_CHANNEL( + num_per_channel, kNumPerChannels, { + constexpr int kNumWarpsPerGroup = + NUM_WARPS / kNumWarpGroups; + assert(num_rdma_ranks <= + kNumWarpGroups * kNumWarpsPerGroup); + EP_STATIC_ASSERT( + kNumMaxTopK + 1 <= + kNumWarpGroups * kNumWarpsPerGroup, + "Too many top-k selections"); + auto dispatch_func = + use_fp8 ? dispatch_kernel + : dispatch_kernel; + SETUP_LAUNCH_CONFIG( + num_sms, + kNumWarpGroups * kNumWarpsPerGroup * 32, + stream); + LAUNCH_KERNEL(&cfg, + dispatch_func, + packed_recv_x, + packed_recv_x_scales, + packed_rdma_recv_x, + packed_recv_src_info, + packed_recv_layout_range, + packed_recv_count, + packed_rdma_recv_count, + rdma_send_flags, + rdma_recv_x, + rdma_recv_count, + rdma_x, + nvl_recv_x, + x, + topk_idx, + topk_weights, + atomic_counter_per_expert, + atomic_counter_per_rdma, + atomic_finished_counter_per_rdma, + atomic_recv_tokens_per_rdma_expert, + atomic_nvl_sender_multi_sms, + atomic_counter_per_qp, + next_clean, + num_next_clean_int, + num_tokens, + num_max_dispatch_tokens_per_rank, + rank, + phases, + next_buffer_id); + })})})})})}); } template + int kNumQPs, + int kNumPerChannels = 128> __global__ __launch_bounds__( kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void combine_kernel(void* combined_x, @@ -837,19 +932,23 @@ __global__ __launch_bounds__( constexpr int kNumRanks = kNumRdmaRanks * NUM_MAX_NVL_PEERS; constexpr int kNumLocalExperts = kNumExperts / kNumRanks; constexpr int kNumRdmaExperts = kNumLocalExperts * NUM_MAX_NVL_PEERS; - constexpr int kNumPerChannels = 128; - constexpr int kNumScales = kHidden / kNumPerChannels; + constexpr int kAlignElems = sizeof(int4) / sizeof(float); + constexpr int kNumScales = + kNumPerChannels == -1 ? 1 : kHidden / kNumPerChannels; const int nvl_buffer_id = next_buffer_id ^ 1; const size_t num_bytes_per_msg_dispatch = sizeof(int4) + (kNumRdmaRanks * (kTopk * 3 + 1) * sizeof(int) + sizeof(int4) - 1) / sizeof(int4) * sizeof(int4) + - (kDispatchUseFP8 ? (kHidden + kNumScales * sizeof(float)) - : (kHidden * sizeof(nv_bfloat16))); + (kDispatchUseFP8 + ? (kHidden + AlignUpElems(kNumScales, kAlignElems) * sizeof(float)) + : (kHidden * sizeof(nv_bfloat16))); const size_t num_bytes_per_msg_rdma_revecier_and_nvl_sender_dispatch = - sizeof(int4) + (kDispatchUseFP8 ? (kHidden + kNumScales * sizeof(float)) - : (kHidden * sizeof(nv_bfloat16))); + sizeof(int4) + + (kDispatchUseFP8 + ? (kHidden + AlignUpElems(kNumScales, kAlignElems) * sizeof(float)) + : (kHidden * sizeof(nv_bfloat16))); const size_t dispatch_hidden_bytes = kHidden * @@ -1032,7 +1131,9 @@ __global__ __launch_bounds__( reinterpret_cast(dispatch_rdma_recv_x_now)[0]; const int* nvl_rank_meta = reinterpret_cast( dispatch_rdma_recv_x_now + sizeof(int4) + dispatch_hidden_bytes + - (kDispatchUseFP8 ? kNumScales * sizeof(float) : 0)); + (kDispatchUseFP8 + ? AlignUpElems(kNumScales, kAlignElems) * sizeof(float) + : 0)); const int nvl_rank_nums = *(nvl_rank_meta + rdma_rank * (kTopk * 3 + 1)); const int* nvl_rank_meta_now = @@ -1215,7 +1316,8 @@ void combine(void* combined_x, cudaStream_t stream, int phases, bool dispatch_use_fp8, - int next_buffer_id) { + int next_buffer_id, + int num_per_channel) { constexpr int kNumMaxTopk = 8; constexpr int kNumQPs = 4; constexpr int NUM_WARPS = 32; @@ -1245,58 +1347,66 @@ void combine(void* combined_x, {DISPATCH_NUM_EXPERTS( num_experts, kNumExperts, - {DISPATCH_NUM_WARP_GROUPS(num_warp_groups, kNumWarpGroups, { - constexpr int kNumWarpsPerGroup = - NUM_WARPS / kNumWarpGroups; - auto combine_func = dispatch_use_fp8 - ? combine_kernel - : combine_kernel; - SETUP_LAUNCH_CONFIG(num_sms, - kNumWarpGroups * kNumWarpsPerGroup * 32, - stream); - LAUNCH_KERNEL(&cfg, - combine_func, - combined_x, - rdma_recv_x, - rdma_recv_flag, - rdma_send_x, - dispatch_rdma_recv_x, - dispatch_rdma_recv_count, - nvl_buffer, - x, - topk_idx, - topk_weights, - src_info, - layout_range, - rdma_send_flags, - next_clean, - num_next_clean_int, - atomic_clean_flag, - atomic_nvl_sender_multi_sms, - num_combined_tokens, - hidden, - num_topk, - num_max_dispatch_tokens_per_rank, - num_experts, - rank, - num_ranks, - phases, - next_buffer_id); - })})})})}) + {DISPATCH_NUM_WARP_GROUPS( + num_warp_groups, + kNumWarpGroups, + {DISPATCH_NUM_PER_CHANNEL( + num_per_channel, kNumPerChannels, { + constexpr int kNumWarpsPerGroup = + NUM_WARPS / kNumWarpGroups; + auto combine_func = + dispatch_use_fp8 + ? combine_kernel + : combine_kernel; + SETUP_LAUNCH_CONFIG( + num_sms, + kNumWarpGroups * kNumWarpsPerGroup * 32, + stream); + LAUNCH_KERNEL(&cfg, + combine_func, + combined_x, + rdma_recv_x, + rdma_recv_flag, + rdma_send_x, + dispatch_rdma_recv_x, + dispatch_rdma_recv_count, + nvl_buffer, + x, + topk_idx, + topk_weights, + src_info, + layout_range, + rdma_send_flags, + next_clean, + num_next_clean_int, + atomic_clean_flag, + atomic_nvl_sender_multi_sms, + num_combined_tokens, + hidden, + num_topk, + num_max_dispatch_tokens_per_rank, + num_experts, + rank, + num_ranks, + phases, + next_buffer_id); + })})})})})}) } } // namespace internode_ll_two_stage diff --git a/python/paddle/distributed/communication/deep_ep/buffer.py b/python/paddle/distributed/communication/deep_ep/buffer.py index 7b6a8cb5ba148e..5778881b3b26eb 100644 --- a/python/paddle/distributed/communication/deep_ep/buffer.py +++ b/python/paddle/distributed/communication/deep_ep/buffer.py @@ -1055,6 +1055,7 @@ def low_latency_dispatch_two_stage( use_fp8: bool = True, async_finish: bool = False, return_recv_hook: bool = False, + num_per_channel: int = 128, ) -> tuple[ tuple[paddle.Tensor, paddle.Tensor], paddle.Tensor, @@ -1121,6 +1122,7 @@ def low_latency_dispatch_two_stage( use_fp8, async_finish, return_recv_hook, + num_per_channel, ) handle = ( packed_recv_rdma_x, @@ -1163,6 +1165,7 @@ def low_latency_combine_two_stage( dispatch_use_fp8: bool = False, async_finish: bool = False, return_recv_hook: bool = False, + num_per_channel: int = 128, out: paddle.Tensor | None = None, ) -> tuple[paddle.Tensor, EventOverlap, Callable]: """ @@ -1218,6 +1221,7 @@ def low_latency_combine_two_stage( dispatch_use_fp8, async_finish, return_recv_hook, + num_per_channel, out, ) tensors_to_record = (