diff --git a/cpp/tensorrt_llm/kernels/customAllReduceKernels.h b/cpp/tensorrt_llm/kernels/customAllReduceKernels.h index 934679a944c..6758558e277 100644 --- a/cpp/tensorrt_llm/kernels/customAllReduceKernels.h +++ b/cpp/tensorrt_llm/kernels/customAllReduceKernels.h @@ -56,6 +56,8 @@ enum class AllReduceStrategyType : int8_t ONESHOT = 4, TWOSHOT = 5, LOWPRECISION = 6, + MNNVL = 7, + NCCL_SYMMETRIC = 8, }; enum class AllReduceStrategyConfig : int8_t diff --git a/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.cpp b/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.cpp index 81222885241..e0f2d5cce2e 100644 --- a/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.cpp +++ b/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.cpp @@ -14,47 +14,58 @@ * limitations under the License. */ #include "ub_allocator.h" +#include "tensorrt_llm/common/opUtils.h" +#include +#include namespace tensorrt_llm::runtime::ub { UserBufferAllocator& UserBufferAllocator::Instance() { - static UserBufferAllocator _; - return _; + if (use_nccl_symmetric) + { + static NCCLUserBufferAllocator _; + return _; + } + else + { + static UserBufferAllocator _; + return _; + } } -void UserBufferAllocator::initialize(tensorrt_llm::runtime::WorldConfig const& world_config) +void UserBufferAllocator::initialize(tensorrt_llm::runtime::WorldConfig const& worldConfig) { - if (!is_initialized()) + if (!isInitialized()) { - ub_comm_ = nullptr; - world_config_ = world_config; - create_communicator_grouped2(&ub_comm_, world_config_); - TLLM_CHECK(ub_comm_ != nullptr); - is_initialized_ = true; + mUbComm = nullptr; + mWorldConfig = worldConfig; + create_communicator_grouped2(&mUbComm, worldConfig); + TLLM_CHECK(mUbComm != nullptr); + mIsInitialized = true; } } -bool UserBufferAllocator::is_initialized() +bool UserBufferAllocator::isInitialized() { - return is_initialized_; + return mIsInitialized; } -UBBuffer UserBufferAllocator::register_ub_buffer(size_t bytes) +UBBuffer UserBufferAllocator::registerUBBuffer(size_t bytes) { - TLLM_CHECK(is_initialized()); + TLLM_CHECK(isInitialized()); void* addr = nullptr; int handle = -1; - handle = register_user_buffer_collective((void**) &addr, bytes, ub_comm_); + handle = register_user_buffer_collective((void**) &addr, bytes, mUbComm); return {addr, handle, bytes}; } UBBuffer UserBufferAllocator::allocate(size_t bytes) { - TLLM_CHECK(is_initialized()); - auto ub_buffer = register_ub_buffer(bytes); + TLLM_CHECK(isInitialized()); + auto ub_buffer = registerUBBuffer(bytes); TLLM_CHECK(!ub_buffer.invalid()); - buffers_.push_back(ub_buffer); + mBuffers.push_back(ub_buffer); return ub_buffer; } @@ -62,13 +73,177 @@ void UserBufferAllocator::deallocate(void* addr) {} UBBuffer UserBufferAllocator::get(int idx) { - TLLM_CHECK(is_initialized() && idx < buffers_.size() && !buffers_[idx].invalid()); - return buffers_[idx]; + TLLM_CHECK(isInitialized() && idx < mBuffers.size() && !mBuffers[idx].invalid()); + return mBuffers[idx]; } communicator* UserBufferAllocator::comm() { - TLLM_CHECK(is_initialized()); - return ub_comm_; + TLLM_CHECK(isInitialized()); + return mUbComm; +} + +void NCCLUserBufferAllocator::initialize(tensorrt_llm::runtime::WorldConfig const& worldConfig) +{ + if (!isInitialized()) + { + TLLM_LOG_INFO("Initializing NCCLUserBufferAllocator"); + std::set group; + for (int i = 0; i < worldConfig.getSize(); i++) + { + group.insert(i); + } + mComm = getComm(group); + mIsInitialized = true; + } } + +UBBuffer NCCLUserBufferAllocator::registerUBBuffer(size_t bytes) +{ + TLLM_CHECK(isInitialized()); + UBBuffer ub_buffer; + + auto& ncclHelper = getNCCLHelper(); + if (!ncclHelper.isLoaded()) + { + TLLM_THROW("NCCL library could not be loaded for dynamic symbol access"); + } + + auto ncclMemAllocFunc = ncclHelper.getNCCLMemAlloc(); + auto ncclCommWindowRegisterFunc = ncclHelper.getNCCLCommWindowRegister(); + + NCCLCHECK(ncclMemAllocFunc(&ub_buffer.addr, bytes)); + NCCLCHECK(ncclCommWindowRegisterFunc((*mComm), ub_buffer.addr, bytes, &ub_buffer.window, NCCL_WIN_COLL_SYMMETRIC)); + ub_buffer.handle = 5; + ub_buffer.size = bytes; + return ub_buffer; +} + +// Static member definitions +std::unique_ptr NCCLUserBufferAllocator::mNCCLHelper = nullptr; + +NCCLHelper& NCCLUserBufferAllocator::getNCCLHelper() +{ + if (!mNCCLHelper) + { + mNCCLHelper = std::make_unique(); + } + return *mNCCLHelper; +} + +// NCCLHelper implementation +NCCLHelper::NCCLHelper() + : mLibraryHandle(nullptr) + , mNCCLCommWindowRegister(nullptr) + , mNCCLMemAlloc(nullptr) + , mIsLoaded(false) +{ + loadNCCLLibrary(); +} + +NCCLHelper::~NCCLHelper() +{ + if (mLibraryHandle) + { +#ifdef _WIN32 + FreeLibrary(mLibraryHandle); +#else + dlclose(mLibraryHandle); +#endif + mLibraryHandle = nullptr; + } +} + +void NCCLHelper::loadNCCLLibrary() +{ + try + { +#ifdef _WIN32 + char const* libraryNames[] = {"nccl.dll"}; +#else + char const* libraryNames[] = {"libnccl.so"}; +#endif + + for (int i = 0; libraryNames[i] != nullptr; ++i) + { + mLibraryHandle = loadLibraryHandle(libraryNames[i]); + if (mLibraryHandle) + { + TLLM_LOG_INFO("Successfully loaded NCCL library: %s", libraryNames[i]); + break; + } + } + + if (!mLibraryHandle) + { + TLLM_LOG_WARNING("Failed to load NCCL library"); + return; + } + + // Load the required symbols + mNCCLCommWindowRegister + = reinterpret_cast(getSymbolAddress(mLibraryHandle, "ncclCommWindowRegister")); + + mNCCLMemAlloc = reinterpret_cast(getSymbolAddress(mLibraryHandle, "ncclMemAlloc")); + + if (mNCCLCommWindowRegister == nullptr) + { + TLLM_LOG_WARNING("Failed to load ncclCommWindowRegister symbol, NCCL symmetric will not be supported."); + } + + if (mNCCLMemAlloc) + { + mIsLoaded = true; + } + else + { + TLLM_LOG_WARNING("Failed to load required NCCL symbols"); + } + } + catch (std::exception const& e) + { + TLLM_LOG_WARNING("Exception while loading NCCL library: %s", e.what()); + } +} + +void* NCCLHelper::loadLibraryHandle(char const* libName) +{ +#ifdef _WIN32 + return LoadLibraryA(libName); +#else + return dlopen(libName, RTLD_LAZY | RTLD_GLOBAL); +#endif +} + +void* NCCLHelper::getSymbolAddress(void* handle, char const* symbolName) +{ + if (!handle) + { + return nullptr; + } + +#ifdef _WIN32 + return GetProcAddress(static_cast(handle), symbolName); +#else + return dlsym(handle, symbolName); +#endif +} + +NCCLHelper::ncclCommWindowRegisterFunc NCCLHelper::getNCCLCommWindowRegister() +{ + return mNCCLCommWindowRegister; +} + +NCCLHelper::ncclMemAllocFunc NCCLHelper::getNCCLMemAlloc() +{ + return mNCCLMemAlloc; +} + +bool NCCLHelper::isLoaded() const +{ + return mIsLoaded; +} + +bool UserBufferAllocator::use_nccl_symmetric = false; + }; // namespace tensorrt_llm::runtime::ub diff --git a/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.h b/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.h index 9e5c2ee4cb4..37a48e50352 100644 --- a/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.h +++ b/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.h @@ -14,9 +14,16 @@ * limitations under the License. */ #pragma once +#include "nccl.h" #include "tensorrt_llm/runtime/worldConfig.h" +#include #if ENABLE_MULTI_DEVICE #include "userbuffers.h" +#ifdef _WIN32 +#include +#else +#include +#endif #endif namespace tensorrt_llm::runtime::ub @@ -28,11 +35,13 @@ struct UBBuffer void* addr; int handle; size_t size; + ncclWindow_t window; - UBBuffer(void* a = nullptr, int h = -1, size_t s = 0) + UBBuffer(void* a = nullptr, int h = -1, size_t s = 0, ncclWindow_t w = nullptr) : addr(a) , handle(h) , size(s) + , window(w) { } @@ -49,21 +58,74 @@ class UserBufferAllocator UserBufferAllocator() = default; - void initialize(tensorrt_llm::runtime::WorldConfig const& world_config); - bool is_initialized(); + virtual void initialize(tensorrt_llm::runtime::WorldConfig const& worldConfig); + bool isInitialized(); UBBuffer allocate(size_t bytes); void deallocate(void* addr); UBBuffer get(int idx); communicator* comm(); + virtual UBBuffer registerUBBuffer(size_t bytes); + + static bool use_nccl_symmetric; private: - UBBuffer register_ub_buffer(size_t bytes); + communicator* mUbComm; - communicator* ub_comm_; - std::vector buffers_; - bool is_initialized_; - tensorrt_llm::runtime::WorldConfig world_config_; +protected: + std::vector mBuffers; + bool mIsInitialized; + tensorrt_llm::runtime::WorldConfig mWorldConfig; }; + +class NCCLHelper +{ +public: + NCCLHelper(); + ~NCCLHelper(); + + // Dynamic loading function type definition + using ncclCommWindowRegisterFunc = ncclResult_t (*)(ncclComm_t, void*, size_t, ncclWindow_t*, int); + using ncclMemAllocFunc = ncclResult_t (*)(void**, size_t); + + // Get function pointer for ncclCommWindowRegister + ncclCommWindowRegisterFunc getNCCLCommWindowRegister(); + + // Get function pointer for ncclMemAlloc + ncclMemAllocFunc getNCCLMemAlloc(); + + // Check if NCCL library is successfully loaded + bool isLoaded() const; + +private: + void loadNCCLLibrary(); + void* loadLibraryHandle(char const* libName); + void* getSymbolAddress(void* handle, char const* symbolName); + +#ifdef _WIN32 + HMODULE mLibraryHandle; +#else + void* mLibraryHandle; +#endif + + ncclCommWindowRegisterFunc mNCCLCommWindowRegister; + ncclMemAllocFunc mNCCLMemAlloc; + bool mIsLoaded; +}; + +class NCCLUserBufferAllocator : public UserBufferAllocator +{ +public: + void initialize(tensorrt_llm::runtime::WorldConfig const& world_config) override; + UBBuffer registerUBBuffer(size_t bytes) override; + + // Get shared NCCLHelper instance + static NCCLHelper& getNCCLHelper(); + +private: + std::shared_ptr mComm; + static std::unique_ptr mNCCLHelper; +}; + #else using communicator = void; #endif diff --git a/cpp/tensorrt_llm/kernels/userbuffers/ub_interface.cpp b/cpp/tensorrt_llm/kernels/userbuffers/ub_interface.cpp index d7a3e69981c..6d5f62b2604 100644 --- a/cpp/tensorrt_llm/kernels/userbuffers/ub_interface.cpp +++ b/cpp/tensorrt_llm/kernels/userbuffers/ub_interface.cpp @@ -36,7 +36,7 @@ void ub_initialize(int tp_size) bool ub_is_initialized() { - return UserBufferAllocator::Instance().is_initialized(); + return UserBufferAllocator::Instance().isInitialized(); } UBBuffer ub_allocate(size_t bytes) diff --git a/cpp/tensorrt_llm/kernels/userbuffers/userbuffersManager.cpp b/cpp/tensorrt_llm/kernels/userbuffers/userbuffersManager.cpp index c636eec3d97..a1fcd3c01fb 100644 --- a/cpp/tensorrt_llm/kernels/userbuffers/userbuffersManager.cpp +++ b/cpp/tensorrt_llm/kernels/userbuffers/userbuffersManager.cpp @@ -29,11 +29,14 @@ UserBuffersManager& UserBuffersManager::get_instance() return allocator; } -void UserBuffersManager::initialize( - int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, int64_t gpus_per_node, int64_t buffer_size) +void UserBuffersManager::initialize(int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, + int64_t gpus_per_node, int64_t buffer_size, bool use_nccl_symmetric) { std::lock_guard lock(mutex_); tensorrt_llm::runtime::WorldConfig world_config(tp_size, pp_size, cp_size, rank, gpus_per_node); +#if ENABLE_MULTI_DEVICE + UserBufferAllocator::Instance().use_nccl_symmetric = use_nccl_symmetric; +#endif tensorrt_llm::runtime::ub::ub_initialize(world_config); TLLM_CHECK(tensorrt_llm::runtime::ub::ub_is_initialized()); buffer_size_ = buffer_size; @@ -95,10 +98,11 @@ tensorrt_llm::runtime::ub::communicator* UserBuffersManager::comm() return tensorrt_llm::runtime::ub::ub_comm(); } -void initialize_userbuffers_manager( - int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, int64_t gpus_per_node, int64_t buffer_size) +void initialize_userbuffers_manager(int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, + int64_t gpus_per_node, int64_t buffer_size, bool use_nccl_symmetric) { - UserBuffersManager::get_instance().initialize(tp_size, pp_size, cp_size, rank, gpus_per_node, buffer_size); + UserBuffersManager::get_instance().initialize( + tp_size, pp_size, cp_size, rank, gpus_per_node, buffer_size, use_nccl_symmetric); } } // namespace tensorrt_llm::runtime::ub diff --git a/cpp/tensorrt_llm/kernels/userbuffers/userbuffersManager.h b/cpp/tensorrt_llm/kernels/userbuffers/userbuffersManager.h index 7ec39db602c..1b34f8e8a17 100644 --- a/cpp/tensorrt_llm/kernels/userbuffers/userbuffersManager.h +++ b/cpp/tensorrt_llm/kernels/userbuffers/userbuffersManager.h @@ -46,8 +46,9 @@ class UserBuffersManager //! @param gpus_per_node The number of GPUs per node. //! @param buffer_size The size of the buffer to allocate. All buffers allocated by this manager will have this //! size. - void initialize( - int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, int64_t gpus_per_node, int64_t buffer_size); + //! @param use_nccl_symmetric Whether to use NCCL symmetric communication. + void initialize(int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, int64_t gpus_per_node, + int64_t buffer_size, bool use_nccl_symmetric); //! @brief Create a UB tensor from the given shape, strides and data type. The function will choose available UB //! buffer or create a new one if no available buffer is found. @@ -75,7 +76,7 @@ class UserBuffersManager int64_t buffer_size_; }; -void initialize_userbuffers_manager( - int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, int64_t gpus_per_node, int64_t buffer_size); +void initialize_userbuffers_manager(int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, + int64_t gpus_per_node, int64_t buffer_size, bool use_nccl_symmetric); } // namespace tensorrt_llm::runtime::ub diff --git a/cpp/tensorrt_llm/pybind/runtime/bindings.cpp b/cpp/tensorrt_llm/pybind/runtime/bindings.cpp index 432f7e6b136..c2bcb28e6a4 100644 --- a/cpp/tensorrt_llm/pybind/runtime/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/runtime/bindings.cpp @@ -456,7 +456,8 @@ void initBindings(pybind11::module_& m) .value("AUTO", tensorrt_llm::kernels::AllReduceStrategyType::AUTO) .value("UB", tensorrt_llm::kernels::AllReduceStrategyType::UB) .value("ONESHOT", tensorrt_llm::kernels::AllReduceStrategyType::ONESHOT) - .value("TWOSHOT", tensorrt_llm::kernels::AllReduceStrategyType::TWOSHOT); + .value("TWOSHOT", tensorrt_llm::kernels::AllReduceStrategyType::TWOSHOT) + .value("NCCL_SYMMETRIC", tensorrt_llm::kernels::AllReduceStrategyType::NCCL_SYMMETRIC); // Initialize MoeLoadBalancer bindings initMoeBindings(m); diff --git a/cpp/tensorrt_llm/thop/allreduceOp.cpp b/cpp/tensorrt_llm/thop/allreduceOp.cpp index b38aea3ecd1..6cdc1f92822 100644 --- a/cpp/tensorrt_llm/thop/allreduceOp.cpp +++ b/cpp/tensorrt_llm/thop/allreduceOp.cpp @@ -163,9 +163,9 @@ class AllreduceOp { size_t size = input.numel(); size_t seq_len = input.size(0); + size_t bytes_per_element = input.element_size(); + TLLM_LOG_DEBUG("All reduce message size is %zu", size * bytes_per_element); - // If strategy is set to UB, UB must be used as UB impl output is special and cannot be used - // by others. AllReduceStrategyType runtime_strategy = getRuntimeStrategy(seq_len, size); // Log runtime strategy @@ -177,6 +177,8 @@ class AllreduceOp { case AllReduceStrategyType::UB: return runUBAllReduce(input, residual, norm_weight, scale, bias); case AllReduceStrategyType::NCCL: return runNCCLAllReduce(input, residual, norm_weight, scale, bias); + case AllReduceStrategyType::NCCL_SYMMETRIC: + return runNCCLAllReduceSymmetric(input, residual, norm_weight, scale, bias); case AllReduceStrategyType::MIN_LATENCY: case AllReduceStrategyType::ONESHOT: case AllReduceStrategyType::TWOSHOT: @@ -303,6 +305,39 @@ class AllreduceOp return fallbackRunSubsequentOps(input, residual, norm_weight, scale, bias, reduce_output); } + std::vector runNCCLAllReduceSymmetric(torch::Tensor const& input, + torch::optional const& residual, torch::optional const& norm_weight, + torch::optional const& scale, torch::optional const& bias) noexcept + { + + auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); + int size = input.numel(); + auto& ub_manager = tensorrt_llm::runtime::ub::UserBuffersManager::get_instance(); + auto ub_buffer0 = ub_manager.search_buffer(input.data_ptr()); + if (ub_buffer0.invalid()) + { + auto [symmetric_input, symmetric_ub_buffer0] + = torch_ext::create_userbuffers_tensor(input.sizes(), input.scalar_type()); + cudaMemcpyAsync(symmetric_ub_buffer0.addr, input.data_ptr(), size * input.element_size(), + cudaMemcpyDeviceToDevice, stream); + ub_buffer0 = symmetric_ub_buffer0; + } + + TLLM_CHECK(!ub_buffer0.invalid()); + auto [norm_out, ub_buffer1] = torch_ext::create_userbuffers_tensor(input.sizes(), input.scalar_type()); + + NCCLCHECK(ncclAllReduce( + ub_buffer0.addr, norm_out.mutable_data_ptr(), size, (*getDtypeMap())[mType], ncclSum, *mNcclComm, stream)); + + if (mOp == AllReduceFusionOp::NONE) + { + return {norm_out}; + } + + // Treat any other patterns as fallback cases. + return fallbackRunSubsequentOps(input, residual, norm_weight, scale, bias, norm_out); + } + std::vector runLowPrecisionAllReduce(torch::Tensor const& input, torch::optional const& residual, torch::optional const& norm_weight, torch::optional const& scale, torch::optional const& bias) noexcept @@ -633,6 +668,10 @@ class AllreduceOp { runtime_strategy = AllReduceStrategyType::NCCL; } + else if (mStrategy == AllReduceStrategyType::NCCL_SYMMETRIC) + { + runtime_strategy = AllReduceStrategyType::NCCL_SYMMETRIC; + } else { // This is for DEBUG and BENCHMARK purpose. It will overried the strategy if AUTO is set. @@ -658,6 +697,11 @@ class AllreduceOp TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: NCCL", rank); break; } + case AllReduceStrategyType::NCCL_SYMMETRIC: + { + TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: NCCL_SYMMETRIC", rank); + break; + } case AllReduceStrategyType::MIN_LATENCY: { TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: MIN_LATENCY", rank); @@ -673,7 +717,7 @@ class AllreduceOp TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: LOWPRECISION", rank); break; } - default: break; + default: TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: UNKNOWN: %d", rank, strategy); break; } } diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index ba713a7d566..5d44ecb263e 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -1,3 +1,4 @@ +import logging import math import os import threading @@ -14,6 +15,7 @@ from tensorrt_llm.plugin.plugin import CustomAllReduceHelper _thread_local = threading.local() +logger = logging.getLogger(__name__) def get_allreduce_workspace(mapping: Mapping) -> torch.LongTensor: diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index 232d2ccecd6..15cd4d618f2 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -126,7 +126,8 @@ def get_all_reduce_strategy(strategy: str = "AUTO"): "ONESHOT": AllReduceStrategy.ONESHOT, "TWOSHOT": AllReduceStrategy.TWOSHOT, "LOWPRECISION": AllReduceStrategy.LOWPRECISION, - "MNNVL": AllReduceStrategy.MNNVL + "MNNVL": AllReduceStrategy.MNNVL, + "NCCL_SYMMETRIC": AllReduceStrategy.NCCL_SYMMETRIC } key = strategy.upper() return maps[key] if key in maps else AllReduceStrategy.AUTO diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 2d00cee05f0..189b42885b5 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -322,10 +322,14 @@ def __init__( and not self.enable_attention_dp) try: + use_ub_for_nccl = ( + pytorch_backend_config.allreduce_strategy == "NCCL_SYMMETRIC" + and self._init_userbuffers(self.model.config.hidden_size)) if pytorch_backend_config.torch_compile_enabled: set_torch_compiling(True) - use_ub = pytorch_backend_config.torch_compile_enable_userbuffers and self._init_userbuffers( - self.model.config.hidden_size) + use_ub = not use_ub_for_nccl and ( + pytorch_backend_config.torch_compile_enable_userbuffers + and self._init_userbuffers(self.model.config.hidden_size)) self._torch_compile_backend = Backend( pytorch_backend_config.torch_compile_inductor_enabled, enable_userbuffers=use_ub, @@ -2232,12 +2236,12 @@ def _init_userbuffers(self, hidden_size): # Disable UB for unsupported platforms if not ub.ub_supported(): return False - ub.initialize_userbuffers_manager(self.mapping.tp_size, - self.mapping.pp_size, - self.mapping.cp_size, - self.mapping.rank, - self.mapping.gpus_per_node, - hidden_size * self.max_num_tokens * 2) + use_nccl_symmetric = self.pytorch_backend_config.allreduce_strategy == "NCCL_SYMMETRIC" + ub.initialize_userbuffers_manager( + self.mapping.tp_size, self.mapping.pp_size, self.mapping.cp_size, + self.mapping.rank, self.mapping.gpus_per_node, + hidden_size * self.max_num_tokens * 2, use_nccl_symmetric) + return True def load_weights_from_target_model(self, diff --git a/tensorrt_llm/functional.py b/tensorrt_llm/functional.py index 06880bc4304..59c42d32ab4 100755 --- a/tensorrt_llm/functional.py +++ b/tensorrt_llm/functional.py @@ -3882,6 +3882,7 @@ class AllReduceStrategy(IntEnum): TWOSHOT = 5 LOWPRECISION = 6 MNNVL = 7 + NCCL_SYMMETRIC = 8 class AllReduceFusionOp(IntEnum): diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index b7d46ed6fa2..f1a6a593242 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -2090,14 +2090,12 @@ class TorchLlmArgs(BaseLlmArgs): status="prototype", ) - allreduce_strategy: Optional[ - Literal['AUTO', 'NCCL', 'UB', 'MINLATENCY', 'ONESHOT', 'TWOSHOT', - 'LOWPRECISION', 'MNNVL']] = Field( - default='AUTO', - description="Allreduce strategy to use.", - status="beta", - ) - + allreduce_strategy: Optional[Literal[ + 'AUTO', 'NCCL', 'UB', 'MINLATENCY', 'ONESHOT', 'TWOSHOT', + 'LOWPRECISION', 'MNNVL', + 'NCCL_SYMMETRIC']] = Field(default='AUTO', + description="Allreduce strategy to use.", + status="beta") checkpoint_loader: Optional[object] = Field( default=None, description="The checkpoint loader to use for this LLM instance.", diff --git a/tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py b/tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py index e3d00f4683c..78b6cc61ea3 100644 --- a/tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py +++ b/tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py @@ -21,9 +21,9 @@ import torch from mpi4py import MPI from mpi4py.futures import MPIPoolExecutor -from utils.util import skip_pre_blackwell import tensorrt_llm +import tensorrt_llm.bindings.internal.userbuffers as ub from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp, AllReduceParams) from tensorrt_llm.functional import AllReduceStrategy @@ -55,6 +55,7 @@ def run_single_rank( dtype, fused_add_norm, reference_output_list, + strategy, ): rank = tensorrt_llm.mpi_rank() torch.cuda.set_device(rank) @@ -70,6 +71,7 @@ def run_single_rank( rank, fused_add_norm, reference_output_list, + strategy, ) except Exception: traceback.print_exc() @@ -89,6 +91,7 @@ def row_linear_residual_norm_fusion_forward( tensor_parallel_rank: int, fusion: bool, reference_output_list: list[tuple[torch.Tensor, ...]], + strategy: AllReduceStrategy, ): # Move all tensors to GPU @@ -100,6 +103,12 @@ def row_linear_residual_norm_fusion_forward( for ref_output in reference_output_list ] + if strategy == AllReduceStrategy.NCCL_SYMMETRIC: + ub.initialize_userbuffers_manager( + tensor_parallel_size, 1, 1, tensor_parallel_rank, + torch.cuda.device_count(), + x_list[0].nelement() * x_list[0].element_size(), True) + MPI.COMM_WORLD.barrier() # Create a single AllReduce instance to be reused for all sequence lengths @@ -109,7 +118,7 @@ def row_linear_residual_norm_fusion_forward( tp_size=tensor_parallel_size, rank=tensor_parallel_rank, ), - strategy=AllReduceStrategy.MNNVL, + strategy=strategy, dtype=dtype, ) @@ -152,7 +161,6 @@ def func(input, residual, norm_weight, eps, enable_fusion): ) -@skip_pre_blackwell @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="needs 2 GPUs to run this test") @pytest.mark.parametrize( @@ -171,12 +179,16 @@ def func(input, residual, norm_weight, eps, enable_fusion): @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=lambda x: f"dtype:{torch.finfo(x).dtype}") +@pytest.mark.parametrize( + "strategy", [AllReduceStrategy.MNNVL, AllReduceStrategy.NCCL_SYMMETRIC], + ids=lambda x: f"strategy:{x}") @pytest.mark.parametrize( "fusion", [True, False], ids=["fusion", "no_fusion"], ) -def test_row_linear_residual_norm_fusion(seq_len, hidden_size, dtype, fusion): +def test_row_linear_residual_norm_fusion(seq_len, hidden_size, dtype, strategy, + fusion): torch.manual_seed(42) tensor_parallel_size = 2 @@ -222,6 +234,7 @@ def test_row_linear_residual_norm_fusion(seq_len, hidden_size, dtype, fusion): dtype, fusion, reference_output_list, + strategy, ) for i in range(tensor_parallel_size) ]), ) diff --git a/tests/unittest/_torch/multi_gpu/test_user_buffers.py b/tests/unittest/_torch/multi_gpu/test_user_buffers.py index 601f5acfbc2..2248485fbf1 100644 --- a/tests/unittest/_torch/multi_gpu/test_user_buffers.py +++ b/tests/unittest/_torch/multi_gpu/test_user_buffers.py @@ -35,7 +35,8 @@ def init_userbuffers_allocator(tp_size, rank, max_ub_size): ub.initialize_userbuffers_manager(tp_size, 1, 1, rank, - torch.cuda.device_count(), max_ub_size) + torch.cuda.device_count(), max_ub_size, + False) def create_userbuffers_tensor(shape, dtype): diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index 984c8953ecd..7844904d152 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -144,7 +144,7 @@ methods: default: False status: prototype allreduce_strategy: - annotation: Optional[Literal['AUTO', 'NCCL', 'UB', 'MINLATENCY', 'ONESHOT', 'TWOSHOT', 'LOWPRECISION', 'MNNVL']] + annotation: Optional[Literal['AUTO', 'NCCL', 'UB', 'MINLATENCY', 'ONESHOT', 'TWOSHOT', 'LOWPRECISION', 'MNNVL', 'NCCL_SYMMETRIC']] default: AUTO status: beta decoding_config: