Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion tests/experimental/agent_loop/test_basic_agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from verl.protocol import DataProto
from verl.tools.base_tool import BaseTool, OpenAIFunctionToolSchema
from verl.tools.schemas import ToolResponse
from verl.trainer.ppo.reward import compute_reward, load_reward_manager
from verl.utils import hf_tokenizer


Expand All @@ -41,6 +42,10 @@ def init_config() -> DictConfig:
# test sleep/wake_up with fsdp offload
"actor_rollout_ref.actor.fsdp_config.param_offload=True",
"actor_rollout_ref.actor.fsdp_config.optimizer_offload=True",
"reward_model.reward_manager=dapo",
"+reward_model.reward_kwargs.overlong_buffer_cfg.enable=False",
"+reward_model.reward_kwargs.overlong_buffer_cfg.len=3072",
"+reward_model.reward_kwargs.max_resp_len=4096",
],
)

Expand Down Expand Up @@ -69,6 +74,10 @@ def test_single_turn(init_config):
)

agent_loop_manager = init_agent_loop_manager(init_config)
tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)
reward_fn = load_reward_manager(
init_config, tokenizer, num_examine=0, **init_config.reward_model.get("reward_kwargs", {})
)

raw_prompts = [
[
Expand Down Expand Up @@ -97,10 +106,17 @@ def test_single_turn(init_config):
assert result.batch["input_ids"].size(1) == seq_len
assert result.batch["attention_mask"].size(1) == seq_len
assert result.batch["position_ids"].size(1) == seq_len
assert result.batch["rm_scores"].size(1) == result.batch["responses"].size(1)

if init_config.actor_rollout_ref.rollout.calculate_log_probs:
assert result.batch["rollout_log_probs"].size(1) == result.batch["responses"].size(1)

# check compute score
assert result.batch["rm_scores"].shape == result.batch["responses"].shape
reward_tensor, reward_extra_info = compute_reward(result, reward_fn)
assert reward_tensor.shape == result.batch["responses"].shape
assert "acc" in reward_extra_info, f"reward_extra_info {reward_extra_info} should contain 'acc'"
assert reward_extra_info["acc"].shape == (len(result),), f"invalid acc: {reward_extra_info['acc']}"

# check turns
num_turns = result.non_tensor_batch["__num_turns__"]
assert np.all(num_turns == 2)
Expand Down
35 changes: 27 additions & 8 deletions verl/experimental/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ class AgentLoopOutput(BaseModel):
"""Number of chat turns, including user, assistant, tool."""
metrics: AgentLoopMetrics
"""Auxiliary performance metrics"""
extra_fields: dict[str, Any] = {}
"""Extra fields for dynamic addition."""


class _InternalAgentLoopOutput(AgentLoopOutput):
Expand Down Expand Up @@ -252,7 +254,7 @@ def __init__(self, config: DictConfig, local_path: str) -> None:
)
self.loop = asyncio.get_event_loop()

async def compute_score(self, output: AgentLoopOutput, kwargs: dict) -> float:
async def compute_score(self, output: AgentLoopOutput, kwargs: dict) -> dict:
"""Compute reward score for agent loop output.

NOTE: Since `reward_manager.__call__` is blocking function, we run it in thread pool to
Expand All @@ -263,7 +265,7 @@ async def compute_score(self, output: AgentLoopOutput, kwargs: dict) -> float:
kwargs (dict): Dataset fields from `verl.utils.dataset.RLHFDataset`.

Returns:
float: Reward score.
dict: Reward score and reward extra info.
"""
prompts = torch.tensor(output.prompt_ids, dtype=torch.long).unsqueeze(0)
responses = torch.tensor(output.response_ids, dtype=torch.long).unsqueeze(0)
Expand All @@ -284,12 +286,16 @@ async def compute_score(self, output: AgentLoopOutput, kwargs: dict) -> float:
batch=batch,
non_tensor_batch=non_tensor_batch,
)
reward_tensor = await self.loop.run_in_executor(
result = await self.loop.run_in_executor(
None,
self.reward_manager,
data,
True, # return_dict
)
return reward_tensor.sum(dim=-1).item()

reward_score = result["reward_tensor"].sum(dim=-1).item()
reward_extra_info = {k: v[0] for k, v in result.get("reward_extra_info", {}).items()}
return {"reward_score": reward_score, "reward_extra_info": reward_extra_info}


@ray.remote
Expand Down Expand Up @@ -424,7 +430,9 @@ async def _run_agent_loop(

# Some AgentLoop may have already computed the reward score, e.g SWE-agent.
if output.reward_score is None and not self.config.reward_model.enable:
output.reward_score = await self.reward_manager_worker.compute_score.remote(output, kwargs)
result = await self.reward_manager_worker.compute_score.remote(output, kwargs)
output.reward_score = result["reward_score"]
output.extra_fields["reward_extra_info"] = result["reward_extra_info"]

# NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py
# prompt_ids: left padded with zeros (e.g., [0,0,0,0,1,2,3,4])
Expand Down Expand Up @@ -534,6 +542,7 @@ async def _run_agent_loop(
reward_score=output.reward_score,
num_turns=output.num_turns,
metrics=output.metrics,
extra_fields=output.extra_fields,
)

def _postprocess(self, inputs: list[_InternalAgentLoopOutput]) -> DataProto:
Expand Down Expand Up @@ -575,13 +584,23 @@ def _postprocess(self, inputs: list[_InternalAgentLoopOutput]) -> DataProto:
"__num_turns__": np.array([input.num_turns for input in inputs], dtype=np.int32),
}

# add reward_extra_info to non_tensor_batch
reward_extra_infos = [input.extra_fields.get("reward_extra_info", {}) for input in inputs]
reward_extra_keys = list(reward_extra_infos[0].keys())
for key in reward_extra_keys:
non_tensor_batch[key] = np.array([info[key] for info in reward_extra_infos])

# Add multi_modal_inputs to non_tensor_batch if any samples have them
multi_modal_inputs_list = [input.multi_modal_inputs for input in inputs]
if any(mmi is not None for mmi in multi_modal_inputs_list):
non_tensor_batch["multi_modal_inputs"] = np.array(multi_modal_inputs_list, dtype=object)

metrics = [input.metrics.model_dump() for input in inputs]
return DataProto(batch=batch, non_tensor_batch=non_tensor_batch, meta_info={"metrics": metrics})
return DataProto(
batch=batch,
non_tensor_batch=non_tensor_batch,
meta_info={"metrics": metrics, "reward_extra_keys": reward_extra_keys},
)


async def get_trajectory_info(step, index, validate):
Expand Down Expand Up @@ -717,10 +736,10 @@ def generate_sequences(self, prompts: DataProto) -> DataProto:
self.sleep()

# calculate performance metrics
metrics = [output.meta_info["metrics"] for output in outputs] # List[List[Dict[str, str]]]
metrics = [output.meta_info.pop("metrics") for output in outputs] # List[List[Dict[str, str]]]
timing = self._performance_metrics(metrics, output)

output.meta_info = {"timing": timing}
output.meta_info = {"timing": timing, **outputs[0].meta_info}
return output

def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: DataProto) -> dict[str, float]:
Expand Down
4 changes: 3 additions & 1 deletion verl/workers/reward_manager/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def __call__(self, data: DataProto, return_dict: bool = False) -> torch.Tensor |
# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
if "rm_scores" in data.batch.keys():
if return_dict:
return {"reward_tensor": data.batch["rm_scores"]}
reward_extra_keys = data.meta_info.get("reward_extra_keys", [])
reward_extra_info = {key: data.non_tensor_batch[key] for key in reward_extra_keys}
return {"reward_tensor": data.batch["rm_scores"], "reward_extra_info": reward_extra_info}
else:
return data.batch["rm_scores"]

