Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Support latest trtllm
  • Loading branch information
SchumiDing committed Feb 2, 2026
commit bf71c9b4c3c19b6496133d93568119aea6d8951d
48 changes: 0 additions & 48 deletions tests/workers/config/test_optim_config_on_cpu.py

This file was deleted.

96 changes: 80 additions & 16 deletions verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# limitations under the License.
import base64
import inspect
import pickle
from typing import Optional

import torch

from tensorrt_llm import serialization
from tensorrt_llm._ray_utils import control_action_decorator
from tensorrt_llm._torch.modules.fused_moe.moe_load_balancer import MoeLoadBalancer
from tensorrt_llm._torch.utils import get_device_uuid
Expand All @@ -42,30 +44,85 @@ def update_weights(self, ipc_handles: Optional[dict] = None):
try:
if not hasattr(self.engine.model_engine.model, "first_pre_reload_weights"):
for module in self.engine.model_engine.model.modules():
if hasattr(module, "pre_reload_weights") and not getattr(module, "_weights_removed", False):
if hasattr(module, "pre_reload_weights") and not getattr(
module, "_weights_removed", False
):
module.pre_reload_weights()
self.engine.model_engine.model.first_pre_reload_weights = True
setattr(self.engine.model_engine.model, "first_pre_reload_weights", True)

if ipc_handles is not None:
device_uuid = get_device_uuid()
handles = ipc_handles.get(device_uuid, None)
if handles is not None:
weights = pickle.loads(base64.b64decode(handles))
model = self.engine.model_engine.model
load_weights_args = inspect.getfullargspec(model.load_weights).args
supports_partial_loading = "allow_partial_loading" in load_weights_args

if supports_partial_loading:
self.engine.model_engine.model_loader.reload(model, weights, allow_partial_loading=True)
else:
self.engine.model_engine.model_loader.reload(model, weights, allow_partial_loading=False)
logger.info("Update weights from IPC handles")
device_uuid = get_device_uuid(self.device_id)

if device_uuid not in ipc_handles:
raise ValueError(f"Device UUID {device_uuid} not found in ipc_handles")

weights = {}

serialized_handles = ipc_handles[device_uuid]
if isinstance(serialized_handles, str):
# Data is base64-encoded pickled bytes - deserialize it
# using restricted unpickler from tensorrt_llm.serialization
logger.info("Deserializing base64-encoded weight handles")
decoded_data = base64.b64decode(serialized_handles)
# Allow basic builtins and all torch modules
approved_imports = {
"builtins": [
"list",
"tuple",
"str",
"int",
"float",
"bool",
"bytes",
"dict",
"NoneType",
"type",
],
}
all_handles = serialization.loads(
decoded_data,
approved_imports=approved_imports,
approved_module_patterns=[r"^torch.*"],
)

# Verify the result is a list as expected
if not isinstance(all_handles, list):
raise ValueError(
f"Deserialized data must be a list, got {type(all_handles).__name__} instead"
)
else:
# Data is already in the correct format (backward compatibility)
all_handles = serialized_handles

for param_name, tensor_handle in all_handles:
func, args = tensor_handle
list_args = list(args)
list_args[6] = self.device_id
tensor = func(*list_args)
weights[param_name] = tensor

logger.info(f"weights key size: {len(weights.keys())}")

# Check if model supports partial loading and use appropriate strategy
model = self.engine.model_engine.model
load_weights_args = inspect.getfullargspec(model.load_weights).args
supports_partial_loading = "allow_partial_loading" in load_weights_args

if supports_partial_loading:
self.engine.model_engine.model_loader.reload(model, weights, allow_partial_loading=True)
else:
self.engine.model_engine.model_loader.reload(model, weights, allow_partial_loading=False)
else:
logger.info("Finalize update weights")
for module in self.engine.model_engine.model.modules():
if hasattr(module, "process_weights_after_loading") and not getattr(
module, "_weights_removed", False
):
module.process_weights_after_loading()
if hasattr(module, "post_load_weights") and not getattr(module, "_weights_removed", False):
if hasattr(module, "post_load_weights") and not getattr(
module, "_weights_removed", False
):
module.post_load_weights()
moe_load_balancer = getattr(self.engine.model_engine, "moe_load_balancer", None)
if isinstance(moe_load_balancer, MoeLoadBalancer):
Expand All @@ -79,3 +136,10 @@ def update_weights(self, ipc_handles: Optional[dict] = None):
except Exception as e:
logger.error("Encountered an error in update_weights")
raise e

def check_weights_updated(self) -> bool:
"""Check if the weights are updated to 0."""
weights_updated = True
for name, p in self.engine.model_engine.model.named_parameters():
weights_updated = weights_updated and torch.allclose(p, torch.zeros_like(p))
return weights_updated