Skip to content

Commit 77ef1db

Browse files
[misc] fix: fix list conversion in get_tensordict (#4304)
--- ## What does this PR do? This PR fixes a `ValueError` that occurs when converting `DataProto` containing nested Python structures (lists of lists, lists of dicts, etc.) to `TensorDict`. The issue manifested during distributed training when `non_tensor_batch` fields like `turn_scores`, `reward_extra_info`, `raw_prompt`, and `tool_rewards` contained nested structures that `TensorDict` couldn't handle directly. **Root Cause:** `TensorDict` cannot accept raw nested Python objects like `[[], [0.5, 0.8]]` or `[{"acc": 1.0}, {"acc": 0.0}]`. These must be wrapped using `NonTensorData` and organized into `NonTensorStack` for proper handling. **Solution:** - Explicitly wrap each element in nested lists with `NonTensorData` before creating `NonTensorStack` - Added helper functions `assign_non_tensor_stack()` and `assign_non_tensor()` in `tensordict_utils.py` - Updated `DataProto.to_tensordict()` and `DataProto.from_tensordict()` for proper round-trip conversion - Added automatic nested structure detection in `get_tensordict()` Previous PR: [4296 ](#4296) --- ## Test ### Unit Tests Added **`tests/test_protocol_v2_on_cpu.py`** (8 new tests): - `test_assign_non_tensor_stack_with_nested_lists` - Lists of lists - `test_assign_non_tensor_stack_with_nested_dicts` - Lists of dicts - `test_assign_non_tensor_stack_with_complex_nested` - Lists of lists of dicts - `test_assign_non_tensor_with_auto_detection` - Auto type detection - `test_get_tensordict_with_nested_lists` - Integration with get_tensordict - `test_get_tensordict_with_nested_dicts` - Integration with get_tensordict - `test_get_tensordict_with_complex_nested_structures` - Complex nested case - `test_get_tensordict_agent_loop_scenario` - Real-world agent loop scenario ### How to Run Tests ```bash # Test tensordict_utils nested structure support pytest third_party/open_verl/tests/test_protocol_v2_on_cpu.py -v ``` ### Validation ✅ All new tests pass ✅ Existing tests remain passing ✅ Successfully handles empty lists in nested structures (e.g., `turn_scores = [[], [0.5, 0.8]]`) ✅ Round-trip conversion (DataProto → TensorDict → DataProto) preserves data integrity --- ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 7dd1245 commit 77ef1db

File tree

2 files changed

+286
-15
lines changed

2 files changed

+286
-15
lines changed

tests/test_protocol_v2_on_cpu.py

Lines changed: 203 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import numpy as np
2323
import pytest
2424
import torch
25+
from tensordict.tensorclass import NonTensorData, NonTensorStack
2526

2627
from verl.utils import tensordict_utils as tu
2728

@@ -45,10 +46,10 @@ def test_union_tensor_dict():
4546
# conflict in tensor values
4647
tu.union_tensor_dict(data1, data_with_copied_obs)
4748

48-
data1 = tu.assign_non_tensor_dict(data1, meta_info1)
49+
data1 = tu.assign_non_tensor(data1, **meta_info1)
4950
tu.union_tensor_dict(data1, data2) # works ok
5051

51-
data2 = tu.assign_non_tensor_dict(data2, meta_info2)
52+
data2 = tu.assign_non_tensor(data2, **meta_info2)
5253

5354
with pytest.raises(AssertionError):
5455
# conflict in NonTensorData
@@ -651,3 +652,203 @@ def test_concat_tensordict():
651652
# make sure tensordict1 and tensordict2 is untouched
652653
tu.assert_tensordict_eq(tensordict1, tensordict1_copy)
653654
tu.assert_tensordict_eq(tensordict2, tensordict2_copy)
655+
656+
657+
def test_assign_non_tensor_stack_with_nested_lists():
658+
"""Test assign_non_tensor_stack with lists of lists."""
659+
td = tu.get_tensordict({"obs": torch.randn(3, 4)}, non_tensor_dict={})
660+
661+
# Lists of varying lengths (like turn_scores or tool_rewards)
662+
turn_scores = [[], [0.5, 0.8], [0.9]]
663+
tu.assign_non_tensor_stack(td, "turn_scores", turn_scores)
664+
665+
# Verify data is accessible
666+
assert len(td["turn_scores"]) == 3
667+
assert list(td["turn_scores"][0]) == []
668+
assert list(td["turn_scores"][1]) == [0.5, 0.8]
669+
assert list(td["turn_scores"][2]) == [0.9]
670+
671+
672+
def test_assign_non_tensor_stack_with_nested_dicts():
673+
"""Test assign_non_tensor_stack with lists of dicts."""
674+
td = tu.get_tensordict({"obs": torch.randn(3, 4)}, non_tensor_dict={})
675+
676+
# Lists of dicts (like reward_extra_info)
677+
reward_extra_info = [{"acc": 1.0, "loss": 0.1}, {"acc": 0.0, "loss": 0.9}, {"acc": 1.0, "loss": 0.05}]
678+
tu.assign_non_tensor_stack(td, "reward_extra_info", reward_extra_info)
679+
680+
# Verify data is accessible
681+
assert len(td["reward_extra_info"]) == 3
682+
assert dict(td["reward_extra_info"][0]) == {"acc": 1.0, "loss": 0.1}
683+
assert dict(td["reward_extra_info"][1]) == {"acc": 0.0, "loss": 0.9}
684+
assert dict(td["reward_extra_info"][2]) == {"acc": 1.0, "loss": 0.05}
685+
686+
687+
def test_assign_non_tensor_stack_with_complex_nested():
688+
"""Test assign_non_tensor_stack with lists of lists of dicts."""
689+
td = tu.get_tensordict({"obs": torch.randn(2, 4)}, non_tensor_dict={})
690+
691+
# Lists of lists of dicts (like raw_prompt)
692+
raw_prompt = [
693+
[{"content": "Question 1", "role": "user"}],
694+
[{"content": "Question 2", "role": "user"}, {"content": "Answer 2", "role": "assistant"}],
695+
]
696+
tu.assign_non_tensor_stack(td, "raw_prompt", raw_prompt)
697+
698+
# Verify data is accessible
699+
assert len(td["raw_prompt"]) == 2
700+
assert len(td["raw_prompt"][0]) == 1
701+
assert dict(td["raw_prompt"][0][0]) == {"content": "Question 1", "role": "user"}
702+
assert len(td["raw_prompt"][1]) == 2
703+
assert dict(td["raw_prompt"][1][0]) == {"content": "Question 2", "role": "user"}
704+
705+
706+
def test_assign_non_tensor_handles_wrappers():
707+
td = tu.get_tensordict({"obs": torch.randn(3, 4)}, non_tensor_dict={})
708+
709+
meta = {"top_p": 0.8}
710+
tu.assign_non_tensor(td, **meta)
711+
assert td["top_p"] == 0.8
712+
713+
wrapped = NonTensorData(0.3)
714+
stack = NonTensorStack.from_list([NonTensorData(1.0), NonTensorData(2.0), NonTensorData(3.0)])
715+
tu.assign_non_tensor(td, wrapped=wrapped, stack=stack)
716+
717+
assert td["wrapped"] == 0.3
718+
assert td["stack"] == [1.0, 2.0, 3.0]
719+
720+
721+
def test_assign_non_tensor_stack_batch_size_check():
722+
td = tu.get_tensordict({"obs": torch.randn(3, 4)}, non_tensor_dict={})
723+
stack = NonTensorStack.from_list([NonTensorData(1.0), NonTensorData(2.0)])
724+
725+
with pytest.raises(RuntimeError):
726+
tu.assign_non_tensor(td, stack=stack)
727+
728+
729+
def test_assign_non_tensor_with_auto_detection():
730+
"""Test assign_non_tensor automatically detects and handles nested structures."""
731+
td = tu.get_tensordict({"obs": torch.randn(3, 4)}, non_tensor_dict={})
732+
733+
# Mix of simple and nested data
734+
tu.assign_non_tensor(
735+
td,
736+
metadata="experiment_1", # Simple value
737+
turn_scores=[[], [0.5, 0.8], [0.9]], # Nested list
738+
reward_extra_info=[{"acc": 1.0}, {"acc": 0.0}, {"acc": 1.0}], # List of dicts
739+
simple_list=["a", "b", "c"], # Simple list (also uses NonTensorStack for consistency)
740+
)
741+
742+
# Verify all data is accessible
743+
assert td["metadata"] == "experiment_1"
744+
assert len(td["turn_scores"]) == 3
745+
assert list(td["turn_scores"][1]) == [0.5, 0.8]
746+
assert len(td["reward_extra_info"]) == 3
747+
assert dict(td["reward_extra_info"][0]) == {"acc": 1.0}
748+
assert len(td["simple_list"]) == 3
749+
assert td["simple_list"][0] == "a"
750+
751+
752+
def test_get_tensordict_with_nested_lists():
753+
"""Test get_tensordict automatically handles nested lists."""
754+
obs = torch.randn(3, 4)
755+
turn_scores = [[], [0.5, 0.8], [0.9]]
756+
757+
# This should automatically convert turn_scores to NonTensorStack
758+
td = tu.get_tensordict({"obs": obs, "turn_scores": turn_scores})
759+
760+
# Verify tensors and nested data are both accessible
761+
assert torch.all(torch.eq(td["obs"], obs))
762+
assert len(td["turn_scores"]) == 3
763+
assert list(td["turn_scores"][0]) == []
764+
assert list(td["turn_scores"][1]) == [0.5, 0.8]
765+
766+
767+
def test_get_tensordict_with_nested_dicts():
768+
"""Test get_tensordict automatically handles lists of dicts."""
769+
obs = torch.randn(3, 4)
770+
reward_extra_info = [{"acc": 1.0}, {"acc": 0.0}, {"acc": 1.0}]
771+
772+
td = tu.get_tensordict({"obs": obs, "reward_extra_info": reward_extra_info})
773+
774+
assert torch.all(torch.eq(td["obs"], obs))
775+
assert len(td["reward_extra_info"]) == 3
776+
assert dict(td["reward_extra_info"][0]) == {"acc": 1.0}
777+
778+
779+
def test_get_tensordict_with_complex_nested_structures():
780+
"""Test get_tensordict with lists of lists of dicts."""
781+
obs = torch.randn(2, 4)
782+
raw_prompt = [
783+
[{"content": "Q1", "role": "user"}],
784+
[{"content": "Q2", "role": "user"}, {"content": "A2", "role": "assistant"}],
785+
]
786+
787+
td = tu.get_tensordict({"obs": obs, "raw_prompt": raw_prompt})
788+
789+
assert torch.all(torch.eq(td["obs"], obs))
790+
assert len(td["raw_prompt"]) == 2
791+
assert dict(td["raw_prompt"][0][0]) == {"content": "Q1", "role": "user"}
792+
793+
794+
def test_get_tensordict_agent_loop_scenario():
795+
"""Test the complete agent loop scenario with all nested types.
796+
797+
This simulates the exact use case from agent loops with:
798+
- turn_scores: lists of lists
799+
- reward_extra_info: lists of dicts
800+
- raw_prompt: lists of lists of dicts
801+
- tool_rewards: lists of lists
802+
"""
803+
prompts = torch.randn(2, 10)
804+
responses = torch.randn(2, 5)
805+
806+
# Nested structures from agent loop
807+
data_source = ["lighteval/MATH", "lighteval/MATH"]
808+
uid = ["uuid-1", "uuid-2"]
809+
turn_scores = [[], [0.5, 0.8]] # Lists of varying lengths
810+
reward_extra_info = [{"acc": 1.0, "loss": 0.1}, {"acc": 0.0, "loss": 0.9}]
811+
raw_prompt = [
812+
[{"content": "Compute 4 @ 2", "role": "user"}],
813+
[{"content": "Compute 8 @ 7", "role": "user"}],
814+
]
815+
tool_rewards = [[0.0], []] # List of lists
816+
817+
# This should handle all nested structures automatically
818+
td = tu.get_tensordict(
819+
tensor_dict={
820+
"prompts": prompts,
821+
"responses": responses,
822+
"data_source": data_source,
823+
"uid": uid,
824+
"turn_scores": turn_scores,
825+
"reward_extra_info": reward_extra_info,
826+
"raw_prompt": raw_prompt,
827+
"tool_rewards": tool_rewards,
828+
},
829+
non_tensor_dict={"global_steps": 42},
830+
)
831+
832+
# Verify all data types are accessible
833+
assert torch.all(torch.eq(td["prompts"], prompts))
834+
assert torch.all(torch.eq(td["responses"], responses))
835+
assert td["data_source"] == data_source
836+
assert td["uid"] == uid
837+
838+
# Verify nested structures
839+
assert len(td["turn_scores"]) == 2
840+
assert list(td["turn_scores"][0]) == []
841+
assert list(td["turn_scores"][1]) == [0.5, 0.8]
842+
843+
assert len(td["reward_extra_info"]) == 2
844+
assert dict(td["reward_extra_info"][0]) == {"acc": 1.0, "loss": 0.1}
845+
846+
assert len(td["raw_prompt"]) == 2
847+
assert dict(td["raw_prompt"][0][0]) == {"content": "Compute 4 @ 2", "role": "user"}
848+
849+
assert len(td["tool_rewards"]) == 2
850+
assert list(td["tool_rewards"][0]) == [0.0]
851+
assert list(td["tool_rewards"][1]) == []
852+
853+
# Verify metadata
854+
assert td["global_steps"] == 42

verl/utils/tensordict_utils.py

Lines changed: 83 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,71 @@
2020
from tensordict.tensorclass import NonTensorData, NonTensorStack
2121

2222

23-
def assign_non_tensor_dict(tensor_dict: TensorDict, non_tensor_dict: dict):
24-
for key, val in non_tensor_dict.items():
25-
assign_non_tensor_data(tensor_dict=tensor_dict, key=key, val=val)
26-
return tensor_dict
27-
28-
2923
def assign_non_tensor_data(tensor_dict: TensorDict, key, val):
24+
assert isinstance(tensor_dict, TensorDict), "input dict must be a TensorDict"
3025
tensor_dict[key] = NonTensorData(val)
3126

3227

33-
def assign_non_tensor(tensordict: TensorDict, **kwargs):
28+
def assign_non_tensor_stack(tensor_dict: TensorDict, key, val: list):
29+
"""Assign a list with potentially nested structures (lists, dicts, etc.) to TensorDict.
30+
31+
This function handles complex nested data structures like:
32+
- Lists of lists: [[], [0.5, 0.8], [0.9]]
33+
- Lists of dicts: [{"acc": 1.0}, {"acc": 0.0}]
34+
- Lists of lists of dicts: [[{"content": "...", "role": "user"}]]
35+
36+
These structures are wrapped in NonTensorStack so TensorDict can handle them correctly.
37+
38+
Args:
39+
tensor_dict: The TensorDict to assign to
40+
key: The key to assign the value under
41+
val: A list containing potentially nested structures
42+
43+
Example:
44+
>>> td = TensorDict({}, batch_size=[])
45+
>>> turn_scores = [[], [0.5, 0.8], [0.9]]
46+
>>> assign_non_tensor_stack(td, "turn_scores", turn_scores)
47+
>>> # Now td["turn_scores"] contains the nested data
48+
"""
49+
# Convert list to NonTensorStack to handle nested structures
50+
# This wraps each item in NonTensorData to preserve complex objects
51+
# TODO(petersh6): can convert back to val directly if we are not accessing .data from the NonTensorStack
52+
assert isinstance(tensor_dict, TensorDict), "input dict must be a TensorDict"
53+
tensor_dict[key] = NonTensorStack.from_list([NonTensorData(item) for item in val])
54+
55+
56+
def assign_non_tensor(tensor_dict: TensorDict, **kwargs):
57+
"""Assign non-tensor data to a TensorDict.
58+
59+
Automatically detects if the value is a list with nested structures and uses
60+
the appropriate assignment method (NonTensorData for simple values,
61+
NonTensorStack for lists with nested structures).
62+
63+
Args:
64+
tensor_dict: The TensorDict to assign to
65+
**kwargs: Key-value pairs where values can be:
66+
- Simple values (stored as NonTensorData)
67+
- Lists with nested structures (stored as NonTensorStack)
68+
69+
Example:
70+
>>> td = TensorDict({"obs": torch.randn(3, 4)}, batch_size=[3])
71+
>>> assign_non_tensor(
72+
... tensor_dict=td,
73+
... metadata="experiment_1", # Simple value
74+
... turn_scores=[[], [0.5, 0.8], [0.9]] # Nested list
75+
... )
76+
"""
77+
assert isinstance(tensor_dict, TensorDict), "input dict must be a TensorDict"
3478
for key, val in kwargs.items():
35-
assign_non_tensor_data(tensor_dict=tensordict, key=key, val=val)
36-
return tensordict
79+
if isinstance(val, (NonTensorData | NonTensorStack)):
80+
tensor_dict[key] = val
81+
elif isinstance(val, list):
82+
# For lists, use NonTensorStack
83+
assign_non_tensor_stack(tensor_dict=tensor_dict, key=key, val=val)
84+
else:
85+
# For non-list values, use NonTensorData
86+
assign_non_tensor_data(tensor_dict=tensor_dict, key=key, val=val)
87+
return tensor_dict
3788

3889

3990
def unwrap_non_tensor_data(data):
@@ -92,15 +143,31 @@ def concat_tensordict(data: list[TensorDict]) -> TensorDict:
92143

93144

94145
def get_tensordict(tensor_dict: dict[str, torch.Tensor | list], non_tensor_dict: dict = None) -> TensorDict:
95-
"""
146+
"""Create a TensorDict from tensors and non-tensor data.
147+
148+
Automatically handles nested structures in lists by converting them to NonTensorStack.
149+
This enables support for:
150+
- Lists of lists: [[], [0.5, 0.8], [0.9]]
151+
- Lists of dicts: [{"acc": 1.0}, {"acc": 0.0}]
152+
- Lists of lists of dicts: [[{"content": "...", "role": "user"}]]
96153
97154
Args:
98-
data_dict:
99-
meta_info:
155+
tensor_dict: Dictionary of tensors and lists to include in the TensorDict
156+
non_tensor_dict: Dictionary of metadata to store as NonTensorData
100157
101158
Returns:
102-
159+
TensorDict with proper handling of nested structures
160+
161+
Example:
162+
>>> td = get_tensordict(
163+
... tensor_dict={
164+
... "obs": torch.randn(3, 4),
165+
... "turn_scores": [[], [0.5, 0.8], [0.9]] # Nested list
166+
... },
167+
... non_tensor_dict={"experiment": "test"}
168+
... )
103169
"""
170+
tensor_dict = tensor_dict.copy()
104171
if non_tensor_dict is None:
105172
non_tensor_dict = {}
106173

@@ -127,6 +194,9 @@ def get_tensordict(tensor_dict: dict[str, torch.Tensor | list], non_tensor_dict:
127194
"Passing a list makes the data NonTensorStack, "
128195
"which doesn't support torch.Tensor. Please convert to numpy first"
129196
)
197+
# Convert to NonTensorStack to handle nested structures
198+
tensor_dict[key] = NonTensorStack.from_list([NonTensorData(item) for item in val])
199+
130200
assert isinstance(val, torch.Tensor | list)
131201

132202
if batch_size is None:

0 commit comments

Comments
 (0)