Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 41 additions & 25 deletions verl/workers/reward_manager/prime.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.

import asyncio
from typing import Any
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor
from functools import partial
from typing import Callable, Optional
Expand Down Expand Up @@ -47,7 +49,7 @@ async def single_compute_score(evaluation_func, completion, reference, task, tas


async def parallel_compute_score_async(evaluation_func, completions, references, tasks, extra_info=None, num_processes=64):
scores = []
evaluation_entries = []
with ProcessPoolExecutor(max_workers=num_processes) as executor:
if extra_info is None:
extra_info = [None] * len(tasks)
Expand All @@ -68,12 +70,29 @@ async def parallel_compute_score_async(evaluation_func, completions, references,
for result, completion, reference, task in zip(results, completions, references, tasks):
if isinstance(result, Exception) or result is None:
# Handle failed or timed-out tasks
scores.append(0.0)
evaluation_entries.append({"score": 0.0})
elif isinstance(result[0], (int, float, bool)):
scores.append(float(result[0]))
evaluation_entries.append({"score": float(result[0])})
elif isinstance(result[0], dict):
assert "score" in result[0], "If reward funtion return a dcit, it must contain a 'score' key"
evaluation_entries.append(result[0])
else:
scores.append(float(result[0][0]))
return scores
raise ValueError(
"Reward function must return either a numeric score (int, float, bool) or a dictionary with a 'score' key"
)

all_keys = set()
for entry in evaluation_entries:
all_keys.update(entry.keys())

reward_info = defaultdict(list)
for entry in evaluation_entries:
for key in all_keys:
# If there are some key return by other reward function,
# and the current entry does not contain this key, we pad it with str 'unkown'.
reward_info[key].append(entry.get(key, 'unkown'))

return reward_info


class PrimeRewardManager:
Expand All @@ -93,7 +112,7 @@ def __init__(
self.compute_score = compute_score or _default_compute_score
self.reward_fn_key = reward_fn_key

def verify(self, data):
def verify(self, data) -> dict[str, Any]:
"""
verify the batch and save as ``acc`` tensor
"""
Expand All @@ -108,24 +127,21 @@ def verify(self, data):

assert len(sequences_str) == len(ground_truth) == len(data_sources)
try:
scores = asyncio.run(
parallel_compute_score_async(
self.compute_score,
sequences_str,
ground_truth,
data_sources,
extra_info=extra_info,
num_processes=64,
)
)
except asyncio.TimeoutError:
print("Global timeout in reward computing! Setting all as 0.")
scores = [0.0 for _ in range(len(sequences_str))]
reward_info = asyncio.run(
parallel_compute_score_async(self.compute_score,
sequences_str,
ground_truth,
data_sources,
extra_info=extra_info,
num_processes=64))
except asyncio.TimeoutError as e:
print('Global timeout in reward computing! Setting all as 0.')
reward_info["score"] = [0. for _ in range(len(sequences_str))]
except Exception as e:
print(f"Unexpected error in batched reward computing. Setting all as 0.: {e}")
scores = [0.0 for _ in range(len(sequences_str))]
data.batch["acc"] = torch.tensor(scores, dtype=torch.float32, device=prompt_ids.device)
return scores
reward_info["score"] = [0. for _ in range(len(sequences_str))]
data.batch['acc'] = torch.tensor(reward_info["score"], dtype=torch.float32, device=prompt_ids.device)
return reward_info

def __call__(self, data: DataProto, return_dict: bool = False):
"""We will expand this function gradually based on the available datasets"""
Expand All @@ -147,11 +163,11 @@ def __call__(self, data: DataProto, return_dict: bool = False):
sequences_str = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True)
data_sources = data.non_tensor_batch["data_source"]

scores = self.verify(data)
reward_info = self.verify(data)

for i in range(len(data)):
data_source = data_sources[i]
reward_tensor[i, valid_response_length[i].item() - 1] = scores[i]
reward_tensor[i, valid_response_length[i].item() - 1] = reward_info["score"][i]

if data_source not in already_print_data_sources:
already_print_data_sources[data_source] = 0
Expand All @@ -161,6 +177,6 @@ def __call__(self, data: DataProto, return_dict: bool = False):
print(sequences_str)

if return_dict:
return {"reward_tensor": reward_tensor}
return {"reward_tensor": reward_tensor, "reward_extra_info": reward_info}
else:
return reward_tensor