Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
2003ccb
feat: Ray integration
tongyuantongyu Aug 18, 2025
fd171ef
Gracefully exit.
joyang-nv Aug 18, 2025
f560d01
fix single node disagg random failure
hchings Aug 19, 2025
b7d77ac
revert cpu req
hchings Aug 19, 2025
c05b72c
cleanup ray process rendering in process tree report
tongyuantongyu Aug 18, 2025
6c689d5
Use weakref for results and make ray shared queue exit gracefully.
joyang-nv Aug 19, 2025
f4c4113
Use different way to get weak ref of actor handle.
joyang-nv Aug 19, 2025
57892a1
Cache transceiver refactor
shuyixiong Aug 19, 2025
5fd3097
Remove unused code for ray.
joyang-nv Aug 19, 2025
955436f
Refine code
shuyixiong Aug 19, 2025
c6310d0
update run_cluster.sh for multinode disagg
hchings Aug 19, 2025
f2ec920
WAR ProcessGroup not pickleable.
tongyuantongyu Aug 20, 2025
5004104
fix ci failure, fix abort_request()
hchings Aug 20, 2025
9f2e709
Add heartbeat to prevent trtllm-serve timeout due to gloo. Fix pre-co…
hchings Aug 20, 2025
115b971
ccacheTransceiver nits
tongyuantongyu Aug 21, 2025
9b76654
fix build
tongyuantongyu Aug 21, 2025
3abdec9
minor fix and code cleanup
hchings Aug 21, 2025
dc9e398
skip build mesh for single rank world
tongyuantongyu Aug 22, 2025
e0fd152
Warm up ray queue actors and optimize pickle out.
joyang-nv Aug 21, 2025
4bb1ef8
Clean up.
joyang-nv Aug 24, 2025
ddde721
nanobind fix and a few cleanup
hchings Aug 24, 2025
d87dce3
revert working extension; some renamings.
hchings Aug 25, 2025
599db47
Remove update_weights and reset_kv_cache and collective_rpc api
shuyixiong Aug 25, 2025
c5b5b33
Cleanup
shuyixiong Aug 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ tensorrt_llm/deep_ep_cpp_tllm.pyi
tensorrt_llm/deep_gemm/
tensorrt_llm/deep_gemm_cpp_tllm.*.so
tensorrt_llm/deep_gemm_cpp_tllm.pyi
tensorrt_llm/pg_utils_bindings.*.so
*docs/cpp_docs*
*docs/source/_cpp_gen*
docs/source/**/*.rst
Expand Down
144 changes: 140 additions & 4 deletions cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,17 @@
#include "tensorrt_llm/executor/cacheCommunicator.h"
#include "tensorrt_llm/executor/dataTransceiverState.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include "tensorrt_llm/runtime/utils/pgUtils.h"
#include <future>
#include <map>
#include <memory>
#include <mutex>
#include <optional>
#include <pybind11/pybind11.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/custom_class.h>
#include <torch/python.h>
#include <type_traits>
#include <vector>

using SizeType32 = tensorrt_llm::runtime::SizeType32;

Expand All @@ -37,6 +45,131 @@ class BaseCacheTransceiver;
class DataResponder;
class DataRequester;

