Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cpp/tensorrt_llm/kernels/customAllReduceKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ enum class AllReduceStrategyType : int8_t
ONESHOT = 4,
TWOSHOT = 5,
LOWPRECISION = 6,
MNNVL = 7,
NCCL_SYMMETRIC = 8,
};

enum class AllReduceStrategyConfig : int8_t
Expand Down
217 changes: 196 additions & 21 deletions cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,61 +14,236 @@
* limitations under the License.
*/
#include "ub_allocator.h"
#include "tensorrt_llm/common/opUtils.h"
#include <set>
#include <stdexcept>

namespace tensorrt_llm::runtime::ub
{
UserBufferAllocator& UserBufferAllocator::Instance()
{
static UserBufferAllocator _;
return _;
if (use_nccl_symmetric)
{
static NCCLUserBufferAllocator _;
return _;
}
else
{
static UserBufferAllocator _;
return _;
}
}

void UserBufferAllocator::initialize(tensorrt_llm::runtime::WorldConfig const& world_config)
void UserBufferAllocator::initialize(tensorrt_llm::runtime::WorldConfig const& worldConfig)
{
if (!is_initialized())
if (!isInitialized())
{
ub_comm_ = nullptr;
world_config_ = world_config;
create_communicator_grouped2(&ub_comm_, world_config_);
TLLM_CHECK(ub_comm_ != nullptr);
is_initialized_ = true;
mUbComm = nullptr;
mWorldConfig = worldConfig;
create_communicator_grouped2(&mUbComm, worldConfig);
TLLM_CHECK(mUbComm != nullptr);
mIsInitialized = true;
}
}

bool UserBufferAllocator::is_initialized()
bool UserBufferAllocator::isInitialized()
{
return is_initialized_;
return mIsInitialized;
}

UBBuffer UserBufferAllocator::register_ub_buffer(size_t bytes)
UBBuffer UserBufferAllocator::registerUBBuffer(size_t bytes)
{
TLLM_CHECK(is_initialized());
TLLM_CHECK(isInitialized());
void* addr = nullptr;
int handle = -1;
handle = register_user_buffer_collective((void**) &addr, bytes, ub_comm_);
handle = register_user_buffer_collective((void**) &addr, bytes, mUbComm);
return {addr, handle, bytes};
}

UBBuffer UserBufferAllocator::allocate(size_t bytes)
{
TLLM_CHECK(is_initialized());
auto ub_buffer = register_ub_buffer(bytes);
TLLM_CHECK(isInitialized());
auto ub_buffer = registerUBBuffer(bytes);
TLLM_CHECK(!ub_buffer.invalid());
buffers_.push_back(ub_buffer);
mBuffers.push_back(ub_buffer);
return ub_buffer;
}

void UserBufferAllocator::deallocate(void* addr) {}

UBBuffer UserBufferAllocator::get(int idx)
{
TLLM_CHECK(is_initialized() && idx < buffers_.size() && !buffers_[idx].invalid());
return buffers_[idx];
TLLM_CHECK(isInitialized() && idx < mBuffers.size() && !mBuffers[idx].invalid());
return mBuffers[idx];
}

communicator* UserBufferAllocator::comm()
{
TLLM_CHECK(is_initialized());
return ub_comm_;
TLLM_CHECK(isInitialized());
return mUbComm;
}

void NCCLUserBufferAllocator::initialize(tensorrt_llm::runtime::WorldConfig const& worldConfig)
{
if (!isInitialized())
{
TLLM_LOG_INFO("Initializing NCCLUserBufferAllocator");
std::set<int> group;
for (int i = 0; i < worldConfig.getSize(); i++)
{
group.insert(i);
}
mComm = getComm(group);
mIsInitialized = true;
}
}

UBBuffer NCCLUserBufferAllocator::registerUBBuffer(size_t bytes)
{
TLLM_CHECK(isInitialized());
UBBuffer ub_buffer;

auto& ncclHelper = getNCCLHelper();
if (!ncclHelper.isLoaded())
{
TLLM_THROW("NCCL library could not be loaded for dynamic symbol access");
}

auto ncclMemAllocFunc = ncclHelper.getNCCLMemAlloc();
auto ncclCommWindowRegisterFunc = ncclHelper.getNCCLCommWindowRegister();

NCCLCHECK(ncclMemAllocFunc(&ub_buffer.addr, bytes));
NCCLCHECK(ncclCommWindowRegisterFunc((*mComm), ub_buffer.addr, bytes, &ub_buffer.window, NCCL_WIN_COLL_SYMMETRIC));
ub_buffer.handle = 5;
ub_buffer.size = bytes;
return ub_buffer;
}

// Static member definitions
std::unique_ptr<NCCLHelper> NCCLUserBufferAllocator::mNCCLHelper = nullptr;

NCCLHelper& NCCLUserBufferAllocator::getNCCLHelper()
{
if (!mNCCLHelper)
{
mNCCLHelper = std::make_unique<NCCLHelper>();
}
return *mNCCLHelper;
}

// NCCLHelper implementation
NCCLHelper::NCCLHelper()
: mLibraryHandle(nullptr)
, mNCCLCommWindowRegister(nullptr)
, mNCCLMemAlloc(nullptr)
, mIsLoaded(false)
{
loadNCCLLibrary();
}

NCCLHelper::~NCCLHelper()
{
if (mLibraryHandle)
{
#ifdef _WIN32
FreeLibrary(mLibraryHandle);
#else
dlclose(mLibraryHandle);
#endif
mLibraryHandle = nullptr;
}
}

void NCCLHelper::loadNCCLLibrary()
{
try
{
#ifdef _WIN32
char const* libraryNames[] = {"nccl.dll"};
#else
char const* libraryNames[] = {"libnccl.so"};
#endif

for (int i = 0; libraryNames[i] != nullptr; ++i)
{
mLibraryHandle = loadLibraryHandle(libraryNames[i]);
if (mLibraryHandle)
{
TLLM_LOG_INFO("Successfully loaded NCCL library: %s", libraryNames[i]);
break;
}
}

if (!mLibraryHandle)
{
TLLM_LOG_WARNING("Failed to load NCCL library");
return;
}

// Load the required symbols
mNCCLCommWindowRegister
= reinterpret_cast<ncclCommWindowRegisterFunc>(getSymbolAddress(mLibraryHandle, "ncclCommWindowRegister"));

mNCCLMemAlloc = reinterpret_cast<ncclMemAllocFunc>(getSymbolAddress(mLibraryHandle, "ncclMemAlloc"));

if (mNCCLCommWindowRegister == nullptr)
{
TLLM_LOG_WARNING("Failed to load ncclCommWindowRegister symbol, NCCL symmetric will not be supported.");
}

if (mNCCLMemAlloc)
{
mIsLoaded = true;
}
else
{
TLLM_LOG_WARNING("Failed to load required NCCL symbols");
}
}
catch (std::exception const& e)
{
TLLM_LOG_WARNING("Exception while loading NCCL library: %s", e.what());
}
}

void* NCCLHelper::loadLibraryHandle(char const* libName)
{
#ifdef _WIN32
return LoadLibraryA(libName);
#else
return dlopen(libName, RTLD_LAZY | RTLD_GLOBAL);
#endif
}

void* NCCLHelper::getSymbolAddress(void* handle, char const* symbolName)
{
if (!handle)
{
return nullptr;
}

#ifdef _WIN32
return GetProcAddress(static_cast<HMODULE>(handle), symbolName);
#else
return dlsym(handle, symbolName);
#endif
}

NCCLHelper::ncclCommWindowRegisterFunc NCCLHelper::getNCCLCommWindowRegister()
{
return mNCCLCommWindowRegister;
}

NCCLHelper::ncclMemAllocFunc NCCLHelper::getNCCLMemAlloc()
{
return mNCCLMemAlloc;
}

bool NCCLHelper::isLoaded() const
{
return mIsLoaded;
}

bool UserBufferAllocator::use_nccl_symmetric = false;

}; // namespace tensorrt_llm::runtime::ub
78 changes: 70 additions & 8 deletions cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,16 @@
* limitations under the License.
*/
#pragma once
#include "nccl.h"
#include "tensorrt_llm/runtime/worldConfig.h"
#include <memory>
#if ENABLE_MULTI_DEVICE
#include "userbuffers.h"
#ifdef _WIN32
#include <windows.h>
#else
#include <dlfcn.h>
#endif
#endif

