Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2048,7 +2048,8 @@ Buffer::low_latency_dispatch_two_stage(
int num_experts,
bool use_fp8,
bool async,
bool return_recv_hook) {
bool return_recv_hook,
int num_per_channel) {
EP_HOST_ASSERT(low_latency_mode);

// Tensor checks
Expand All @@ -2063,7 +2064,8 @@ Buffer::low_latency_dispatch_two_stage(

auto num_tokens = static_cast<int>(x.size(0)),
hidden = static_cast<int>(x.size(1));
auto num_scales = hidden / 128, num_topk = static_cast<int>(topk_idx.size(1));
auto num_scales = num_per_channel == -1 ? 1 : hidden / 128,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果引入了num_per_channel,这里是不是改成hidden / num_per_channel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果引入了num_per_channel,这里是不是改成hidden / num_per_channel

这样的话,per-token的num_per_channel需要传hidden_size进来,参数会有点繁琐

num_topk = static_cast<int>(topk_idx.size(1));
int num_local_experts = num_experts / num_ranks;

// Buffer control
Expand Down Expand Up @@ -2120,7 +2122,7 @@ Buffer::low_latency_dispatch_two_stage(
(num_ranks / NUM_MAX_NVL_PEERS * (num_topk * 3 + 1) * sizeof(int) +
sizeof(int4) - 1) /
sizeof(int4) * sizeof(int4) +
(use_fp8 ? (hidden + num_scales * sizeof(float))
(use_fp8 ? (hidden + (num_scales + 3) / 4 * 4 * sizeof(float))
: (hidden * sizeof(nv_bfloat16)));
auto packed_rdma_recv_x = ConvertPaddleTensorToDetailTensor(
paddle::experimental::empty({num_ranks / NUM_MAX_NVL_PEERS,
Expand Down Expand Up @@ -2181,7 +2183,8 @@ Buffer::low_latency_dispatch_two_stage(
workspace,
launch_stream,
phases,
low_latency_buffer_idx);
low_latency_buffer_idx,
num_per_channel);
};
launcher(return_recv_hook
? LOW_LATENCY_SEND_PHASE
Expand Down Expand Up @@ -2222,6 +2225,7 @@ Buffer::low_latency_combine_two_stage(
bool dispatch_use_fp8,
bool async,
bool return_recv_hook,
int num_per_channel,
const std::optional<deep_ep::detail::Tensor>& out) {
EP_HOST_ASSERT(low_latency_mode);

Expand Down Expand Up @@ -2308,7 +2312,8 @@ Buffer::low_latency_combine_two_stage(
launch_stream,
phases,
dispatch_use_fp8,
low_latency_buffer_idx);
low_latency_buffer_idx,
num_per_channel);
};
launcher(return_recv_hook
? LOW_LATENCY_SEND_PHASE
Expand Down Expand Up @@ -3098,7 +3103,8 @@ Buffer::low_latency_dispatch_two_stage_api(const paddle::Tensor& x,
int num_experts,
bool use_fp8,
bool async,
bool return_recv_hook) {
bool return_recv_hook,
int num_per_channel) {
#ifdef PADDLE_WITH_NVSHMEM
const auto& x_ = ConvertPaddleTensorToDetailTensor(x);
const auto& topk_idx_ = ConvertPaddleTensorToDetailTensor(topk_idx);
Expand All @@ -3111,7 +3117,8 @@ Buffer::low_latency_dispatch_two_stage_api(const paddle::Tensor& x,
num_experts,
use_fp8,
async,
return_recv_hook);
return_recv_hook,
num_per_channel);

auto packed_recv_x_ = ConvertDetailTensorToPaddleTensor(std::get<0>(res));

Expand Down Expand Up @@ -3169,6 +3176,7 @@ Buffer::low_latency_combine_two_stage_api(
bool dispatch_use_fp8,
bool async,
bool return_recv_hook,
int num_per_channel,
const std::optional<paddle::Tensor>& out) {
#ifdef PADDLE_WITH_NVSHMEM
const auto& x_ = ConvertPaddleTensorToDetailTensor(x);
Expand Down Expand Up @@ -3200,6 +3208,7 @@ Buffer::low_latency_combine_two_stage_api(
dispatch_use_fp8,
async,
return_recv_hook,
num_per_channel,
out_);

auto combined_x_ = ConvertDetailTensorToPaddleTensor(std::get<0>(res));
Expand Down
8 changes: 6 additions & 2 deletions paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,8 @@ struct Buffer {
int num_experts,
bool use_fp8,
bool async,
bool return_recv_hook);
bool return_recv_hook,
int num_per_channel);

std::tuple<deep_ep::detail::Tensor,
std::optional<EventHandle>,
Expand All @@ -334,6 +335,7 @@ struct Buffer {
bool dispatch_use_fp8,
bool async,
bool return_recv_hook,
int num_per_channel,
const std::optional<deep_ep::detail::Tensor>& out);

std::tuple<deep_ep::detail::Tensor,
Expand Down Expand Up @@ -488,7 +490,8 @@ struct Buffer {
int num_experts,
bool use_fp8,
bool async,
bool return_recv_hook);
bool return_recv_hook,
int num_per_channel);

std::tuple<paddle::Tensor,
std::optional<EventHandle>,
Expand All @@ -507,6 +510,7 @@ struct Buffer {
bool dispatch_use_fp8,
bool async,
bool return_recv_hook,
int num_per_channel,
const std::optional<paddle::Tensor>& out);

std::tuple<paddle::Tensor,
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/distributed/collective/deep_ep/kernels/api.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,8 @@ void dispatch(void* packed_recv_x,
void* workspace,
cudaStream_t stream,
int phases,
int next_buffer_id);
int next_buffer_id,
int num_per_channel);

void combine(void* combined_x,
void* rdma_recv_x,
Expand Down Expand Up @@ -404,7 +405,8 @@ void combine(void* combined_x,
cudaStream_t stream,
int phases,
bool dispatch_use_fp8,
int next_buffer_id);
int next_buffer_id,
int num_per_channel);

void clean_low_latency_buffer_two_stage(void** buffer_ptrs_gpu,
const size_t max_nvl_num_bytes,
Expand Down
Loading
Loading