class CacheTransceiverComm
{
public:
// Construct from a non-owning raw pointer, won't take ownership of the pointer
explicit CacheTransceiverComm(mpi::MpiComm const* mpiComm)
: mMpiComm(std::shared_ptr<mpi::MpiComm const>(nullptr), mpiComm)
{
}

// Construct from a shared_ptr with shared ownership
explicit CacheTransceiverComm(std::shared_ptr<mpi::MpiComm const> mpiComm)
: mMpiComm(std::move(mpiComm))
{
}

// Construct from a ProcessGroup communicator
explicit CacheTransceiverComm(c10::intrusive_ptr<c10d::ProcessGroup> pgComm)
: mPgComm(std::move(pgComm))
{
}

~CacheTransceiverComm() = default;

bool isMpi() const noexcept
{
return mMpiComm != nullptr;
}

int getRank() const
{
if (isMpi())
{
return mMpiComm->getRank();
}
return mPgComm->getRank();
}

int getSize() const
{
if (isMpi())
{
return mMpiComm->getSize();
}
return mPgComm->getSize();
}

void allgather(void const* sendbuf, void* recvbuf, int count, mpi::MpiType dtype) const
{
if (isMpi())
{
mMpiComm->allgather(sendbuf, recvbuf, count, dtype);
return;
}
TLLM_THROW("Input arguments only supported in mpi");
}

template <typename Input, typename Output>
bool allgather(Input input, Output output, c10d::AllgatherOptions options = c10d::AllgatherOptions()) const
{
if (isMpi())
{
TLLM_THROW("Input arguments only supported in pg");
}
tensorrt_llm::pg_utils::PgHelper pgh{mPgComm};

PGCHECK_THROW(pgh.allgather(input, output, options));
return true;
}

template <typename Input, typename Output>
bool allgatherv(Input input, Output output, std::vector<int> const& sizes,
c10d::AllgatherOptions options = c10d::AllgatherOptions()) const
{
if (isMpi())
{
TLLM_THROW("Input arguments only supported in pg");
}
tensorrt_llm::pg_utils::PgHelper pgh{mPgComm};
PGCHECK_THROW(pgh.allgatherv(input, output, sizes, options));
return true;
}

bool allgatherv(void const* sendbuf, int sendcount, mpi::MpiType sendtype, void* recvbuf,
std::vector<int> const& recvcounts, std::vector<int> const& displs, mpi::MpiType recvtype) const
{
if (isMpi())
{
mMpiComm->allgatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype);
return true;
}
TLLM_THROW("Input arguments only supported in mpi");
}

CacheTransceiverComm split(int color, int key)
{
if (isMpi())
{
auto subgroup = mMpiComm->split(color, key);
return CacheTransceiverComm(std::make_shared<mpi::MpiComm const>(std::move(subgroup)));
}
bool const initialized = Py_IsInitialized();
TLLM_CHECK_WITH_INFO(initialized, "Trying to use ProcessGroup communicator but Python is not initialized");
try
{
pybind11::gil_scoped_acquire gil;
auto const m = pybind11::module::import("tensorrt_llm._torch.distributed.pg_utils");
// Properly box the existing intrusive_ptr ProcessGroup into an IValue
// and convert to a Python object without constructing a new instance.
auto const py_pg = torch::jit::toPyObject(c10::IValue(mPgComm));

auto const py_sub_pg = m.attr("split")(color, key, py_pg);
auto pgSub = torch::jit::toCustomClass<c10d::ProcessGroup>(py_sub_pg);
return CacheTransceiverComm(pgSub);
}
catch (...)
{
TLLM_THROW("Failed to split process group");
}
}

private:
std::shared_ptr<mpi::MpiComm const> mMpiComm;
c10::intrusive_ptr<c10d::ProcessGroup> mPgComm;
};

class CacheTransceiverFactory
{
public:
Expand Down Expand Up @@ -114,9 +247,12 @@ class CacheTransceiver : public BaseCacheTransceiver
std::unique_ptr<DataRequester> mDataRequester;
std::vector<std::pair<LlmRequest*, std::future<void>>> mResponderFutures;
std::vector<std::pair<LlmRequest*, std::future<void>>> mRequesterFutures;
mpi::MpiComm const *mMpiGroupComm{nullptr}, *mMpiWorldComm{nullptr};
std::shared_ptr<mpi::MpiComm> mMpiGroupTensorParaComm, mMpiGroupPipeParaComm, mMpiGroupDataComm,
mMpiGroupTPInDPComm;
// only for mpi backend, don't need it for ucx backend
mpi::MpiComm const* mMpiWorldComm{nullptr};

std::shared_ptr<CacheTransceiverComm> mGroupComm;
std::shared_ptr<CacheTransceiverComm> mGroupTensorParaComm, mGroupPipeParaComm, mGroupDataComm, mGroupTPInDPComm;

executor::kv_cache::CommState const* mCommState;
std::unique_ptr<executor::kv_cache::CacheState> mCacheState;
std::unique_ptr<executor::kv_cache::ConnectionManager> mManager;
Expand Down
19 changes: 19 additions & 0 deletions cpp/include/tensorrt_llm/runtime/utils/mpiUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include <cstdlib>
#include <memory>
#include <mutex>
#include <optional>
#include <thread>

#if ENABLE_MULTI_DEVICE
Expand Down Expand Up @@ -425,7 +426,25 @@ class MpiComm
return !(rhs == *this);
}

bool couldUseMPI() const
{
if (!mDisableMPI.has_value())
{
char* val = std::getenv("TLLM_DISABLE_MPI");
;
bool disable_mpi = false;
if (val != NULL && std::string(val) == "1")
{
throw std::runtime_error("MPI is disabled, DON\'T USE MPI");
}
mDisableMPI = disable_mpi;
}

return mDisableMPI.value();
}
Comment on lines +429 to +444
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

