Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
dcaacfe
Fix partial load problem, Add vlm support for trtllm rollout
SchumiDing Jan 31, 2026
0394ab5
Precommit check
SchumiDing Jan 31, 2026
0664ab1
Add check for if the model is vlm in trtllmhttpserver
SchumiDing Jan 31, 2026
bf71c9b
Support latest trtllm
SchumiDing Feb 2, 2026
f6e58b8
Support for qwen2.5 vl
SchumiDing Feb 2, 2026
7af6917
Add trtllm rollout test script
SchumiDing Feb 2, 2026
94c4eb0
Add test_trtllm_rollout workflow to test trtllm_rollout
SchumiDing Feb 2, 2026
25518fe
Add back mistakenly deleted file
SchumiDing Feb 2, 2026
fd007fb
Precommit check
SchumiDing Feb 2, 2026
659ec01
Merge branch 'verl-project:main' into vlm_trtllm_support
SchumiDing Feb 4, 2026
55b55dc
Merge branch 'verl-project:main' into vlm_trtllm_support
SchumiDing Feb 5, 2026
e2cc50b
Merge branch 'verl-project:main' into vlm_trtllm_support
SchumiDing Feb 5, 2026
ca17f8a
Merge branch 'verl-project:main' into vlm_trtllm_support
SchumiDing Feb 6, 2026
62af0f2
Merge branch 'verl-project:main' into vlm_trtllm_support
SchumiDing Feb 11, 2026
24a6620
Modified to inherit the worker extension class of tensorrt llm
SchumiDing Feb 11, 2026
6f055a2
Modified to inherit the worker extension class of tensorrt llm
SchumiDing Feb 11, 2026
d0b1d1d
fix readability problem of multimodal config
SchumiDing Feb 11, 2026
6b021f4
Remove need for multimodal server config
SchumiDing Feb 11, 2026
a7faa7b
Add vlm unit test into exisiting trtllm unit test
SchumiDing Feb 11, 2026
8519d36
add e2e script to train qwen2.5-vl with trtllm rollout
SchumiDing Feb 11, 2026
9acdcd6
Merge branch 'verl-project:main' into vlm_trtllm_support
SchumiDing Feb 12, 2026
5a145a5
Change import statement
SchumiDing Feb 12, 2026
3776338
remove reward config in e2e script
SchumiDing Feb 12, 2026
1706e71
When multi modal input for trtllm, decode with special token first
SchumiDing Feb 12, 2026
90837f3
rever typo
SchumiDing Feb 12, 2026
57506e2
revert typo
SchumiDing Feb 12, 2026
e193d0d
pre commit check
SchumiDing Feb 12, 2026
81050ce
Fix bugs
SchumiDing Feb 27, 2026
91d8c59
Update
SchumiDing Feb 27, 2026
60dd50b
Update
SchumiDing Feb 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
When multi modal input for trtllm, decode with special token first
  • Loading branch information
SchumiDing committed Feb 12, 2026
commit 1706e71d33c8f8ac643c89da499c49fb87958d56
2 changes: 1 addition & 1 deletion verl/experimental/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ async def generate(
server = self._choose_server(request_id)
output = await server.generate.remote(
request_id=uuid4().hex, # use new request_id for each turn
prompt_ids=prompt_ids,
prompt_ids=prompt_ids, # for trtllm, this is the raw prompt
sampling_params=sampling_params,
image_data=image_data,
video_data=video_data,
Expand Down
9 changes: 6 additions & 3 deletions verl/workers/rollout/trtllm_rollout/trtllm_async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ async def launch_server(self):

async def generate(
self,
prompt_ids: list[int],
prompt_ids: str,
sampling_params: dict[str, Any],
request_id: str,
image_data: Optional[list[Any]] = None,
Expand All @@ -201,16 +201,19 @@ async def generate(

trt_llm_sampling_params = SamplingParams(**sampling_params)
if self.is_vlm_model:
org_prompt = self.llm.tokenizer.decode(prompt_ids)
if image_data or video_data:

input_dict = {
"prompt_token_ids": prompt_ids,
"prompt": org_prompt,
"multi_modal_data": {},
"mm_processor_kwargs": {},
}
if image_data:
input_dict["multi_modal_data"]["image"] = image_data
if video_data:
input_dict["multi_modal_data"]["video"] = video_data

outputs = await self.llm.generate_async(
inputs=input_dict,
sampling_params=trt_llm_sampling_params,
Expand Down Expand Up @@ -369,7 +372,7 @@ async def launch_servers(self):
node_id=node_id,
soft=False,
),
runtime_env={"env_vars": {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"}},
runtime_env={"env_vars": {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1", "TLLM_NUMA_AWARE_WORKER_AFFINITY":"0"}},
name=name,
).remote(
config=self.config,
Expand Down
2 changes: 1 addition & 1 deletion verl/workers/rollout/trtllm_rollout/trtllm_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None
total_available_bytes = int(self.config.checkpoint_engine.update_weights_bucket_megabytes) * 1024 * 1024

try:
device_uuid = get_device_uuid(self.gpu_id)
device_uuid = get_device_uuid(int(self.gpu_id))
except Exception as e:
logger.error(f"Failed to get device UUID in update_weights(): {e}")
device_uuid = None
Expand Down
40 changes: 38 additions & 2 deletions verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def update_weights(self, ipc_handles: Optional[dict] = None):
# using restricted unpickler from tensorrt_llm.serialization
logger.info("Deserializing base64-encoded weight handles")
decoded_data = base64.b64decode(serialized_handles)
# Allow basic builtins and all torch modules
# Allow basic builtins and torch tensor reconstruction classes
approved_imports = {
"builtins": [
"list",
Expand All @@ -76,11 +76,47 @@ def update_weights(self, ipc_handles: Optional[dict] = None):
"NoneType",
"type",
],
"torch": [
"Tensor",
"FloatTensor",
"DoubleTensor",
"HalfTensor",
"BFloat16Tensor",
"IntTensor",
"LongTensor",
"ShortTensor",
"CharTensor",
"ByteTensor",
"BoolTensor",
"Size",
"dtype",
"device",
"float32",
"float16",
"int32",
"int64",
"int16",
"int8",
"uint8",
"bool",
],
"torch.multiprocessing.reductions": [
"rebuild_cuda_tensor",
"rebuild_tensor",
],
"torch._utils": [
"_rebuild_tensor_v2",
],
"torch.storage": [
"_load_from_bytes",
"_TypedStorage",
"UntypedStorage",
"TypedStorage",
],
}
all_handles = serialization.loads(
decoded_data,
approved_imports=approved_imports,
approved_module_patterns=[r"^torch.*"],
)

# Verify the result is a list as expected
Expand Down