Skip to content
Merged
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/thop/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ add_library(
mlaPreprocessOp.cpp
allgatherOp.cpp
allreduceOp.cpp
alltoallOp.cpp
attentionOp.cpp
causalConv1dOp.cpp
convertSpecDecodingMaskToPackedMaskOp.cpp
Expand Down
126 changes: 126 additions & 0 deletions cpp/tensorrt_llm/thop/alltoallOp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tensorrt_llm/common/opUtils.h"
#include "tensorrt_llm/runtime/torchUtils.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"

#include <NvInferRuntime.h>
#include <c10/cuda/CUDAStream.h>
#include <cassert>
#include <set>
#include <string>
#include <torch/extension.h>
#include <vector>
#if ENABLE_MULTI_DEVICE
#include <nccl.h>
#endif // ENABLE_MULTI_DEVICE

namespace torch_ext
{
#if ENABLE_MULTI_DEVICE

namespace
{

class AllToAllHelixOp
{
public:
AllToAllHelixOp(std::set<int> group)
: mGroup(std::move(group))
{
}

~AllToAllHelixOp() = default;

int initialize()
{
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank());
mNcclComm = getComm(mGroup);
TLLM_CHECK_WITH_INFO(mGroup.size() > 0, "group size should be greater than 0");
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank());
return 0;
}

std::vector<torch::Tensor> run(torch::TensorList input_list, torch::optional<int64_t> num_lists)
{
TLLM_CHECK_WITH_INFO(mNcclComm.get() != nullptr, "mNcclComm should be initialized before used");
auto num_lists_ = static_cast<int>(num_lists.value_or(1));
auto num_ranks = static_cast<int>(mGroup.size());
// note: ensures that input_list size > 0
TLLM_CHECK_WITH_INFO(static_cast<int>(input_list.size()) == num_ranks * num_lists_,
"input_list size should be equal to group size * num_lists");
std::vector<torch::Tensor> output_list(static_cast<size_t>(num_lists_));
auto stream = at::cuda::getCurrentCUDAStream(input_list[0].get_device());
ncclGroupStart();
for (int il = 0; il < num_lists_; ++il)
{
auto off = il * num_ranks;
auto output_shape = input_list[off].sizes().vec();
output_shape.insert(output_shape.begin(), num_ranks);
auto output = torch::empty(output_shape, input_list[off].options());
output_list[il] = output;
auto type = tensorrt_llm::runtime::TorchUtils::dataType(input_list[off].scalar_type());
auto nccl_type = (*getDtypeMap())[type];
for (int r = 0; r < num_ranks; ++r)
{
auto const& input = input_list[off + r];
ncclSend(input.data_ptr(), input.numel(), nccl_type, r, *mNcclComm, stream);
ncclRecv(output[r].mutable_data_ptr(), output[r].numel(), nccl_type, r, *mNcclComm, stream);
}
}
NCCLCHECK_THROW(ncclGroupEnd());
return output_list;
}

private:
std::set<int> mGroup;
std::shared_ptr<ncclComm_t> mNcclComm;
};

} // namespace

#endif // ENABLE_MULTI_DEVICE

std::vector<torch::Tensor> alltoall_helix(
torch::TensorList input_list, torch::List<int64_t> group_, torch::optional<int64_t> num_lists)
{
#if ENABLE_MULTI_DEVICE
std::set<int> group;
for (int64_t rank : group_)
{
group.insert(static_cast<int>(rank));
}
AllToAllHelixOp op(group);
op.initialize();
return op.run(input_list, num_lists);
#else
return {};
#endif // ENABLE_MULTI_DEVICE
}

} // namespace torch_ext

TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def("alltoall_helix(Tensor[] input_list, int[] group, int? num_lists) -> Tensor[]");
}

TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("alltoall_helix", &torch_ext::alltoall_helix);
}
9 changes: 9 additions & 0 deletions tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,3 +504,12 @@ def _(router_logits, topk, output_dtype: torch.dtype = None):
return router_logits.new_empty(
sz, dtype=torch.int32), router_logits.new_empty(sz,
dtype=output_dtype)

@torch.library.register_fake("trtllm::alltoall_helix")
def _(input_list, group, num_lists):
num_ranks = len(group)
len(input_list) // num_ranks
return [
input_list[i].new_empty((num_ranks, ) + i.shape)
for i in range(0, len(input_list), num_ranks)
]
3 changes: 2 additions & 1 deletion tensorrt_llm/_torch/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from .communicator import Distributed, MPIDist, PPComm, TorchDist
from .ops import (AllReduce, AllReduceParams, AllReduceStrategy, MoEAllReduce,
MoEAllReduceParams, allgather, reducescatter,
MoEAllReduceParams, allgather, alltoall_helix, reducescatter,
userbuffers_allreduce_finalize)

