Commit 1db2672
authored
[single_controller] feat: Support dispatch/collect nested tensors with 3 or more dimensions (verl-project#4940)
### What does this PR do?
There are 2 errors that prevent dispatching and collecting nested
tensors with 3 or more dimensions.
#### Dispatch
When chunking a `TensorDict` with more than 1 nested tensor that has 3
or more dimensions, re-use of the variable name `td` in the function
args and inner loop results in a `KeyError`:
https://github.com/volcengine/verl/blob/e204cd80bd0886c75606d4b82ba88eed2658d1c7/verl/utils/tensordict_utils.py#L276-L312
#### Collection
When collecting returned `TensorDict`s that have a nested tensor with 3
or more dimensions, there is an assertion enforcing that the nested
tensor have exactly 2 dimensions, although the function works for
tensors with an arbitrary number of dimensions:
https://github.com/volcengine/verl/blob/e204cd80bd0886c75606d4b82ba88eed2658d1c7/verl/utils/tensordict_utils.py#L159-L192
### Tests
The added tests demonstrate both errors when run on `main`:
```python
FAILED tests/test_protocol_v2_on_cpu.py::test_concat_nested_tensor - AssertionError: nested tensor must have 2 dimensions. Got torch.Size([2, 4, j32])
FAILED tests/test_protocol_v2_on_cpu.py::test_chunk_tensordict - KeyError: 'key "position_ids" not found in TensorDict with keys [\'attention_mask\', \'input_ids\', \'multi_modal_inputs\']'
```1 parent b7e5074 commit 1db2672
File tree
2 files changed
+57
-8
lines changed- tests
- verl/utils
2 files changed
+57
-8
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
674 | 674 | | |
675 | 675 | | |
676 | 676 | | |
| 677 | + | |
677 | 678 | | |
678 | 679 | | |
679 | 680 | | |
| |||
690 | 691 | | |
691 | 692 | | |
692 | 693 | | |
| 694 | + | |
| 695 | + | |
| 696 | + | |
| 697 | + | |
| 698 | + | |
| 699 | + | |
| 700 | + | |
| 701 | + | |
| 702 | + | |
| 703 | + | |
| 704 | + | |
| 705 | + | |
| 706 | + | |
| 707 | + | |
| 708 | + | |
| 709 | + | |
| 710 | + | |
| 711 | + | |
| 712 | + | |
| 713 | + | |
| 714 | + | |
| 715 | + | |
| 716 | + | |
| 717 | + | |
| 718 | + | |
| 719 | + | |
| 720 | + | |
| 721 | + | |
| 722 | + | |
| 723 | + | |
| 724 | + | |
| 725 | + | |
| 726 | + | |
| 727 | + | |
| 728 | + | |
| 729 | + | |
693 | 730 | | |
694 | 731 | | |
695 | 732 | | |
| |||
755 | 792 | | |
756 | 793 | | |
757 | 794 | | |
| 795 | + | |
| 796 | + | |
| 797 | + | |
| 798 | + | |
| 799 | + | |
| 800 | + | |
| 801 | + | |
| 802 | + | |
| 803 | + | |
758 | 804 | | |
759 | 805 | | |
760 | 806 | | |
| |||
768 | 814 | | |
769 | 815 | | |
770 | 816 | | |
| 817 | + | |
771 | 818 | | |
772 | 819 | | |
773 | 820 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
157 | 157 | | |
158 | 158 | | |
159 | 159 | | |
160 | | - | |
| 160 | + | |
161 | 161 | | |
162 | 162 | | |
163 | | - | |
| 163 | + | |
164 | 164 | | |
165 | 165 | | |
166 | | - | |
167 | | - | |
| 166 | + | |
| 167 | + | |
168 | 168 | | |
169 | 169 | | |
170 | 170 | | |
171 | 171 | | |
172 | 172 | | |
173 | 173 | | |
174 | 174 | | |
175 | | - | |
| 175 | + | |
176 | 176 | | |
177 | 177 | | |
178 | 178 | | |
| |||
184 | 184 | | |
185 | 185 | | |
186 | 186 | | |
187 | | - | |
| 187 | + | |
188 | 188 | | |
189 | 189 | | |
190 | 190 | | |
| |||
306 | 306 | | |
307 | 307 | | |
308 | 308 | | |
309 | | - | |
310 | | - | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
311 | 313 | | |
312 | 314 | | |
313 | 315 | | |
| |||
0 commit comments