namespace tensorrt_llm::runtime::ub
Expand All @@ -28,11 +35,13 @@ struct UBBuffer
void* addr;
int handle;
size_t size;
ncclWindow_t window;

UBBuffer(void* a = nullptr, int h = -1, size_t s = 0)
UBBuffer(void* a = nullptr, int h = -1, size_t s = 0, ncclWindow_t w = nullptr)
: addr(a)
, handle(h)
, size(s)
, window(w)
{
}

Expand All @@ -49,21 +58,74 @@ class UserBufferAllocator

UserBufferAllocator() = default;

void initialize(tensorrt_llm::runtime::WorldConfig const& world_config);
bool is_initialized();
virtual void initialize(tensorrt_llm::runtime::WorldConfig const& worldConfig);
bool isInitialized();
UBBuffer allocate(size_t bytes);
void deallocate(void* addr);
UBBuffer get(int idx);
communicator* comm();
virtual UBBuffer registerUBBuffer(size_t bytes);

static bool use_nccl_symmetric;

private:
UBBuffer register_ub_buffer(size_t bytes);
communicator* mUbComm;

communicator* ub_comm_;
std::vector<UBBuffer> buffers_;
bool is_initialized_;
tensorrt_llm::runtime::WorldConfig world_config_;
protected:
std::vector<UBBuffer> mBuffers;
bool mIsInitialized;
tensorrt_llm::runtime::WorldConfig mWorldConfig;
};