__all__ = [
"allgather",
"alltoall_helix",
"reducescatter",
"userbuffers_allreduce_finalize",
"AllReduce",
Expand Down
38 changes: 38 additions & 0 deletions tensorrt_llm/_torch/distributed/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,44 @@ def convert_output(x, x_info):
return output


def alltoall_helix(
inputs: List[torch.Tensor],
group: List[int],
) -> List[torch.Tensor]:
'''
Add an operation that performs a collective all-to-all across a given group.
The operation is implemented using a torch op that wraps a NCCL group call of a series of
NCCL send/recv operations to implement the all-to-all. See the following materials for details.
https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/p2p.html#all-to-all,
https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/group.html.
Args:
inputs (List[Tensor]): The input tensors.
Its length must be a multiple of the group size,
and all tensors in a group must have the same shape.
group (List[int]): The group of ranks to participate in the all-to-all.
Returns:
The output tensors.
For each group of input tensors (of size group size),
there is one output tensor with shape (group size, *input shape).
'''
n_ranks = len(group)
if n_ranks == 1:
return inputs

assert n_ranks > 0, "group must be non-empty"
assert n_ranks == len(set(group)), "group must be unique"

assert len(inputs) % n_ranks == 0,\
"inputs length must be a multiple of the group size"
num_lists = len(inputs) // n_ranks
for il in range(num_lists):
ref_input = inputs[il * n_ranks]
assert all([inputs[i].shape == ref_input.shape for i in range(il * n_ranks + 1, (il + 1) * n_ranks)]),\
"all input tensors in a group must have the same shape"

return torch.ops.trtllm.alltoall_helix(inputs, group, num_lists)


def reducescatter(
input: Union[torch.Tensor, List[torch.Tensor]],
mapping: Mapping,
Expand Down
150 changes: 150 additions & 0 deletions tests/unittest/_torch/multi_gpu/test_alltoall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import pickle
import sys
import traceback

import cloudpickle
import pytest
import torch
from mpi4py import MPI

import tensorrt_llm
from tensorrt_llm._torch.distributed import alltoall_helix

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
cloudpickle.register_pickle_by_value(sys.modules[__name__])
MPI.pickle.__init__(
cloudpickle.dumps,
cloudpickle.loads,
pickle.HIGHEST_PROTOCOL,
)

# needed since we reuse the mpi executor pool, first test running will leak a thread
pytestmark = pytest.mark.threadleak(enabled=False)


def run_single_rank(single_rank_forward_func, *args, **kwargs):
rank = tensorrt_llm.mpi_rank()
torch.cuda.set_device(rank)
try:
single_rank_forward_func(*args, **kwargs)
except Exception:
traceback.print_exc()
raise
return True


@torch.inference_mode()
def run_alltoall_op(input_tensors, expected_recv_tensors, group):
"""Run alltoall_helix operation on a single rank."""
rank = tensorrt_llm.mpi_rank()
input_tensors = input_tensors[rank]
expected_recv_tensors = expected_recv_tensors[rank]
torch.cuda.set_device(rank)

# Move input tensors to GPU
input_tensors = [t.cuda() for t in input_tensors]

# Call alltoall_helix
output_tensors = alltoall_helix(input_tensors, group)

# Verify output
expected_recv_tensors = [t.cuda() for t in expected_recv_tensors]

assert len(output_tensors) * len(group) == len(input_tensors)
assert len(output_tensors) == len(expected_recv_tensors)

for i, output_tensor in enumerate(output_tensors):
assert output_tensor.dtype == expected_recv_tensors[i].dtype
assert output_tensor.device == expected_recv_tensors[i].device
assert output_tensor.shape == expected_recv_tensors[i].shape

assert torch.allclose(output_tensor, expected_recv_tensors[i])

return True


def run_alltoall_test(mpi_pool_executor, dtypes, shapes):
torch.manual_seed(0)
world_size = mpi_pool_executor.num_workers
num_lists = len(shapes)

# Create input tensors for each rank
send_tensors = []
for rank in range(world_size):
send_tensors.append([])
for list_idx in range(num_lists):
send_tensors[-1].append([])
for r in range(world_size):
# Each rank creates a tensor with unique data to send to rank `r`
tensor = torch.randn(*shapes[list_idx], dtype=dtypes[list_idx])
send_tensors[-1][-1].append(tensor)
expected_recv_tensors = []
# Given the expected tensors sent by rank `rank` to all other ranks `r`,
# we can now determine the expected tensors received by each rank `rank`
for rank in range(world_size):
expected_recv_tensors.append([])
# For each original tensor, determine the received tensors
for list_idx in range(num_lists):
# The received tensors are a transpose of the sent tensors
recv_tensors = [
send_tensors[r][list_idx][rank] for r in range(world_size)
]
expected_recv_tensors[-1].append(torch.stack(recv_tensors))

input_tensors = [[x for y in tensors for x in y]
for tensors in send_tensors]

# Create group list
group = list(range(world_size))

results = mpi_pool_executor.map(
run_single_rank,
*zip(*[(run_alltoall_op, input_tensors, expected_recv_tensors, group)] *
world_size),
)
for r in results:
assert r is True


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Requires at least 2 GPUs for this test")
@pytest.mark.parametrize("seq_len", [16, 256, 1024],
ids=lambda x: f"seqlen:{x}")
@pytest.mark.parametrize("hidden_size", [128, 2048, 7168],
ids=lambda x: f"hidden:{x}")
@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)
def test_alltoall_2gpu(seq_len, hidden_size, mpi_pool_executor):
dtypes = [torch.bfloat16, torch.float]
shapes1 = [(seq_len, hidden_size)]
run_alltoall_test(mpi_pool_executor, dtypes, shapes1)
shapes2 = [(seq_len, hidden_size), (seq_len + 1, hidden_size + 1)]
run_alltoall_test(mpi_pool_executor, dtypes, shapes2)


@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="Requires at least 4 GPUs for this test")
@pytest.mark.parametrize("seq_len", [28, 1004], ids=lambda x: f"seqlen:{x}")
@pytest.mark.parametrize("hidden_size", [36, 6284], ids=lambda x: f"hidden:{x}")
@pytest.mark.parametrize("mpi_pool_executor", [4], indirect=True)
def test_alltoall_4gpu(seq_len, hidden_size, mpi_pool_executor):
dtypes = [torch.bfloat16, torch.float]
shapes1 = [(seq_len, hidden_size)]
run_alltoall_test(mpi_pool_executor, dtypes, shapes1)
shapes2 = [(seq_len, hidden_size), (seq_len + 1, hidden_size + 1)]
run_alltoall_test(mpi_pool_executor, dtypes, shapes2)