-
Notifications
You must be signed in to change notification settings - Fork 3.2k
[rollout] feat: Fix partial load problem, Add vlm support for trtllm rollout #5149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
dcaacfe
0394ab5
0664ab1
bf71c9b
f6e58b8
7af6917
94c4eb0
25518fe
fd007fb
659ec01
55b55dc
e2cc50b
ca17f8a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,16 @@ | ||
| # Copyright 2026 Bytedance Ltd. and/or its affiliates | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import base64 | ||
| import inspect | ||
| import pickle | ||
|
|
@@ -10,7 +23,6 @@ | |
|
|
||
|
|
||
| class WorkerExtension: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here is the latest WorkerExtension in the TensorRT-LLM repo. Are there any motivations for implementing a new one in verl repo? I am thinking about how to unify both. Ideally, we may update the one in the TensorRT-LLM codebase, but if we need a minor change on it before the next trtllm version bump up, @hchings do you have a suggestion?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah. Ideally, we should still use the worker extension from the tedsnort-llm repo. But to support model that do not allow partial loading, I suppose the use of self.engine.model_engine.model_loader.reload should be able to use with param: allow_partial_loading=False
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I add this new worker extension is to support allow_partial_loading=False, cause tensorrt-llm always set this param as True, but some models do not support
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll prefer that we keep this at TensorRT-LLM repo instead and make it generic for other RL FWs to reuse in the future. |
||
|
|
||
| def __init__(self): | ||
| pass | ||
|
|
||
|
|
@@ -30,11 +42,9 @@ def update_weights(self, ipc_handles: Optional[dict] = None): | |
| try: | ||
| if not hasattr(self.engine.model_engine.model, "first_pre_reload_weights"): | ||
| for module in self.engine.model_engine.model.modules(): | ||
| if hasattr(module, "pre_reload_weights") and not getattr( | ||
| module, "_weights_removed", False | ||
| ): | ||
| if hasattr(module, "pre_reload_weights") and not getattr(module, "_weights_removed", False): | ||
| module.pre_reload_weights() | ||
| setattr(self.engine.model_engine.model, "first_pre_reload_weights", True) | ||
| self.engine.model_engine.model.first_pre_reload_weights = True | ||
|
|
||
| if ipc_handles is not None: | ||
| device_uuid = get_device_uuid() | ||
|
|
@@ -46,22 +56,16 @@ def update_weights(self, ipc_handles: Optional[dict] = None): | |
| supports_partial_loading = "allow_partial_loading" in load_weights_args | ||
|
|
||
| if supports_partial_loading: | ||
| self.engine.model_engine.model_loader.reload( | ||
| model, weights, allow_partial_loading=True | ||
| ) | ||
| self.engine.model_engine.model_loader.reload(model, weights, allow_partial_loading=True) | ||
| else: | ||
| self.engine.model_engine.model_loader.reload( | ||
| model, weights, allow_partial_loading=False | ||
| ) | ||
| self.engine.model_engine.model_loader.reload(model, weights, allow_partial_loading=False) | ||
| else: | ||
| for module in self.engine.model_engine.model.modules(): | ||
| if hasattr(module, "process_weights_after_loading") and not getattr( | ||
| module, "_weights_removed", False | ||
| ): | ||
| module.process_weights_after_loading() | ||
| if hasattr(module, "post_load_weights") and not getattr( | ||
| module, "_weights_removed", False | ||
| ): | ||
| if hasattr(module, "post_load_weights") and not getattr(module, "_weights_removed", False): | ||
| module.post_load_weights() | ||
| moe_load_balancer = getattr(self.engine.model_engine, "moe_load_balancer", None) | ||
| if isinstance(moe_load_balancer, MoeLoadBalancer): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a unittest for this new feature? There is a
test_trtllm_async_server.py.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, I'm adding one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test script and relating test workflow has been added
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't find the test_trtllm_async_server.py in verl repo, so I write a test script for test on both llm rollout and vlm rollout of tensorrt-llm rollout worker
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have a unittest MR that should be merge shortly, that contains the test_trtllm_async_server.py.