couldUseMPI() returns the wrong value and throws where a pure query is expected

  • The function name implies a pure check, but it throws when disabled; this makes if (couldUseMPI()) branches unusable.
  • It returns mDisableMPI.value() (i.e., “disabled?”) instead of “could use?” causing it to return false when MPI is enabled.
  • It also relies on std::string without including and uses std::runtime_error without including .

This will break the intended backend gating and can cause spurious exceptions.

Apply this fix to make the API boolean and non-throwing; let callers decide behavior:

-    bool couldUseMPI() const
+    [[nodiscard]] bool couldUseMPI() const
     {
         if (!mDisableMPI.has_value())
         {
-            char* val = std::getenv("TLLM_DISABLE_MPI");
-            ;
-            bool disable_mpi = false;
-            if (val != NULL && std::string(val) == "1")
-            {
-                throw std::runtime_error("MPI is disabled, DON\'T USE MPI");
-            }
-            mDisableMPI = disable_mpi;
+            char const* val = std::getenv("TLLM_DISABLE_MPI");
+            bool const disableMpi = (val != nullptr) && (std::strcmp(val, "1") == 0);
+            mDisableMPI = disableMpi;
         }
-
-        return mDisableMPI.value();
+        // couldUse == !disabled
+        return !mDisableMPI.value();
     }

And add missing headers:

 #include <memory>
 #include <mutex>
+#include <optional>
+#include <stdexcept>
+#include <string>
 #include <thread>

If you prefer to raise on misuse, add a separate ensureMPIEnabled() that throws, while couldUseMPI() remains a pure query.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
bool couldUseMPI() const
{
if (!mDisableMPI.has_value())
{
char* val = std::getenv("TLLM_DISABLE_MPI");
;
bool disable_mpi = false;
if (val != NULL && std::string(val) == "1")
{
throw std::runtime_error("MPI is disabled, DON\'T USE MPI");
}
mDisableMPI = disable_mpi;
}
return mDisableMPI.value();
}
++ b/cpp/include/tensorrt_llm/runtime/utils/mpiUtils.h
@@
#include <memory>
#include <mutex>
#include <optional>
#include <stdexcept>
#include <string>
#include <thread>
Suggested change
bool couldUseMPI() const
{
if (!mDisableMPI.has_value())
{
char* val = std::getenv("TLLM_DISABLE_MPI");
;
bool disable_mpi = false;
if (val != NULL && std::string(val) == "1")
{
throw std::runtime_error("MPI is disabled, DON\'T USE MPI");
}
mDisableMPI = disable_mpi;
}
return mDisableMPI.value();
}
++ b/cpp/include/tensorrt_llm/runtime/utils/mpiUtils.h
@@ lines 429-444
[[nodiscard]] bool couldUseMPI() const
{
if (!mDisableMPI.has_value())
{
- char* val = std::getenv("TLLM_DISABLE_MPI");
- ;
- bool disable_mpi = false;
- if (val != NULL && std::string(val) == "1")
- {
- throw std::runtime_error("MPI is disabled, DON\'T USE MPI");
- }
char const* val = std::getenv("TLLM_DISABLE_MPI");
bool const disableMpi = (val != nullptr) && (std::strcmp(val, "1") == 0);
mDisableMPI = disableMpi;
}
-
// couldUse == not disabled
return !mDisableMPI.value();
}
🤖 Prompt for AI Agents
In cpp/include/tensorrt_llm/runtime/utils/mpiUtils.h around lines 429-444,
couldUseMPI() currently throws, always sets mDisableMPI=false, and returns the
disabled flag instead of whether MPI can be used; also missing headers. Change
the function to be a pure, non-throwing query: read getenv("TLLM_DISABLE_MPI"),
set mDisableMPI = (val != nullptr && std::string(val) == "1"); do NOT throw; and
return !mDisableMPI.value() (i.e. true when MPI is usable). Also add the missing
#include <string> and #include <stdexcept> at the top of the header (or remove
std::runtime_error usage if you keep no throwing). If you still want a throwing
variant, add a separate ensureMPIEnabled() that checks couldUseMPI() and throws.


private:
mutable std::optional<bool> mDisableMPI;
//! \brief Corresponds to `world()` by default, but can be overridden per process.
static MpiComm& mutableSession();

Expand Down
Loading