diff --git a/.gitmodules b/.gitmodules index 31970ad4054..00ff73d1364 100644 --- a/.gitmodules +++ b/.gitmodules @@ -23,3 +23,6 @@ [submodule "3rdparty/nanobind"] path = 3rdparty/nanobind url = https://github.com/wjakob/nanobind +[submodule "3rdparty/cppzmq"] + path = 3rdparty/cppzmq + url = https://github.com/zeromq/cppzmq.git diff --git a/3rdparty/cppzmq b/3rdparty/cppzmq new file mode 160000 index 00000000000..c94c20743ed --- /dev/null +++ b/3rdparty/cppzmq @@ -0,0 +1 @@ +Subproject commit c94c20743ed7d4aa37835a5c46567ab0790d4acc diff --git a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp index 599a89cef03..48ac605a3fd 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp @@ -159,7 +159,8 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa { std::lock_guard lock(mDllMutex); mWrapperLibHandle = dllOpen(UCX_WRAPPER_LIB_NAME); - TLLM_CHECK_WITH_INFO(mWrapperLibHandle != nullptr, "UCX wrapper library is not open correctly."); + TLLM_CHECK_WITH_INFO( + mWrapperLibHandle != nullptr, "UCX wrapper library is not open correctly. error : %s", dlerror()); auto load_sym = [](void* handle, char const* name) { void* ret = dllGetSym(handle, name); diff --git a/cpp/tensorrt_llm/executor/cache_transmission/ucx_utils/CMakeLists.txt b/cpp/tensorrt_llm/executor/cache_transmission/ucx_utils/CMakeLists.txt index cd2095dd8dd..b8ea41b1419 100644 --- a/cpp/tensorrt_llm/executor/cache_transmission/ucx_utils/CMakeLists.txt +++ b/cpp/tensorrt_llm/executor/cache_transmission/ucx_utils/CMakeLists.txt @@ -4,6 +4,13 @@ if(ENABLE_UCX) find_package(ucx REQUIRED) find_package(ucxx REQUIRED) + include_directories(${3RDPARTY_DIR}/cppzmq) + + # Find and link ZMQ + find_package(PkgConfig REQUIRED) + pkg_check_modules(ZMQ REQUIRED libzmq) + # Add the NIXL wrapper target + add_library(${UCX_WRAPPER_TARGET} SHARED connection.cpp ucxCacheCommunicator.cpp) set_target_properties( @@ -20,4 +27,8 @@ if(ENABLE_UCX) PRIVATE $) target_link_libraries(${UCX_WRAPPER_TARGET} PUBLIC ucxx::ucxx ucx::ucs) target_link_libraries(${UCX_WRAPPER_TARGET} PUBLIC ${CUDA_RT_LIB}) + + # Add include directories + target_include_directories(${UCX_WRAPPER_TARGET} PRIVATE ${ZMQ_INCLUDE_DIRS}) + target_link_libraries(${UCX_WRAPPER_TARGET} PRIVATE ${ZMQ_LIBRARIES}) endif() diff --git a/cpp/tensorrt_llm/executor/cache_transmission/ucx_utils/ucxCacheCommunicator.cpp b/cpp/tensorrt_llm/executor/cache_transmission/ucx_utils/ucxCacheCommunicator.cpp index 426dafa5bcf..88cfa4ca93e 100644 --- a/cpp/tensorrt_llm/executor/cache_transmission/ucx_utils/ucxCacheCommunicator.cpp +++ b/cpp/tensorrt_llm/executor/cache_transmission/ucx_utils/ucxCacheCommunicator.cpp @@ -22,27 +22,56 @@ #include #include #include +#include #include +#include #include #include namespace tensorrt_llm::executor::kv_cache { -static void listenerCallback(ucp_conn_request_h connRequest, void* arg) +class UcxCmMessage { - TLLM_LOG_DEBUG("listenerCallback"); - char ipStr[INET6_ADDRSTRLEN]; - char portStr[INET6_ADDRSTRLEN]; - ucp_conn_request_attr_t attr{}; - UcxConnectionManager* connectionManager = reinterpret_cast(arg); - - attr.field_mask = UCP_CONN_REQUEST_ATTR_FIELD_CLIENT_ADDR; - ucxx::utils::ucsErrorThrow(ucp_conn_request_query(connRequest, &attr)); - ucxx::utils::sockaddr_get_ip_port_str(&attr.client_address, ipStr, portStr, INET6_ADDRSTRLEN); - TLLM_LOG_DEBUG("Server received a connection request from client at address %s:%s", ipStr, portStr); - connectionManager->addConnection(connRequest); -} +public: + enum class MessageType + { + GET_WORKER_ADDRESS = 1, + SERVER_WORKER_ADDRESS = 2, + STOP = 3, + }; + + MessageType mType; + std::optional mWorkerAddress; + + UcxCmMessage(MessageType type, std::optional workerAddress) + : mType(type) + , mWorkerAddress(std::move(workerAddress)) + { + } + + static size_t serializedSize(UcxCmMessage const& message) + { + namespace su = tensorrt_llm::executor::serialize_utils; + + return su::serializedSize(message.mType) + su::serializedSize(message.mWorkerAddress); + } + + static void serialize(UcxCmMessage const& message, std::ostream& os) + { + namespace su = tensorrt_llm::executor::serialize_utils; + su::serialize(message.mType, os); + su::serialize(message.mWorkerAddress, os); + } + + static UcxCmMessage deserialize(std::istream& is) + { + namespace su = tensorrt_llm::executor::serialize_utils; + auto type = su::deserialize(is); + auto workerAddress = su::deserialize>(is); + return UcxCmMessage(type, workerAddress); + } +}; static std::string getLocalIp() { @@ -100,6 +129,22 @@ static std::string getLocalIp() return ip; } +std::optional> parse_zmq_endpoint(std::string const& endpoint) +{ + std::regex ipv4_regex(R"(tcp://([\d\.]+):(\d+))"); + std::regex ipv6_regex(R"(tcp://\[([0-9a-fA-F:]+)\]:(\d+))"); + std::smatch match; + if (std::regex_match(endpoint, match, ipv4_regex)) + { + return std::make_pair(match[1].str(), std::stoi(match[2].str())); + } + else if (std::regex_match(endpoint, match, ipv6_regex)) + { + return std::make_pair(match[1].str(), std::stoi(match[2].str())); + } + return std::nullopt; +} + UcxConnectionManager::UcxConnectionManager() { @@ -120,22 +165,23 @@ UcxConnectionManager::UcxConnectionManager() std::string error = "Error creating worker and starting progress thread for rank " + std::string(e.what()); TLLM_THROW(error); } + auto workerAddressPtr = mWorkersPool.front()->getAddress(); + mWorkerAddress = workerAddressPtr->getString(); - try - { - - mListener = mWorkersPool.front()->createListener(0, listenerCallback, this); - } - catch (std::exception const& e) - { - std::string error = "Error creating listener for rank " + std::string(e.what()); - TLLM_THROW(error); - } - - // Get local IP address + mZmqRepSocket = zmq::socket_t(mZmqContext, zmq::socket_type::rep); + mZmqRepSocket.set(zmq::sockopt::sndhwm, 1000); std::string localIp = getLocalIp(); - auto port = mListener->getPort(); - SocketState socketState{port, localIp}; + mZmqRepSocket.bind("tcp://" + localIp + ":*"); + mZmqRepEndpoint = mZmqRepSocket.get(zmq::sockopt::last_endpoint); + TLLM_LOG_INFO(mpi::MpiComm::world().getRank(), "UcxConnectionManager::UcxConnectionManager mZmqRepEndpoint: %s", + mZmqRepEndpoint.c_str()); + auto parse_result = parse_zmq_endpoint(mZmqRepEndpoint); + TLLM_CHECK_WITH_INFO(parse_result.has_value(), "Failed to parse ZMQ endpoint"); + auto [ip, port] = parse_result.value(); + TLLM_LOG_INFO(mpi::MpiComm::world().getRank(), "UcxConnectionManager::UcxConnectionManager ip: %s, port: %d", + ip.c_str(), port); + + SocketState socketState{static_cast(port), ip}; std::vector socketStates(mpi::MpiComm::session().getSize()); if (mpi::MpiComm::session().getSize() > 1) @@ -179,6 +225,47 @@ UcxConnectionManager::UcxConnectionManager() } mCommState = CommState(socketStates, mpi::MpiComm::session().getRank()); TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), " ***** UCX mCommState: %s", mCommState.toString().c_str()); + + mZmqRepThread = std::thread( + [this]() + { + while (true) + { + zmq::message_t message; + auto ret = mZmqRepSocket.recv(message); + TLLM_CHECK_WITH_INFO(ret, "mZmqRepSocket.recv failed"); + std::string recvMessage(static_cast(message.data()), message.size()); + std::istringstream is(recvMessage); + UcxCmMessage ucxCmessage = UcxCmMessage::deserialize(is); + + if (ucxCmessage.mType == UcxCmMessage::MessageType::GET_WORKER_ADDRESS) + { + // add Connection + TLLM_CHECK_WITH_INFO(ucxCmessage.mWorkerAddress.has_value(), "workerAddress is null"); + std::string workerAddress = ucxCmessage.mWorkerAddress.value(); + std::string selfWorkerAddress = mWorkerAddress; + UcxCmMessage serverMessage(UcxCmMessage::MessageType::SERVER_WORKER_ADDRESS, selfWorkerAddress); + std::ostringstream oStream; + UcxCmMessage::serialize(serverMessage, oStream); + std::string serverMessageStr = oStream.str(); + mZmqRepSocket.send(zmq::buffer(serverMessageStr), zmq::send_flags::none); + addConnection(workerAddress); + } + else if (ucxCmessage.mType == UcxCmMessage::MessageType::STOP) + { + UcxCmMessage stopMessage(UcxCmMessage::MessageType::STOP, std::nullopt); + std::ostringstream oStream; + UcxCmMessage::serialize(stopMessage, oStream); + std::string stopMessageStr = oStream.str(); + mZmqRepSocket.send(zmq::buffer(stopMessageStr), zmq::send_flags::none); + break; + } + else + { + TLLM_THROW("Zmq recv unknown message: %s", recvMessage.c_str()); + } + } + }); } catch (std::exception const& e) { @@ -195,14 +282,38 @@ UcxConnectionManager::~UcxConnectionManager() { worker->stopProgressThread(); } + if (mZmqRepThread.joinable()) + { + zmq::socket_t socket(mZmqContext, zmq::socket_type::req); + socket.connect(mZmqRepEndpoint); + UcxCmMessage stopMessage(UcxCmMessage::MessageType::STOP, std::nullopt); + std::ostringstream oStream; + UcxCmMessage::serialize(stopMessage, oStream); + std::string stopMessageStr = oStream.str(); + socket.send(zmq::buffer(stopMessageStr), zmq::send_flags::none); + zmq::message_t reply; + auto ret = socket.recv(reply); + TLLM_CHECK_WITH_INFO(ret, "zmq socket.recv failed"); + std::string replyStr(static_cast(reply.data()), reply.size()); + std::istringstream is(replyStr); + UcxCmMessage serverMessage = UcxCmMessage::deserialize(is); + TLLM_CHECK_WITH_INFO(serverMessage.mType == UcxCmMessage::MessageType::STOP, "serverMessage.mType is not STOP"); + socket.close(); + mZmqRepThread.join(); + } + + mZmqRepSocket.close(); + + mZmqContext.close(); TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), "END UcxConnectionManager::~UcxConnectionManager"); } -void UcxConnectionManager::addConnection(ucp_conn_request_h connRequest) +void UcxConnectionManager::addConnection(std::string const& workerAddress) { try { - std::shared_ptr newEp = mListener->createEndpointFromConnRequest(connRequest, true); + auto workerAddressPtr = ucxx::createAddressFromString(workerAddress); + auto newEp = mWorkersPool.front()->createEndpointFromWorkerAddress(workerAddressPtr, true); UcxConnection::ConnectionIdType connectionId = getNewConnectionId(newEp); std::scoped_lock lock(mConnectionFuturesMutex); @@ -225,6 +336,23 @@ void UcxConnectionManager::addConnection(ucp_conn_request_h connRequest) } } +std::string build_zmq_endpoint(std::string const& ip, uint16_t port) +{ + std::ostringstream oss; + + std::regex ipv6_regex(R"([0-9a-fA-F]*:[0-9a-fA-F]*:[0-9a-fA-F]*.*)"); + if (std::regex_match(ip, ipv6_regex) && ip.find(':') != std::string::npos) + { + oss << "tcp://[" << ip << "]:" << port; + } + else + { + oss << "tcp://" << ip << ":" << port; + } + + return oss.str(); +} + UcxConnection::ConnectionIdType UcxConnectionManager::addConnection(std::string const& ip, uint16_t port) { static std::mutex sAddConnectionIPMutex; @@ -237,7 +365,24 @@ UcxConnection::ConnectionIdType UcxConnectionManager::addConnection(std::string // This lock ensures that only one thread can create an endpoint from hostname and establish a UCX // connection at a time, guaranteeing that the only one listener will send connectionId to requester in the // same time. - std::shared_ptr newEp = mWorkersPool.front()->createEndpointFromHostname(ip, port, true); + auto reqSocket = zmq::socket_t(mZmqContext, zmq::socket_type::req); + reqSocket.connect(build_zmq_endpoint(ip, port)); + UcxCmMessage getWorkerAddressMessage(UcxCmMessage::MessageType::GET_WORKER_ADDRESS, mWorkerAddress); + std::ostringstream oStream; + UcxCmMessage::serialize(getWorkerAddressMessage, oStream); + std::string getWorkerAddressMessageStr = oStream.str(); + reqSocket.send(zmq::buffer(getWorkerAddressMessageStr), zmq::send_flags::none); + zmq::message_t reply; + auto ret = reqSocket.recv(reply); + TLLM_CHECK_WITH_INFO(ret, "zmq socket.recv failed"); + std::string replyStr(static_cast(reply.data()), reply.size()); + std::istringstream is(replyStr); + UcxCmMessage serverMessage = UcxCmMessage::deserialize(is); + TLLM_CHECK_WITH_INFO(serverMessage.mType == UcxCmMessage::MessageType::SERVER_WORKER_ADDRESS, + "serverMessage.mType is not SERVER_WORKER_ADDRESS"); + std::string serverWorkerAddress = serverMessage.mWorkerAddress.value(); + auto serverWorkerAddressPtr = ucxx::createAddressFromString(serverWorkerAddress); + auto newEp = mWorkersPool.front()->createEndpointFromWorkerAddress(serverWorkerAddressPtr, true); connectionId = getNewConnectionId(newEp); connection = std::make_shared(connectionId, newEp, this, true); } diff --git a/cpp/tensorrt_llm/executor/cache_transmission/ucx_utils/ucxCacheCommunicator.h b/cpp/tensorrt_llm/executor/cache_transmission/ucx_utils/ucxCacheCommunicator.h index 9f60b3dfcdd..642f1f2c6e2 100644 --- a/cpp/tensorrt_llm/executor/cache_transmission/ucx_utils/ucxCacheCommunicator.h +++ b/cpp/tensorrt_llm/executor/cache_transmission/ucx_utils/ucxCacheCommunicator.h @@ -36,6 +36,7 @@ #include #include #include +#include namespace tensorrt_llm::executor::kv_cache { @@ -45,16 +46,20 @@ class UcxConnectionManager : public ConnectionManager, public std::enable_shared private: std::shared_ptr mUcxCtx; std::vector> mWorkersPool; + std::string mWorkerAddress; std::map> mConnections; std::map> mConnectionFutures; std::mutex mConnectionsMutex; std::mutex mConnectionFuturesMutex; std::unordered_map mAddressToConnectionId; std::mutex mAddressToConnectionIdMutex; - std::shared_ptr mListener; CommState mCommState; int mDevice; std::atomic mConnectionIdCounter{1}; + zmq::context_t mZmqContext; + zmq::socket_t mZmqRepSocket; + std::string mZmqRepEndpoint; + std::thread mZmqRepThread; UcxConnection::ConnectionIdType getNewConnectionId(std::shared_ptr const& newEp); UcxConnection::ConnectionIdType addConnection(std::string const& ip, uint16_t port); @@ -69,7 +74,7 @@ class UcxConnectionManager : public ConnectionManager, public std::enable_shared return std::make_unique(); } - void addConnection(ucp_conn_request_h connRequest); + void addConnection(std::string const& workerAddress); Connection const* recvConnect(DataContext const& ctx, void* data, size_t size) override; std::vector getConnections(CommState const& state) override; [[nodiscard]] CommState const& getCommState() const override; diff --git a/docker/common/install_base.sh b/docker/common/install_base.sh index bf6e11420a6..cd5e45ee00b 100644 --- a/docker/common/install_base.sh +++ b/docker/common/install_base.sh @@ -61,7 +61,8 @@ init_ubuntu() { python3-pip \ python-is-python3 \ wget \ - pigz + pigz \ + libzmq3-dev if ! command -v mpirun &> /dev/null; then DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends openmpi-bin libopenmpi-dev fi @@ -129,6 +130,7 @@ install_gcctoolset_rockylinux() { openmpi-devel \ pigz \ rdma-core-devel \ + zeromq-devel \ -y echo "source scl_source enable gcc-toolset-11" >> "${ENV}" echo 'export PATH=/usr/lib64/openmpi/bin:$PATH' >> "${ENV}" diff --git a/jenkins/current_image_tags.properties b/jenkins/current_image_tags.properties index 24078a941fb..dee2ee7218f 100644 --- a/jenkins/current_image_tags.properties +++ b/jenkins/current_image_tags.properties @@ -11,7 +11,7 @@ # # NB: Typically, the suffix indicates the PR whose CI pipeline generated the images. In case that # images are adopted from PostMerge pipelines, the abbreviated commit hash is used instead. -LLM_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.06-py3-x86_64-ubuntu24.04-trt10.11.0.33-skip-tritondevel-202507251001-5678 -LLM_SBSA_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.06-py3-aarch64-ubuntu24.04-trt10.11.0.33-skip-tritondevel-202507251001-5678 -LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.1-devel-rocky8-x86_64-rocky8-py310-trt10.11.0.33-skip-tritondevel-202507251001-5678 -LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.1-devel-rocky8-x86_64-rocky8-py312-trt10.11.0.33-skip-tritondevel-202507251001-5678 +LLM_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.06-py3-x86_64-ubuntu24.04-trt10.11.0.33-skip-tritondevel-202508051130-6090 +LLM_SBSA_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.06-py3-aarch64-ubuntu24.04-trt10.11.0.33-skip-tritondevel-202508051130-6090 +LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.1-devel-rocky8-x86_64-rocky8-py310-trt10.11.0.33-skip-tritondevel-202508051130-6090 +LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.1-devel-rocky8-x86_64-rocky8-py312-trt10.11.0.33-skip-tritondevel-202508051130-6090