Skip to content

[single_controller] feat: support dispatch tensordict#4213

Merged
vermouth1992 merged 5 commits intoverl-project:mainfrom
vermouth1992:chi/dev/dispatch_td
Nov 21, 2025
Merged

[single_controller] feat: support dispatch tensordict#4213
vermouth1992 merged 5 commits intoverl-project:mainfrom
vermouth1992:chi/dev/dispatch_td

Conversation

@vermouth1992
Copy link
Collaborator

What does this PR do?

  • Support dispatch tensordict including nested tensor

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

# Add code snippet or script demonstrating how to use this

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

@vermouth1992 vermouth1992 requested review from PeterSH6 and wuxibin89 and removed request for wuxibin89 November 20, 2025 15:35
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for dispatching TensorDict objects, including those with nested tensors, within the single_controller framework. The changes span across the core decorator logic, utility functions for TensorDict, and corresponding test files. My review focuses on improving the robustness and maintainability of the new utility functions and ensuring correctness in the updated dispatch logic. I've identified a critical issue related to input mutation in a new utility function, as well as some high-severity issues concerning error messaging, local imports, and overly restrictive assertions.

Comment on lines 63 to 92
def concat_tensordict(data: list[TensorDict]) -> TensorDict:
"""Concatenates tensordicts into a single tensordict on dim zero. Support nested tensor"""

# pop all the nested tensor if any
nested_tensors = {}

# find nested tensor
for key in data[0].keys():
tensor = data[0][key]
if isinstance(tensor, torch.Tensor) and tensor.is_nested:
nested_tensors[key] = []
for d in data:
assert d[key].is_nested

for key in nested_tensors.keys():
for d in data:
nested_tensors[key].append(d.pop(key))

# concat nested tensor
for key in nested_tensors.keys():
nested_tensors[key] = concat_nested_tensors(nested_tensors[key])

# concat reset
output = TensorDict.cat(data, dim=0)

# put together
for key in nested_tensors.keys():
output[key] = nested_tensors[key]

return output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This function has two critical issues:

  1. It will raise an IndexError if the input data list is empty, as it tries to access data[0].
  2. It modifies the TensorDict objects in the input list data in-place by calling d.pop(key). Mutating input arguments in a utility function is a dangerous side effect that can lead to subtle and hard-to-debug issues in the calling code.

The suggested code below fixes both issues by adding a check for an empty list and by operating on shallow copies of the input TensorDict objects.

def concat_tensordict(data: list[TensorDict]) -> TensorDict:
    """Concatenates tensordicts into a single tensordict on dim zero. Support nested tensor"""
    if not data:
        return TensorDict({}, batch_size=[0])

    # Create shallow copies to avoid modifying original tensordicts
    data_copies = [d.clone(recurse=False) for d in data]

    # pop all the nested tensor if any
    nested_tensors = {}

    # find nested tensor
    for key in data_copies[0].keys():
        tensor = data_copies[0][key]
        if isinstance(tensor, torch.Tensor) and tensor.is_nested:
            nested_tensors[key] = []
            for d in data_copies:
                assert d[key].is_nested

    for key in nested_tensors.keys():
        for d in data_copies:
            nested_tensors[key].append(d.pop(key))

    # concat nested tensor
    for key in nested_tensors.keys():
        nested_tensors[key] = concat_nested_tensors(nested_tensors[key])

    # concat reset
    output = TensorDict.cat(data_copies, dim=0)

    # put together
    output.update(nested_tensors)

    return output

Comment on lines 159 to 161
from verl.utils.tensordict_utils import concat_tensordict

return concat_tensordict(output)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The import of concat_tensordict is performed locally. According to PEP 8, imports should be at the top of the file to make dependencies clear. Since there appears to be no circular dependency here, please move this import to the top level of the module. This will improve readability and maintainability.

Note: You will need to add from verl.utils.tensordict_utils import concat_tensordict to the top-level imports of this file.

Suggested change
from verl.utils.tensordict_utils import concat_tensordict
return concat_tensordict(output)
return concat_tensordict(output)

