Skip to content

Commit 1db2672

Browse files
authored
[single_controller] feat: Support dispatch/collect nested tensors with 3 or more dimensions (verl-project#4940)
### What does this PR do? There are 2 errors that prevent dispatching and collecting nested tensors with 3 or more dimensions. #### Dispatch When chunking a `TensorDict` with more than 1 nested tensor that has 3 or more dimensions, re-use of the variable name `td` in the function args and inner loop results in a `KeyError`: https://github.com/volcengine/verl/blob/e204cd80bd0886c75606d4b82ba88eed2658d1c7/verl/utils/tensordict_utils.py#L276-L312 #### Collection When collecting returned `TensorDict`s that have a nested tensor with 3 or more dimensions, there is an assertion enforcing that the nested tensor have exactly 2 dimensions, although the function works for tensors with an arbitrary number of dimensions: https://github.com/volcengine/verl/blob/e204cd80bd0886c75606d4b82ba88eed2658d1c7/verl/utils/tensordict_utils.py#L159-L192 ### Tests The added tests demonstrate both errors when run on `main`: ```python FAILED tests/test_protocol_v2_on_cpu.py::test_concat_nested_tensor - AssertionError: nested tensor must have 2 dimensions. Got torch.Size([2, 4, j32]) FAILED tests/test_protocol_v2_on_cpu.py::test_chunk_tensordict - KeyError: 'key "position_ids" not found in TensorDict with keys [\'attention_mask\', \'input_ids\', \'multi_modal_inputs\']' ```
1 parent b7e5074 commit 1db2672

File tree

2 files changed

+57
-8
lines changed

2 files changed

+57
-8
lines changed

tests/test_protocol_v2_on_cpu.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,7 @@ def test_dataproto_chunk_after_index():
674674

675675

676676
def test_concat_nested_tensor():
677+
# Test 2D nested tensors
677678
vocab_size = 128
678679
a = torch.randint(low=0, high=vocab_size, size=(11,))
679680
b = torch.randint(low=0, high=vocab_size, size=(13,))
@@ -690,6 +691,42 @@ def test_concat_nested_tensor():
690691

691692
assert torch.all(torch.eq(output_values, expected)).item()
692693

694+
# Test 3D nested tensors
695+
a_3d = torch.randint(low=0, high=vocab_size, size=(4, 4))
696+
b_3d = torch.randint(low=0, high=vocab_size, size=(4, 5))
697+
c_3d = torch.randint(low=0, high=vocab_size, size=(4, 6))
698+
d_3d = torch.randint(low=0, high=vocab_size, size=(4, 7))
699+
700+
nested_a_b_3d = torch.nested.as_nested_tensor([a_3d, b_3d], layout=torch.jagged)
701+
nested_c_d_3d = torch.nested.as_nested_tensor([c_3d, d_3d], layout=torch.jagged)
702+
703+
output_3d = tu.concat_nested_tensors([nested_a_b_3d, nested_c_d_3d])
704+
705+
assert output_3d.shape[0] == 4
706+
output_3d_unbind = output_3d.unbind(0)
707+
assert torch.all(torch.eq(output_3d_unbind[0], a_3d)).item()
708+
assert torch.all(torch.eq(output_3d_unbind[1], b_3d)).item()
709+
assert torch.all(torch.eq(output_3d_unbind[2], c_3d)).item()
710+
assert torch.all(torch.eq(output_3d_unbind[3], d_3d)).item()
711+
712+
# Test 4D nested tensors
713+
a_4d = torch.randint(low=0, high=vocab_size, size=(2, 3, 4))
714+
b_4d = torch.randint(low=0, high=vocab_size, size=(2, 3, 5))
715+
c_4d = torch.randint(low=0, high=vocab_size, size=(2, 3, 3))
716+
d_4d = torch.randint(low=0, high=vocab_size, size=(2, 3, 6))
717+
718+
nested_a_b_4d = torch.nested.as_nested_tensor([a_4d, b_4d], layout=torch.jagged)
719+
nested_c_d_4d = torch.nested.as_nested_tensor([c_4d, d_4d], layout=torch.jagged)
720+
721+
output_4d = tu.concat_nested_tensors([nested_a_b_4d, nested_c_d_4d])
722+
723+
assert output_4d.shape[0] == 4
724+
output_4d_unbind = output_4d.unbind(0)
725+
assert torch.all(torch.eq(output_4d_unbind[0], a_4d)).item()
726+
assert torch.all(torch.eq(output_4d_unbind[1], b_4d)).item()
727+
assert torch.all(torch.eq(output_4d_unbind[2], c_4d)).item()
728+
assert torch.all(torch.eq(output_4d_unbind[3], d_4d)).item()
729+
693730

694731
def test_concat_tensordict():
695732
vocab_size = 128
@@ -755,6 +792,15 @@ def test_chunk_tensordict():
755792
input_ids = torch.nested.as_nested_tensor(
756793
[torch.arange(4), torch.arange(5), torch.arange(6), torch.arange(7)], layout=torch.jagged
757794
)
795+
attention_mask = torch.nested.as_nested_tensor(
796+
[
797+
torch.randint(low=0, high=2, size=[3, 4]),
798+
torch.randint(low=0, high=2, size=[3, 5]),
799+
torch.randint(low=0, high=2, size=[3, 6]),
800+
torch.randint(low=0, high=2, size=[3, 7]),
801+
],
802+
layout=torch.jagged,
803+
)
758804

759805
multi_modal_inputs = torch.stack(
760806
[
@@ -768,6 +814,7 @@ def test_chunk_tensordict():
768814
{
769815
"input_ids": input_ids,
770816
"position_ids": position_ids,
817+
"attention_mask": attention_mask,
771818
"multi_modal_inputs": multi_modal_inputs,
772819
},
773820
)

verl/utils/tensordict_utils.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -157,22 +157,22 @@ def get_non_tensor_data(data: TensorDict, key: str, default):
157157

158158

159159
def concat_nested_tensors(tensors: list[torch.Tensor]) -> torch.Tensor:
160-
"""Concatenate multiple 2D nested tensors along the batch dimension.
160+
"""Concatenate multiple nested tensors along the batch dimension.
161161
162162
Takes a list of nested tensors with jagged layout and concatenates them
163-
into a single nested tensor. Each input tensor must be 2D and contiguous.
163+
into a single nested tensor. Each input tensor must have 2 or more dimensions and be contiguous.
164164
165165
Args:
166-
tensors: List of 2D nested tensors to concatenate. All tensors must
167-
be nested, contiguous, and have exactly 2 dimensions.
166+
tensors: List of nested tensors to concatenate. All tensors must
167+
be nested, contiguous, and have 2 or more dimensions.
168168
169169
Returns:
170170
A new nested tensor with jagged layout containing all rows from
171171
the input tensors concatenated along dimension 0.
172172
173173
Raises:
174174
AssertionError: If any tensor is not nested, not contiguous, or
175-
doesn't have exactly 2 dimensions.
175+
doesn't have 2 or more dimensions.
176176
177177
Example:
178178
>>> t1 = torch.nested.as_nested_tensor([torch.randn(3), torch.randn(5)], layout=torch.jagged)
@@ -184,7 +184,7 @@ def concat_nested_tensors(tensors: list[torch.Tensor]) -> torch.Tensor:
184184
assert tensor.is_nested and tensor.is_contiguous()
185185
unbind_tensors = []
186186
for tensor in tensors:
187-
assert len(tensor.shape) == 2, f"nested tensor must have 2 dimensions. Got {tensor.shape}"
187+
assert len(tensor.shape) >= 2, f"nested tensor must have 2 or more dimensions. Got {tensor.shape}"
188188
unbind_tensor = tensor.unbind(0)
189189
unbind_tensors.extend(list(unbind_tensor))
190190

@@ -306,8 +306,10 @@ def chunk_tensordict(td: TensorDict, chunks: int) -> list[TensorDict]:
306306
tds = new_td.chunk(chunks=chunks)
307307
for key in keys:
308308
tensors = td[key].unbind(dim=0)
309-
for i, td in enumerate(tds):
310-
td[key] = torch.nested.as_nested_tensor(tensors[i * chunk_size : (i + 1) * chunk_size], layout=torch.jagged)
309+
for i, chunk_td in enumerate(tds):
310+
chunk_td[key] = torch.nested.as_nested_tensor(
311+
tensors[i * chunk_size : (i + 1) * chunk_size], layout=torch.jagged
312+
)
311313

312314
return tds
313315

0 commit comments

Comments
 (0)