[single_controller] feat: support dispatch tensordict#4213
[single_controller] feat: support dispatch tensordict#4213vermouth1992 merged 5 commits intoverl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request adds support for dispatching TensorDict objects, including those with nested tensors, within the single_controller framework. The changes span across the core decorator logic, utility functions for TensorDict, and corresponding test files. My review focuses on improving the robustness and maintainability of the new utility functions and ensuring correctness in the updated dispatch logic. I've identified a critical issue related to input mutation in a new utility function, as well as some high-severity issues concerning error messaging, local imports, and overly restrictive assertions.
| def concat_tensordict(data: list[TensorDict]) -> TensorDict: | ||
| """Concatenates tensordicts into a single tensordict on dim zero. Support nested tensor""" | ||
|
|
||
| # pop all the nested tensor if any | ||
| nested_tensors = {} | ||
|
|
||
| # find nested tensor | ||
| for key in data[0].keys(): | ||
| tensor = data[0][key] | ||
| if isinstance(tensor, torch.Tensor) and tensor.is_nested: | ||
| nested_tensors[key] = [] | ||
| for d in data: | ||
| assert d[key].is_nested | ||
|
|
||
| for key in nested_tensors.keys(): | ||
| for d in data: | ||
| nested_tensors[key].append(d.pop(key)) | ||
|
|
||
| # concat nested tensor | ||
| for key in nested_tensors.keys(): | ||
| nested_tensors[key] = concat_nested_tensors(nested_tensors[key]) | ||
|
|
||
| # concat reset | ||
| output = TensorDict.cat(data, dim=0) | ||
|
|
||
| # put together | ||
| for key in nested_tensors.keys(): | ||
| output[key] = nested_tensors[key] | ||
|
|
||
| return output |
There was a problem hiding this comment.
This function has two critical issues:
- It will raise an
IndexErrorif the inputdatalist is empty, as it tries to accessdata[0]. - It modifies the
TensorDictobjects in the input listdatain-place by callingd.pop(key). Mutating input arguments in a utility function is a dangerous side effect that can lead to subtle and hard-to-debug issues in the calling code.
The suggested code below fixes both issues by adding a check for an empty list and by operating on shallow copies of the input TensorDict objects.
def concat_tensordict(data: list[TensorDict]) -> TensorDict:
"""Concatenates tensordicts into a single tensordict on dim zero. Support nested tensor"""
if not data:
return TensorDict({}, batch_size=[0])
# Create shallow copies to avoid modifying original tensordicts
data_copies = [d.clone(recurse=False) for d in data]
# pop all the nested tensor if any
nested_tensors = {}
# find nested tensor
for key in data_copies[0].keys():
tensor = data_copies[0][key]
if isinstance(tensor, torch.Tensor) and tensor.is_nested:
nested_tensors[key] = []
for d in data_copies:
assert d[key].is_nested
for key in nested_tensors.keys():
for d in data_copies:
nested_tensors[key].append(d.pop(key))
# concat nested tensor
for key in nested_tensors.keys():
nested_tensors[key] = concat_nested_tensors(nested_tensors[key])
# concat reset
output = TensorDict.cat(data_copies, dim=0)
# put together
output.update(nested_tensors)
return output| from verl.utils.tensordict_utils import concat_tensordict | ||
|
|
||
| return concat_tensordict(output) |
There was a problem hiding this comment.
The import of concat_tensordict is performed locally. According to PEP 8, imports should be at the top of the file to make dependencies clear. Since there appears to be no circular dependency here, please move this import to the top level of the module. This will improve readability and maintainability.
Note: You will need to add from verl.utils.tensordict_utils import concat_tensordict to the top-level imports of this file.
| from verl.utils.tensordict_utils import concat_tensordict | |
| return concat_tensordict(output) | |
| return concat_tensordict(output) |
| assert isinstance(o, DataProto | ray.ObjectRef | BatchMeta | TensorDict), ( | ||
| f"expecting {o} to be DataProto or BatchMeta, but got {type(o)}" | ||
| ) |
There was a problem hiding this comment.
The error message in this assertion is outdated and incomplete. It should include TensorDict and ray.ObjectRef as possible types, consistent with the isinstance check. Keeping error messages accurate is important for debugging.
| assert isinstance(o, DataProto | ray.ObjectRef | BatchMeta | TensorDict), ( | |
| f"expecting {o} to be DataProto or BatchMeta, but got {type(o)}" | |
| ) | |
| assert isinstance(o, DataProto | ray.ObjectRef | BatchMeta | TensorDict), ( | |
| f"expecting {o} to be DataProto, ray.ObjectRef, BatchMeta, or TensorDict, but got {type(o)}" | |
| ) |
| assert tensor.is_nested and tensor.is_contiguous() | ||
| unbind_tensors = [] | ||
| for tensor in tensors: | ||
| assert len(tensor.shape) == 2 |
There was a problem hiding this comment.
The assertion assert len(tensor.shape) == 2 is overly restrictive, limiting this function to nested tensors created from 1D tensors (like token sequences). The tensor.unbind(0) logic should work correctly for higher-dimensional nested tensors as well (e.g., a batch of images with varying sizes). Removing this assertion would make this utility function more general and reusable without compromising correctness for the current use cases.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request adds support for dispatching TensorDict objects, including those with nested tensors. The changes span across the single controller's decorator logic and utility functions for TensorDict. New tests have been added to validate the functionality. My main feedback is regarding the implementation of concat_tensordict, which currently modifies its input arguments. This can lead to unexpected side effects and makes the code harder to reason about. I've suggested a refactoring to make the function pure, improving its robustness and predictability. Other changes look good and are well-tested.
| def concat_tensordict(data: list[TensorDict]) -> TensorDict: | ||
| """Concatenates tensordicts into a single tensordict on dim zero. Support nested tensor""" | ||
| assert len(data) > 0, "Must have at least one tensordict" | ||
|
|
||
| # pop all the nested tensor if any | ||
| nested_tensors = {} | ||
|
|
||
| # find nested tensor | ||
| for key in data[0].keys(): | ||
| tensor = data[0][key] | ||
| if isinstance(tensor, torch.Tensor) and tensor.is_nested: | ||
| nested_tensors[key] = [] | ||
| for d in data: | ||
| assert d[key].is_nested | ||
|
|
||
| for key in nested_tensors.keys(): | ||
| for d in data: | ||
| nested_tensors[key].append(d.pop(key)) | ||
|
|
||
| # concat reset | ||
| output = TensorDict.cat(data, dim=0) | ||
|
|
||
| # concat nested tensor | ||
| for key in nested_tensors.keys(): | ||
| output[key] = concat_nested_tensors(nested_tensors[key]) | ||
| # add nested tensor back | ||
| for i, d in enumerate(data): | ||
| d[key] = nested_tensors[key][i] | ||
|
|
||
| return output |
There was a problem hiding this comment.
The concat_tensordict function modifies its input data by popping keys. This is a side effect that can lead to unexpected behavior, especially if an error occurs during concatenation, leaving the input TensorDict objects in a corrupted state. The associated test test_concat_tensordict verifies that the inputs are not modified, which indicates that the intended behavior is to be side-effect-free.
To make the function more robust and predictable, it should be refactored to avoid modifying its inputs. A pure function without side effects is easier to reason about and less prone to bugs. I've suggested an alternative implementation that achieves the same result without mutating the input data.
def concat_tensordict(data: list[TensorDict]) -> TensorDict:
"""Concatenates tensordicts into a single tensordict on dim zero. Support nested tensor"""
assert len(data) > 0, "Must have at least one tensordict"
# Find nested tensor keys from the first tensordict
nested_tensor_keys = {
key for key, value in data[0].items() if isinstance(value, torch.Tensor) and value.is_nested
}
if not nested_tensor_keys:
return TensorDict.cat(data, dim=0)
# Create a list of tensordicts containing only non-nested tensors for concatenation
regular_tds = []
for td in data:
current_nested_keys = {k for k, v in td.items() if isinstance(v, torch.Tensor) and v.is_nested}
assert current_nested_keys == nested_tensor_keys, "All tensordicts must have the same set of nested tensors."
# Create a new TensorDict with non-nested items without modifying the original
regular_items = {k: v for k, v in td.items() if k not in nested_tensor_keys}
regular_tds.append(TensorDict(regular_items, batch_size=td.batch_size, device=td.device))
# Concatenate the regular tensordicts
output = TensorDict.cat(regular_tds, dim=0)
# Concatenate and add nested tensors to the output
for key in nested_tensor_keys:
nested_tensors_to_concat = [td[key] for td in data]
output[key] = concat_nested_tensors(nested_tensors_to_concat)
return output|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces support for dispatching TensorDict objects, including those with nested tensors, within the single controller framework. The changes are well-structured, touching utility functions, decorators, and adding comprehensive tests. The core logic for handling TensorDict and nested tensor concatenation is implemented in verl/utils/tensordict_utils.py. The modifications to verl/single_controller/base/decorator.py correctly integrate TensorDict into the dispatch mechanism. The new tests in tests/single_controller/test_device_mesh_register.py and tests/test_protocol_v2_on_cpu.py effectively validate the new functionality. I have found one critical issue in the implementation of concat_nested_tensors that will cause it to fail for valid inputs. Please see the detailed comment.
| assert tensor.is_nested and tensor.is_contiguous() | ||
| unbind_tensors = [] | ||
| for tensor in tensors: | ||
| assert len(tensor.shape) == 2 |
There was a problem hiding this comment.
The assertion len(tensor.shape) == 2 is incorrect for a nested tensor created from a list of 1D tensors. For such a nested tensor, len(tensor.shape) is 1 (representing the batch dimension), while tensor.dim() is 2 (batch dimension + ragged dimension). This assertion will fail for valid 2D nested tensors. It should be changed to tensor.dim() == 2 to correctly check for a 2D nested tensor structure.
| assert len(tensor.shape) == 2 | |
| assert tensor.dim() == 2 |
) ### What does this PR do? - Support dispatch tensordict including nested tensor ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
) ### What does this PR do? - Support dispatch tensordict including nested tensor ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
) ### What does this PR do? - Support dispatch tensordict including nested tensor ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
What does this PR do?
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)