Expand Down
4 changes: 3 additions & 1 deletion verl/workers/reward_manager/dapo.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def __call__(self, data: DataProto, return_dict: bool = False):
# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
if "rm_scores" in data.batch.keys():
if return_dict:
return {"reward_tensor": data.batch["rm_scores"]}
reward_extra_keys = data.meta_info.get("reward_extra_keys", [])
reward_extra_info = {key: data.non_tensor_batch[key] for key in reward_extra_keys}
return {"reward_tensor": data.batch["rm_scores"], "reward_extra_info": reward_extra_info}
else:
return data.batch["rm_scores"]

Expand Down
4 changes: 3 additions & 1 deletion verl/workers/reward_manager/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def __call__(self, data: DataProto, return_dict: bool = False) -> torch.Tensor |
# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
if "rm_scores" in data.batch.keys():
if return_dict:
return {"reward_tensor": data.batch["rm_scores"]}
reward_extra_keys = data.meta_info.get("reward_extra_keys", [])
reward_extra_info = {key: data.non_tensor_batch[key] for key in reward_extra_keys}
return {"reward_tensor": data.batch["rm_scores"], "reward_extra_info": reward_extra_info}
else:
return data.batch["rm_scores"]

Expand Down
7 changes: 6 additions & 1 deletion verl/workers/reward_manager/prime.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,12 @@ def __call__(self, data: DataProto, return_dict: bool = False) -> torch.Tensor |

# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
if "rm_scores" in data.batch.keys():
return data.batch["rm_scores"]
if return_dict:
reward_extra_keys = data.meta_info.get("reward_extra_keys", [])
reward_extra_info = {key: data.non_tensor_batch[key] for key in reward_extra_keys}
return {"reward_tensor": data.batch["rm_scores"], "reward_extra_info": reward_extra_info}
else:
return data.batch["rm_scores"]

reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)

Expand Down