Comment on lines 281 to 283
assert isinstance(o, DataProto | ray.ObjectRef | BatchMeta | TensorDict), (
f"expecting {o} to be DataProto or BatchMeta, but got {type(o)}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The error message in this assertion is outdated and incomplete. It should include TensorDict and ray.ObjectRef as possible types, consistent with the isinstance check. Keeping error messages accurate is important for debugging.

Suggested change
assert isinstance(o, DataProto | ray.ObjectRef | BatchMeta | TensorDict), (
f"expecting {o} to be DataProto or BatchMeta, but got {type(o)}"
)
assert isinstance(o, DataProto | ray.ObjectRef | BatchMeta | TensorDict), (
f"expecting {o} to be DataProto, ray.ObjectRef, BatchMeta, or TensorDict, but got {type(o)}"
)

assert tensor.is_nested and tensor.is_contiguous()
unbind_tensors = []
for tensor in tensors:
assert len(tensor.shape) == 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The assertion assert len(tensor.shape) == 2 is overly restrictive, limiting this function to nested tensors created from 1D tensors (like token sequences). The tensor.unbind(0) logic should work correctly for higher-dimensional nested tensors as well (e.g., a batch of images with varying sizes). Removing this assertion would make this utility function more general and reusable without compromising correctness for the current use cases.

@vermouth1992
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for dispatching TensorDict objects, including those with nested tensors. The changes span across the single controller's decorator logic and utility functions for TensorDict. New tests have been added to validate the functionality. My main feedback is regarding the implementation of concat_tensordict, which currently modifies its input arguments. This can lead to unexpected side effects and makes the code harder to reason about. I've suggested a refactoring to make the function pure, improving its robustness and predictability. Other changes look good and are well-tested.

Comment on lines 63 to 92
def concat_tensordict(data: list[TensorDict]) -> TensorDict:
"""Concatenates tensordicts into a single tensordict on dim zero. Support nested tensor"""
assert len(data) > 0, "Must have at least one tensordict"

# pop all the nested tensor if any
nested_tensors = {}

# find nested tensor
for key in data[0].keys():
tensor = data[0][key]
if isinstance(tensor, torch.Tensor) and tensor.is_nested:
nested_tensors[key] = []
for d in data:
assert d[key].is_nested

for key in nested_tensors.keys():
for d in data:
nested_tensors[key].append(d.pop(key))

# concat reset
output = TensorDict.cat(data, dim=0)

# concat nested tensor
for key in nested_tensors.keys():
output[key] = concat_nested_tensors(nested_tensors[key])
# add nested tensor back
for i, d in enumerate(data):
d[key] = nested_tensors[key][i]

return output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The concat_tensordict function modifies its input data by popping keys. This is a side effect that can lead to unexpected behavior, especially if an error occurs during concatenation, leaving the input TensorDict objects in a corrupted state. The associated test test_concat_tensordict verifies that the inputs are not modified, which indicates that the intended behavior is to be side-effect-free.

To make the function more robust and predictable, it should be refactored to avoid modifying its inputs. A pure function without side effects is easier to reason about and less prone to bugs. I've suggested an alternative implementation that achieves the same result without mutating the input data.

def concat_tensordict(data: list[TensorDict]) -> TensorDict:
    """Concatenates tensordicts into a single tensordict on dim zero. Support nested tensor"""
    assert len(data) > 0, "Must have at least one tensordict"

    # Find nested tensor keys from the first tensordict
    nested_tensor_keys = {
        key for key, value in data[0].items() if isinstance(value, torch.Tensor) and value.is_nested
    }

    if not nested_tensor_keys:
        return TensorDict.cat(data, dim=0)

    # Create a list of tensordicts containing only non-nested tensors for concatenation
    regular_tds = []
    for td in data:
        current_nested_keys = {k for k, v in td.items() if isinstance(v, torch.Tensor) and v.is_nested}
        assert current_nested_keys == nested_tensor_keys, "All tensordicts must have the same set of nested tensors."

        # Create a new TensorDict with non-nested items without modifying the original
        regular_items = {k: v for k, v in td.items() if k not in nested_tensor_keys}
        regular_tds.append(TensorDict(regular_items, batch_size=td.batch_size, device=td.device))

    # Concatenate the regular tensordicts
    output = TensorDict.cat(regular_tds, dim=0)

    # Concatenate and add nested tensors to the output
    for key in nested_tensor_keys:
        nested_tensors_to_concat = [td[key] for td in data]
        output[key] = concat_nested_tensors(nested_tensors_to_concat)

    return output

@vermouth1992
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for dispatching TensorDict objects, including those with nested tensors, within the single controller framework. The changes are well-structured, touching utility functions, decorators, and adding comprehensive tests. The core logic for handling TensorDict and nested tensor concatenation is implemented in verl/utils/tensordict_utils.py. The modifications to verl/single_controller/base/decorator.py correctly integrate TensorDict into the dispatch mechanism. The new tests in tests/single_controller/test_device_mesh_register.py and tests/test_protocol_v2_on_cpu.py effectively validate the new functionality. I have found one critical issue in the implementation of concat_nested_tensors that will cause it to fail for valid inputs. Please see the detailed comment.

assert tensor.is_nested and tensor.is_contiguous()
unbind_tensors = []
for tensor in tensors:
assert len(tensor.shape) == 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The assertion len(tensor.shape) == 2 is incorrect for a nested tensor created from a list of 1D tensors. For such a nested tensor, len(tensor.shape) is 1 (representing the batch dimension), while tensor.dim() is 2 (batch dimension + ragged dimension). This assertion will fail for valid 2D nested tensors. It should be changed to tensor.dim() == 2 to correctly check for a 2D nested tensor structure.

Suggested change
assert len(tensor.shape) == 2
assert tensor.dim() == 2

@vermouth1992 vermouth1992 enabled auto-merge (squash) November 20, 2025 16:08
@vermouth1992 vermouth1992 merged commit 7851640 into verl-project:main Nov 21, 2025
84 of 86 checks passed
@vermouth1992 vermouth1992 deleted the chi/dev/dispatch_td branch November 21, 2025 02:54
Di-viner pushed a commit to Di-viner/verl that referenced this pull request Nov 30, 2025
)

### What does this PR do?

- Support dispatch tensordict including nested tensor

### Checklist Before Starting

- [ ] Search for similar PRs. Paste at least one query link here: ...
- [ ] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
TimurTaepov pushed a commit to giorgossideris/verl that referenced this pull request Dec 20, 2025
)

### What does this PR do?

- Support dispatch tensordict including nested tensor

### Checklist Before Starting

- [ ] Search for similar PRs. Paste at least one query link here: ...
- [ ] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
vyomakesh0728 added a commit to vyomakesh0728/verl that referenced this pull request Jan 22, 2026
)

### What does this PR do?

- Support dispatch tensordict including nested tensor

### Checklist Before Starting

- [ ] Search for similar PRs. Paste at least one query link here: ...
- [ ] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants