Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,6 @@
[submodule "3rdparty/nanobind"]
path = 3rdparty/nanobind
url = https://github.com/wjakob/nanobind
[submodule "3rdparty/cppzmq"]
path = 3rdparty/cppzmq
url = https://github.com/zeromq/cppzmq.git
1 change: 1 addition & 0 deletions 3rdparty/cppzmq
Submodule cppzmq added at c94c20
3 changes: 2 additions & 1 deletion cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
{
std::lock_guard<std::mutex> lock(mDllMutex);
mWrapperLibHandle = dllOpen(UCX_WRAPPER_LIB_NAME);
TLLM_CHECK_WITH_INFO(mWrapperLibHandle != nullptr, "UCX wrapper library is not open correctly.");
TLLM_CHECK_WITH_INFO(
mWrapperLibHandle != nullptr, "UCX wrapper library is not open correctly. error : %s", dlerror());
auto load_sym = [](void* handle, char const* name)
{
void* ret = dllGetSym(handle, name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@ if(ENABLE_UCX)
find_package(ucx REQUIRED)
find_package(ucxx REQUIRED)

include_directories(${3RDPARTY_DIR}/cppzmq)

# Find and link ZMQ
find_package(PkgConfig REQUIRED)
pkg_check_modules(ZMQ REQUIRED libzmq)
# Add the NIXL wrapper target

add_library(${UCX_WRAPPER_TARGET} SHARED connection.cpp
ucxCacheCommunicator.cpp)
set_target_properties(
Expand All @@ -20,4 +27,8 @@ if(ENABLE_UCX)
PRIVATE $<LINK_LIBRARY:WHOLE_ARCHIVE,ucxx::ucxx>)
target_link_libraries(${UCX_WRAPPER_TARGET} PUBLIC ucxx::ucxx ucx::ucs)
target_link_libraries(${UCX_WRAPPER_TARGET} PUBLIC ${CUDA_RT_LIB})

# Add include directories
target_include_directories(${UCX_WRAPPER_TARGET} PRIVATE ${ZMQ_INCLUDE_DIRS})
target_link_libraries(${UCX_WRAPPER_TARGET} PRIVATE ${ZMQ_LIBRARIES})
endif()
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,56 @@
#include <exception>
#include <iostream>
#include <mutex>
#include <regex>
#include <sys/socket.h>
#include <ucxx/address.h>
#include <ucxx/typedefs.h>
#include <unistd.h>

namespace tensorrt_llm::executor::kv_cache
{

static void listenerCallback(ucp_conn_request_h connRequest, void* arg)
class UcxCmMessage
{
TLLM_LOG_DEBUG("listenerCallback");
char ipStr[INET6_ADDRSTRLEN];
char portStr[INET6_ADDRSTRLEN];
ucp_conn_request_attr_t attr{};
UcxConnectionManager* connectionManager = reinterpret_cast<UcxConnectionManager*>(arg);

attr.field_mask = UCP_CONN_REQUEST_ATTR_FIELD_CLIENT_ADDR;
ucxx::utils::ucsErrorThrow(ucp_conn_request_query(connRequest, &attr));
ucxx::utils::sockaddr_get_ip_port_str(&attr.client_address, ipStr, portStr, INET6_ADDRSTRLEN);
TLLM_LOG_DEBUG("Server received a connection request from client at address %s:%s", ipStr, portStr);
connectionManager->addConnection(connRequest);
}
public:
enum class MessageType
{
GET_WORKER_ADDRESS = 1,
SERVER_WORKER_ADDRESS = 2,
STOP = 3,
};

MessageType mType;
std::optional<std::string> mWorkerAddress;

UcxCmMessage(MessageType type, std::optional<std::string> workerAddress)
: mType(type)
, mWorkerAddress(std::move(workerAddress))
{
}

static size_t serializedSize(UcxCmMessage const& message)
{
namespace su = tensorrt_llm::executor::serialize_utils;

return su::serializedSize(message.mType) + su::serializedSize(message.mWorkerAddress);
}

static void serialize(UcxCmMessage const& message, std::ostream& os)
{
namespace su = tensorrt_llm::executor::serialize_utils;
su::serialize(message.mType, os);
su::serialize(message.mWorkerAddress, os);
}

static UcxCmMessage deserialize(std::istream& is)
{
namespace su = tensorrt_llm::executor::serialize_utils;
auto type = su::deserialize<MessageType>(is);
auto workerAddress = su::deserialize<std::optional<std::string>>(is);
return UcxCmMessage(type, workerAddress);
}
};

static std::string getLocalIp()
{
Expand Down Expand Up @@ -100,6 +129,22 @@ static std::string getLocalIp()
return ip;
}

std::optional<std::pair<std::string, int>> parse_zmq_endpoint(std::string const& endpoint)
{
std::regex ipv4_regex(R"(tcp://([\d\.]+):(\d+))");
std::regex ipv6_regex(R"(tcp://\[([0-9a-fA-F:]+)\]:(\d+))");
std::smatch match;
if (std::regex_match(endpoint, match, ipv4_regex))
{
return std::make_pair(match[1].str(), std::stoi(match[2].str()));
}
else if (std::regex_match(endpoint, match, ipv6_regex))
{
return std::make_pair(match[1].str(), std::stoi(match[2].str()));
}
return std::nullopt;
}

UcxConnectionManager::UcxConnectionManager()

{
Expand All @@ -120,22 +165,23 @@ UcxConnectionManager::UcxConnectionManager()
std::string error = "Error creating worker and starting progress thread for rank " + std::string(e.what());
TLLM_THROW(error);
}
auto workerAddressPtr = mWorkersPool.front()->getAddress();
mWorkerAddress = workerAddressPtr->getString();

try
{

mListener = mWorkersPool.front()->createListener(0, listenerCallback, this);
}
catch (std::exception const& e)
{
std::string error = "Error creating listener for rank " + std::string(e.what());
TLLM_THROW(error);
}

// Get local IP address
mZmqRepSocket = zmq::socket_t(mZmqContext, zmq::socket_type::rep);
mZmqRepSocket.set(zmq::sockopt::sndhwm, 1000);
std::string localIp = getLocalIp();
auto port = mListener->getPort();
SocketState socketState{port, localIp};
mZmqRepSocket.bind("tcp://" + localIp + ":*");
mZmqRepEndpoint = mZmqRepSocket.get(zmq::sockopt::last_endpoint);
TLLM_LOG_INFO(mpi::MpiComm::world().getRank(), "UcxConnectionManager::UcxConnectionManager mZmqRepEndpoint: %s",
mZmqRepEndpoint.c_str());
auto parse_result = parse_zmq_endpoint(mZmqRepEndpoint);
TLLM_CHECK_WITH_INFO(parse_result.has_value(), "Failed to parse ZMQ endpoint");
auto [ip, port] = parse_result.value();
TLLM_LOG_INFO(mpi::MpiComm::world().getRank(), "UcxConnectionManager::UcxConnectionManager ip: %s, port: %d",
ip.c_str(), port);

SocketState socketState{static_cast<uint16_t>(port), ip};
std::vector<executor::kv_cache::SocketState> socketStates(mpi::MpiComm::session().getSize());

if (mpi::MpiComm::session().getSize() > 1)
Expand Down Expand Up @@ -179,6 +225,47 @@ UcxConnectionManager::UcxConnectionManager()
}
mCommState = CommState(socketStates, mpi::MpiComm::session().getRank());
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), " ***** UCX mCommState: %s", mCommState.toString().c_str());