class NCCLHelper
{
public:
NCCLHelper();
~NCCLHelper();

// Dynamic loading function type definition
using ncclCommWindowRegisterFunc = ncclResult_t (*)(ncclComm_t, void*, size_t, ncclWindow_t*, int);
using ncclMemAllocFunc = ncclResult_t (*)(void**, size_t);

// Get function pointer for ncclCommWindowRegister
ncclCommWindowRegisterFunc getNCCLCommWindowRegister();

// Get function pointer for ncclMemAlloc
ncclMemAllocFunc getNCCLMemAlloc();

// Check if NCCL library is successfully loaded
bool isLoaded() const;

private:
void loadNCCLLibrary();
void* loadLibraryHandle(char const* libName);
void* getSymbolAddress(void* handle, char const* symbolName);

#ifdef _WIN32
HMODULE mLibraryHandle;
#else
void* mLibraryHandle;
#endif

ncclCommWindowRegisterFunc mNCCLCommWindowRegister;
ncclMemAllocFunc mNCCLMemAlloc;
bool mIsLoaded;
};

class NCCLUserBufferAllocator : public UserBufferAllocator
{
public:
void initialize(tensorrt_llm::runtime::WorldConfig const& world_config) override;
UBBuffer registerUBBuffer(size_t bytes) override;

// Get shared NCCLHelper instance
static NCCLHelper& getNCCLHelper();

private:
std::shared_ptr<ncclComm_t> mComm;
static std::unique_ptr<NCCLHelper> mNCCLHelper;
};

#else
using communicator = void;
#endif
Expand Down
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/kernels/userbuffers/ub_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void ub_initialize(int tp_size)

bool ub_is_initialized()
{
return UserBufferAllocator::Instance().is_initialized();
return UserBufferAllocator::Instance().isInitialized();
}

UBBuffer ub_allocate(size_t bytes)
Expand Down
Loading