Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
2003ccb
feat: Ray integration
tongyuantongyu Aug 18, 2025
fd171ef
Gracefully exit.
joyang-nv Aug 18, 2025
f560d01
fix single node disagg random failure
hchings Aug 19, 2025
b7d77ac
revert cpu req
hchings Aug 19, 2025
c05b72c
cleanup ray process rendering in process tree report
tongyuantongyu Aug 18, 2025
6c689d5
Use weakref for results and make ray shared queue exit gracefully.
joyang-nv Aug 19, 2025
f4c4113
Use different way to get weak ref of actor handle.
joyang-nv Aug 19, 2025
57892a1
Cache transceiver refactor
shuyixiong Aug 19, 2025
5fd3097
Remove unused code for ray.
joyang-nv Aug 19, 2025
955436f
Refine code
shuyixiong Aug 19, 2025
c6310d0
update run_cluster.sh for multinode disagg
hchings Aug 19, 2025
f2ec920
WAR ProcessGroup not pickleable.
tongyuantongyu Aug 20, 2025
5004104
fix ci failure, fix abort_request()
hchings Aug 20, 2025
9f2e709
Add heartbeat to prevent trtllm-serve timeout due to gloo. Fix pre-co…
hchings Aug 20, 2025
115b971
ccacheTransceiver nits
tongyuantongyu Aug 21, 2025
9b76654
fix build
tongyuantongyu Aug 21, 2025
3abdec9
minor fix and code cleanup
hchings Aug 21, 2025
dc9e398
skip build mesh for single rank world
tongyuantongyu Aug 22, 2025
e0fd152
Warm up ray queue actors and optimize pickle out.
joyang-nv Aug 21, 2025
4bb1ef8
Clean up.
joyang-nv Aug 24, 2025
ddde721
nanobind fix and a few cleanup
hchings Aug 24, 2025
d87dce3
revert working extension; some renamings.
hchings Aug 25, 2025
599db47
Remove update_weights and reset_kv_cache and collective_rpc api
shuyixiong Aug 25, 2025
c5b5b33
Cleanup
shuyixiong Aug 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Remove update_weights and reset_kv_cache and collective_rpc api
  • Loading branch information
shuyixiong committed Aug 26, 2025
commit 599db4708e0c4e8bf47da5fe4117d2c73190a3da
20 changes: 0 additions & 20 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -737,12 +737,6 @@ class WindowBlockManager
return 0;
}

void resetReuseState()
{
mCachedBlocksRoot
= std::make_shared<KVCacheBlock>(KVCacheBlock::kCachedBlocksRootId, tensorrt_llm::kernels::KVCacheIndex{0});
}

private:
//! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq.
void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx);
Expand Down Expand Up @@ -1126,14 +1120,6 @@ class BlockManager
//! \brief Update cache offsets for block at index
void updateCacheBlockOffsetsAtIdx(GenerationRequest& seq, SizeType32 windowSize, SizeType32 blockIdx);

void resetReuseState()
{
for (auto& [windowSize, manager] : mWindowBlockManagers)
{
manager.resetReuseState();
}
}

private:
[[nodiscard]] WindowBlockManager const& windowManagerByLayer(SizeType32 layerIdx) const
{
Expand Down Expand Up @@ -1304,7 +1290,6 @@ class BaseKVCacheManager

virtual void refreshBlocks() = 0;
virtual void flushIterationEvents() = 0;
virtual void resetReuseState() = 0;

[[nodiscard]] static SizeType32 getSinkBubbleLength(SizeType32 sinkTokenLen, SizeType32 tokensPerBlock);

Expand Down Expand Up @@ -1654,11 +1639,6 @@ class KVCacheManager : public BaseKVCacheManager
mBlockManager.flushIterationEvents();
}

void resetReuseState() override
{
mBlockManager.resetReuseState();
}

/// @brief Finds the maximum attention window that can be used on a sequence, given some kv-cache block capacity.
///
/// @param inputLength The number of input tokens in the sequence.
Expand Down
3 changes: 1 addition & 2 deletions cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
.def("get_cache_block_ids", &BaseKVCacheManager::getCacheBlockIds)
.def("get_batch_cache_block_ids", &BaseKVCacheManager::getBatchCacheBlockIds)
.def("get_newly_allocated_block_ids", &BaseKVCacheManager::getNewlyAllocatedBlockIds)
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents)
.def("reset_reuse_state", &BaseKVCacheManager::resetReuseState);
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents);

py::enum_<tbk::CacheType>(m, "CacheType")
.value("SELF", tbk::CacheType::kSELF)
Expand Down
5 changes: 0 additions & 5 deletions examples/llm-api/ray/run_ray_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,6 @@ if [ -f "llm_inference_async_ray.py" ]; then
run_python_file "llm_inference_async_ray.py"
fi

# 4. test_update_weight_from_ipc.py
if [ -f "test_update_weight_from_ipc.py" ]; then
run_python_file "test_update_weight_from_ipc.py"
fi

# Run MPI guarding tests
run_python_file "../llm_inference.py"

Expand Down
231 changes: 0 additions & 231 deletions examples/llm-api/ray/test_update_weight_from_ipc.py

This file was deleted.

33 changes: 12 additions & 21 deletions tensorrt_llm/_torch/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,9 +849,6 @@ def load_single_module(name, module):
for new_name in params_map[names[-1]]:
fw = filter_weights('.'.join(names[:-1] + [new_name]),
weights)
# tmp fixes to enable partial updates in old path
if not fw:
continue
if new_name in ['k_proj', 'v_proj']:
num_kv_heads_list = [num_kv_heads
] * len(fw) if isinstance(
Expand All @@ -868,27 +865,24 @@ def load_single_module(name, module):
}

module_weights.append(fw)
if module_weights:
module.load_weights(weights=module_weights)
module.load_weights(weights=module_weights)

else:
module_weights = filter_weights(name, weights)
if module_weights:
if hasattr(module, 'load_weights'):
module.load_weights(weights=[module_weights])
else:
for n, p in module._parameters.items():
if p is not None:
p.data.copy_(module_weights[n][:])
if hasattr(module, 'load_weights'):
module.load_weights(weights=[module_weights])
else:
for n, p in module._parameters.items():
if p is not None:
p.data.copy_(module_weights[n][:])

if os.environ.get("TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL",
False) in ["True", "true", "1", "yes", "y"]:
for name, module in tqdm(list(
model.named_modules(remove_duplicate=False)),
for name, module in tqdm(list(model.named_modules()),
desc="Loading weights"):
load_single_module(name, module)
else:
all_modules = dict(model.named_modules(remove_duplicate=False))
all_modules = dict(model.named_modules())
serial_load_modules = []
if preload_weight_modules is not None:
for module in preload_weight_modules:
Expand All @@ -904,13 +898,10 @@ def load_single_module(name, module):
del all_modules[module]
pbar.close()

pbar = tqdm(list(model.named_modules(remove_duplicate=False)),
pbar = tqdm(list(model.named_modules()),
desc="Loading weights concurrently")
args_list = [
(name, module)
for name, module in model.named_modules(remove_duplicate=False)
if name not in serial_load_modules
]
args_list = [(name, module) for name, module in model.named_modules()
if name not in serial_load_modules]
run_concurrently(load_single_module, args_list, pbar=pbar)


Expand Down
Loading