mZmqRepThread = std::thread(
[this]()
{
while (true)
{
zmq::message_t message;
auto ret = mZmqRepSocket.recv(message);
TLLM_CHECK_WITH_INFO(ret, "mZmqRepSocket.recv failed");
std::string recvMessage(static_cast<char*>(message.data()), message.size());
std::istringstream is(recvMessage);
UcxCmMessage ucxCmessage = UcxCmMessage::deserialize(is);

if (ucxCmessage.mType == UcxCmMessage::MessageType::GET_WORKER_ADDRESS)
{
// add Connection
TLLM_CHECK_WITH_INFO(ucxCmessage.mWorkerAddress.has_value(), "workerAddress is null");
std::string workerAddress = ucxCmessage.mWorkerAddress.value();
std::string selfWorkerAddress = mWorkerAddress;
UcxCmMessage serverMessage(UcxCmMessage::MessageType::SERVER_WORKER_ADDRESS, selfWorkerAddress);
std::ostringstream oStream;
UcxCmMessage::serialize(serverMessage, oStream);
std::string serverMessageStr = oStream.str();
mZmqRepSocket.send(zmq::buffer(serverMessageStr), zmq::send_flags::none);
addConnection(workerAddress);
}
else if (ucxCmessage.mType == UcxCmMessage::MessageType::STOP)
{
UcxCmMessage stopMessage(UcxCmMessage::MessageType::STOP, std::nullopt);
std::ostringstream oStream;
UcxCmMessage::serialize(stopMessage, oStream);
std::string stopMessageStr = oStream.str();
mZmqRepSocket.send(zmq::buffer(stopMessageStr), zmq::send_flags::none);
break;
}
else
{
TLLM_THROW("Zmq recv unknown message: %s", recvMessage.c_str());
}
}
});
}
catch (std::exception const& e)
{
Expand All @@ -195,14 +282,38 @@ UcxConnectionManager::~UcxConnectionManager()
{
worker->stopProgressThread();
}
if (mZmqRepThread.joinable())
{
zmq::socket_t socket(mZmqContext, zmq::socket_type::req);
socket.connect(mZmqRepEndpoint);
UcxCmMessage stopMessage(UcxCmMessage::MessageType::STOP, std::nullopt);
std::ostringstream oStream;
UcxCmMessage::serialize(stopMessage, oStream);
std::string stopMessageStr = oStream.str();
socket.send(zmq::buffer(stopMessageStr), zmq::send_flags::none);
zmq::message_t reply;
auto ret = socket.recv(reply);
TLLM_CHECK_WITH_INFO(ret, "zmq socket.recv failed");
std::string replyStr(static_cast<char*>(reply.data()), reply.size());
std::istringstream is(replyStr);
UcxCmMessage serverMessage = UcxCmMessage::deserialize(is);
TLLM_CHECK_WITH_INFO(serverMessage.mType == UcxCmMessage::MessageType::STOP, "serverMessage.mType is not STOP");
socket.close();
mZmqRepThread.join();
}

mZmqRepSocket.close();

mZmqContext.close();
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), "END UcxConnectionManager::~UcxConnectionManager");
}

