Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Precommit check
  • Loading branch information
SchumiDing committed Jan 31, 2026
commit 0394ab512fdefe6be8a6b8fa5e2393dfa5e0777e
5 changes: 3 additions & 2 deletions verl/workers/rollout/trtllm_rollout/trtllm_async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ async def launch_server(self):

if self.is_vlm_model:
from tensorrt_llm.inputs.multimodal import MultimodalServerConfig

multimodal_config = MultimodalServerConfig(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a unittest for this new feature? There is a test_trtllm_async_server.py.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, I'm adding one

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test script and relating test workflow has been added

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't find the test_trtllm_async_server.py in verl repo, so I write a test script for test on both llm rollout and vlm rollout of tensorrt-llm rollout worker

Copy link
Collaborator

@hchings hchings Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't find the test_trtllm_async_server.py in verl repo

We have a unittest MR that should be merge shortly, that contains the test_trtllm_async_server.py.

media_io_kwargs={
"image": {
Expand Down Expand Up @@ -193,11 +194,10 @@ async def launch_server(self):
server_role=None,
metadata_server_cfg=None,
)

app = trtllm_server.app
self._server_port, self._server_task = await run_unvicorn(app, None, self._server_address)

@resume_on_abort
async def generate(
self,
prompt_ids: list[int],
Expand All @@ -207,6 +207,7 @@ async def generate(
video_data: Optional[list[Any]] = None,
) -> TokenOutput:
from tensorrt_llm.llmapi import SamplingParams

max_tokens = min(self.config.response_length, self.config.max_model_len - len(prompt_ids))
sampling_params["max_tokens"] = max_tokens
sampling_params["logprobs"] = 1 if sampling_params.pop("logprobs", False) else None
Expand Down
4 changes: 2 additions & 2 deletions verl/workers/rollout/trtllm_rollout/trtllm_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def __init__(
self.is_leader_rank = None
self.replica_rank = None
self.is_dp_rank = None
self._supports_partial_loading = None
self._supports_partial_loading = None

# hybrid mode
if self.device_mesh is not None:
Expand Down Expand Up @@ -421,7 +421,7 @@ async def flush():
await self.update_weights_from_ipc_handles(serialized_device_handles)
cur_available_bytes = total_available_bytes
cur_handles = []

# Query if model supports partial loading
supports_partial_loading = await self.get_supports_partial_loading()

Expand Down
32 changes: 18 additions & 14 deletions verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# Copyright 2026 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import inspect
import pickle
Expand All @@ -10,7 +23,6 @@


class WorkerExtension:
Copy link
Collaborator

@Superjomn Superjomn Feb 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the latest WorkerExtension in the TensorRT-LLM repo. Are there any motivations for implementing a new one in verl repo? I am thinking about how to unify both. Ideally, we may update the one in the TensorRT-LLM codebase, but if we need a minor change on it before the next trtllm version bump up, @hchings do you have a suggestion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. Ideally, we should still use the worker extension from the tedsnort-llm repo. But to support model that do not allow partial loading, I suppose the use of self.engine.model_engine.model_loader.reload should be able to use with param: allow_partial_loading=False

Copy link
Contributor Author

@SchumiDing SchumiDing Feb 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I add this new worker extension is to support allow_partial_loading=False, cause tensorrt-llm always set this param as True, but some models do not support

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll prefer that we keep this at TensorRT-LLM repo instead and make it generic for other RL FWs to reuse in the future.


def __init__(self):
pass

Expand All @@ -30,11 +42,9 @@ def update_weights(self, ipc_handles: Optional[dict] = None):
try:
if not hasattr(self.engine.model_engine.model, "first_pre_reload_weights"):
for module in self.engine.model_engine.model.modules():
if hasattr(module, "pre_reload_weights") and not getattr(
module, "_weights_removed", False
):
if hasattr(module, "pre_reload_weights") and not getattr(module, "_weights_removed", False):
module.pre_reload_weights()
setattr(self.engine.model_engine.model, "first_pre_reload_weights", True)
self.engine.model_engine.model.first_pre_reload_weights = True

if ipc_handles is not None:
device_uuid = get_device_uuid()
Expand All @@ -46,22 +56,16 @@ def update_weights(self, ipc_handles: Optional[dict] = None):
supports_partial_loading = "allow_partial_loading" in load_weights_args

if supports_partial_loading:
self.engine.model_engine.model_loader.reload(
model, weights, allow_partial_loading=True
)
self.engine.model_engine.model_loader.reload(model, weights, allow_partial_loading=True)
else:
self.engine.model_engine.model_loader.reload(
model, weights, allow_partial_loading=False
)
self.engine.model_engine.model_loader.reload(model, weights, allow_partial_loading=False)
else:
for module in self.engine.model_engine.model.modules():
if hasattr(module, "process_weights_after_loading") and not getattr(
module, "_weights_removed", False
):
module.process_weights_after_loading()
if hasattr(module, "post_load_weights") and not getattr(
module, "_weights_removed", False
):
if hasattr(module, "post_load_weights") and not getattr(module, "_weights_removed", False):
module.post_load_weights()
moe_load_balancer = getattr(self.engine.model_engine, "moe_load_balancer", None)
if isinstance(moe_load_balancer, MoeLoadBalancer):
Expand Down