void UcxConnectionManager::addConnection(ucp_conn_request_h connRequest)
void UcxConnectionManager::addConnection(std::string const& workerAddress)
{
try
{
std::shared_ptr<ucxx::Endpoint> newEp = mListener->createEndpointFromConnRequest(connRequest, true);
auto workerAddressPtr = ucxx::createAddressFromString(workerAddress);
auto newEp = mWorkersPool.front()->createEndpointFromWorkerAddress(workerAddressPtr, true);

UcxConnection::ConnectionIdType connectionId = getNewConnectionId(newEp);
std::scoped_lock lock(mConnectionFuturesMutex);
Expand All @@ -225,6 +336,23 @@ void UcxConnectionManager::addConnection(ucp_conn_request_h connRequest)
}
}

std::string build_zmq_endpoint(std::string const& ip, uint16_t port)
{
std::ostringstream oss;

std::regex ipv6_regex(R"([0-9a-fA-F]*:[0-9a-fA-F]*:[0-9a-fA-F]*.*)");
if (std::regex_match(ip, ipv6_regex) && ip.find(':') != std::string::npos)
{
oss << "tcp://[" << ip << "]:" << port;
}
else
{
oss << "tcp://" << ip << ":" << port;
}

return oss.str();
}

UcxConnection::ConnectionIdType UcxConnectionManager::addConnection(std::string const& ip, uint16_t port)
{
static std::mutex sAddConnectionIPMutex;
Expand All @@ -237,7 +365,24 @@ UcxConnection::ConnectionIdType UcxConnectionManager::addConnection(std::string
// This lock ensures that only one thread can create an endpoint from hostname and establish a UCX
// connection at a time, guaranteeing that the only one listener will send connectionId to requester in the
// same time.
std::shared_ptr<ucxx::Endpoint> newEp = mWorkersPool.front()->createEndpointFromHostname(ip, port, true);
auto reqSocket = zmq::socket_t(mZmqContext, zmq::socket_type::req);
reqSocket.connect(build_zmq_endpoint(ip, port));
UcxCmMessage getWorkerAddressMessage(UcxCmMessage::MessageType::GET_WORKER_ADDRESS, mWorkerAddress);
std::ostringstream oStream;
UcxCmMessage::serialize(getWorkerAddressMessage, oStream);
std::string getWorkerAddressMessageStr = oStream.str();
reqSocket.send(zmq::buffer(getWorkerAddressMessageStr), zmq::send_flags::none);
zmq::message_t reply;
auto ret = reqSocket.recv(reply);
TLLM_CHECK_WITH_INFO(ret, "zmq socket.recv failed");
std::string replyStr(static_cast<char*>(reply.data()), reply.size());
std::istringstream is(replyStr);
UcxCmMessage serverMessage = UcxCmMessage::deserialize(is);
TLLM_CHECK_WITH_INFO(serverMessage.mType == UcxCmMessage::MessageType::SERVER_WORKER_ADDRESS,
"serverMessage.mType is not SERVER_WORKER_ADDRESS");
std::string serverWorkerAddress = serverMessage.mWorkerAddress.value();
auto serverWorkerAddressPtr = ucxx::createAddressFromString(serverWorkerAddress);
auto newEp = mWorkersPool.front()->createEndpointFromWorkerAddress(serverWorkerAddressPtr, true);
connectionId = getNewConnectionId(newEp);
connection = std::make_shared<UcxConnection>(connectionId, newEp, this, true);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include <memory>
#include <string>
#include <vector>
#include <zmq.hpp>

namespace tensorrt_llm::executor::kv_cache
{
Expand All @@ -45,16 +46,20 @@ class UcxConnectionManager : public ConnectionManager, public std::enable_shared
private:
std::shared_ptr<ucxx::Context> mUcxCtx;
std::vector<std::shared_ptr<ucxx::Worker>> mWorkersPool;
std::string mWorkerAddress;
std::map<UcxConnection::ConnectionIdType, std::shared_ptr<UcxConnection>> mConnections;
std::map<UcxConnection::ConnectionIdType, std::future<void>> mConnectionFutures;
std::mutex mConnectionsMutex;
std::mutex mConnectionFuturesMutex;
std::unordered_map<std::string, uint64_t> mAddressToConnectionId;
std::mutex mAddressToConnectionIdMutex;
std::shared_ptr<ucxx::Listener> mListener;
CommState mCommState;
int mDevice;
std::atomic<UcxConnection::ConnectionIdType> mConnectionIdCounter{1};
zmq::context_t mZmqContext;
zmq::socket_t mZmqRepSocket;
std::string mZmqRepEndpoint;
std::thread mZmqRepThread;

UcxConnection::ConnectionIdType getNewConnectionId(std::shared_ptr<ucxx::Endpoint> const& newEp);
UcxConnection::ConnectionIdType addConnection(std::string const& ip, uint16_t port);
Expand All @@ -69,7 +74,7 @@ class UcxConnectionManager : public ConnectionManager, public std::enable_shared
return std::make_unique<UcxConnectionManager>();
}

void addConnection(ucp_conn_request_h connRequest);
void addConnection(std::string const& workerAddress);
Connection const* recvConnect(DataContext const& ctx, void* data, size_t size) override;
std::vector<Connection const*> getConnections(CommState const& state) override;
[[nodiscard]] CommState const& getCommState() const override;
Expand Down
4 changes: 3 additions & 1 deletion docker/common/install_base.sh
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ init_ubuntu() {
python3-pip \
python-is-python3 \
wget \
pigz
pigz \
libzmq3-dev
if ! command -v mpirun &> /dev/null; then
DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends openmpi-bin libopenmpi-dev
fi
Expand Down Expand Up @@ -129,6 +130,7 @@ install_gcctoolset_rockylinux() {
openmpi-devel \
pigz \
rdma-core-devel \
zeromq-devel \
-y
echo "source scl_source enable gcc-toolset-11" >> "${ENV}"
echo 'export PATH=/usr/lib64/openmpi/bin:$PATH' >> "${ENV}"
Expand Down
8 changes: 4 additions & 4 deletions jenkins/current_image_tags.properties
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#
# NB: Typically, the suffix indicates the PR whose CI pipeline generated the images. In case that
# images are adopted from PostMerge pipelines, the abbreviated commit hash is used instead.
LLM_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.06-py3-x86_64-ubuntu24.04-trt10.11.0.33-skip-tritondevel-202507251001-5678
LLM_SBSA_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.06-py3-aarch64-ubuntu24.04-trt10.11.0.33-skip-tritondevel-202507251001-5678
LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.1-devel-rocky8-x86_64-rocky8-py310-trt10.11.0.33-skip-tritondevel-202507251001-5678
LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.1-devel-rocky8-x86_64-rocky8-py312-trt10.11.0.33-skip-tritondevel-202507251001-5678
LLM_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.06-py3-x86_64-ubuntu24.04-trt10.11.0.33-skip-tritondevel-202508051130-6090
LLM_SBSA_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.06-py3-aarch64-ubuntu24.04-trt10.11.0.33-skip-tritondevel-202508051130-6090
LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.1-devel-rocky8-x86_64-rocky8-py310-trt10.11.0.33-skip-tritondevel-202508051130-6090
LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.1-devel-rocky8-x86_64-rocky8-py312-trt10.11.0.33-skip-tritondevel-202508051130-6090
Loading