Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions recipe/dapo/config/GRM_template.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The above is a Q&A dialogue between a user and an assistant. It is now known that the standard answer to the user's question is {ground_truth}. Please determine whether the assistant has answered the user's question clearly, precisely, and accurately. Present your reasoning and judgment in the following format:\n\nThink: Content of Thinking\nJudgment: Correct / Incorrect
247 changes: 246 additions & 1 deletion recipe/dapo/dapo_ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
This trainer supports model-agonistic model initialization with huggingface
"""

import random
import uuid
from collections import defaultdict
from copy import deepcopy
Expand All @@ -41,6 +42,175 @@ class RayDAPOTrainer(RayPPOTrainer):
Note that this trainer runs on the driver process on a single CPU/GPU node.
"""

def _validate(self):
"""Override the parent validation method to add GRM support"""
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto

data_source_lst = []
reward_extra_infos_dict: dict[str, list] = defaultdict(list)

# Lists to collect samples for the table
sample_inputs = []
sample_outputs = []
sample_scores = []

# Collect all test batches for logging
all_test_batches = None

for test_data in self.val_dataloader:
test_batch = DataProto.from_single_dict(test_data)

# repeat test batch
test_batch = test_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True)

# we only do validation on rule-based rm
if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model":
return {}

# Store original inputs
input_ids = test_batch.batch["input_ids"]
# TODO: Can we keep special tokens except for padding tokens?
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
sample_inputs.extend(input_texts)

batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
if "multi_modal_inputs" in test_batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.extend(["multi_modal_data", "multi_modal_inputs"])
if "raw_prompt" in test_batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("raw_prompt")
if "tools_kwargs" in test_batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("tools_kwargs")
test_gen_batch = test_batch.pop(
batch_keys=batch_keys_to_pop,
non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
)

test_gen_batch.meta_info = {
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
"recompute_log_prob": False,
"do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample,
"validate": True,
}
print(f"test_gen_batch meta info: {test_gen_batch.meta_info}")

# pad to be divisible by dp_size
test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size)
test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)

# unpad
test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)
print("validation generation end")

# Store generated outputs
output_ids = test_output_gen_batch.batch["responses"]
output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
sample_outputs.extend(output_texts)

test_batch = test_batch.union(test_output_gen_batch)

# Add GRM support similar to training phase
if self.use_grm:
# pad to be divisible by grm worker group size before GRM processing
test_batch_grm_padded, grm_pad_size = pad_dataproto_to_divisor(test_batch, self.grm_wg.world_size)
grm_output_padded = self.grm_wg.generate_sequences_as_grm(test_batch_grm_padded)

# unpad GRM output
grm_output = unpad_dataproto(grm_output_padded, pad_size=grm_pad_size)

# Add "_grm" suffix to all keys from output and union into test_batch
grm_reward_tensor = {}
if hasattr(grm_output, "batch") and grm_output.batch is not None:
for key, value in grm_output.batch.items():
grm_reward_tensor[key + "_grm"] = value

# Create DataProto object for union
if grm_reward_tensor:
grm_reward_data_proto = DataProto.from_dict(grm_reward_tensor)
test_batch = test_batch.union(grm_reward_data_proto)
else:
print("Validation: No grm_reward_tensor from GRM output")

# evaluate using reward_function
result = self.val_reward_fn(test_batch, return_dict=True)
reward_tensor = result["reward_tensor"]
scores = reward_tensor.sum(-1).cpu().tolist()
sample_scores.extend(scores)

reward_extra_infos_dict["reward"].extend(scores)
if reward_extra_infos_dict:
test_batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})
if "reward_extra_info" in result:
for key, lst in result["reward_extra_info"].items():
reward_extra_infos_dict[key].extend(lst)

# Add reward scores to non_tensor_batch for logging
current_scores = reward_tensor.sum(-1).cpu().tolist()
test_batch.non_tensor_batch["reward"] = np.array(current_scores)
if "reward_extra_info" in result:
for key, lst in result["reward_extra_info"].items():
if key not in test_batch.non_tensor_batch:
test_batch.non_tensor_batch[key] = np.array(lst)

# Ensure test_batch has prompts field for logging
if "prompts" not in test_batch.batch:
# Reconstruct prompts from the original input_ids we saved
original_input_ids = test_gen_batch.batch["input_ids"]
test_batch.batch["prompts"] = original_input_ids

# Collect test_batch for logging
if all_test_batches is None:
all_test_batches = test_batch
else:
all_test_batches = DataProto.concat([all_test_batches, test_batch])

data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0]))

# Log generations from all collected batches
if all_test_batches is not None:
self._maybe_log_val_generations(all_test_batches)

# dump generations
val_data_dir = self.config.trainer.get("validation_data_dir", None)
if val_data_dir:
self._dump_generations(
inputs=sample_inputs,
outputs=sample_outputs,
scores=sample_scores,
reward_extra_infos_dict=reward_extra_infos_dict,
dump_path=val_data_dir,
)

for key_info, lst in reward_extra_infos_dict.items():
assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}"

data_sources = np.concatenate(data_source_lst, axis=0)

from verl.trainer.ppo.metric_utils import process_validation_metrics

data_src2var2metric2val = process_validation_metrics(data_sources, sample_inputs, reward_extra_infos_dict)
metric_dict = {}
for data_source, var2metric2val in data_src2var2metric2val.items():
core_var = "acc" if "acc" in var2metric2val else "reward"
for var_name, metric2val in var2metric2val.items():
n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()])
for metric_name, metric_val in metric2val.items():
if (var_name == core_var) and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"]) and (f"@{n_max}" in metric_name):
metric_sec = "val-core"
else:
metric_sec = "val-aux"
pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}"
metric_dict[pfx] = metric_val

# statistics val response len
response_lens = [len(text) for text in sample_outputs]
metric_dict["val/response_len/mean"] = np.mean(response_lens)
metric_dict["val/response_len/max"] = np.max(response_lens)
metric_dict["val/response_len/min"] = np.min(response_lens)

return metric_dict

def fit(self):
"""
The training loop of PPO.
Expand Down Expand Up @@ -137,11 +307,26 @@ def fit(self):
# compute scores. Support both model and function-based.
# We first compute the scores using reward model. Then, we call reward_fn to combine
# the results from reward model and rule-based results.
if self.use_rm:
if self.use_rm and not self.use_grm:
# we first compute reward model score
reward_tensor = self.rm_wg.compute_rm_score(new_batch)
new_batch = new_batch.union(reward_tensor)

if self.use_grm:
output = self.grm_wg.generate_sequences_as_grm(new_batch)
# Add "_grm" suffix to all keys from output and union into new_batch
reward_tensor = {}
if hasattr(output, "batch") and output.batch is not None:
for key, value in output.batch.items():
reward_tensor[key + "_grm"] = value

# Create DataProto object for union
if reward_tensor:
reward_data_proto = DataProto.from_dict(reward_tensor)
new_batch = new_batch.union(reward_data_proto)
else:
print("No reward_tensor from GRM output")

# we combine with rule-based rm
reward_extra_infos_dict: dict[str, list]
try:
Expand Down Expand Up @@ -312,3 +497,63 @@ def fit(self):

progress_bar.update(1)
self.global_steps += 1

def _maybe_log_val_generations(self, batch: DataProto):
"""Log a table of validation samples to the configured logger (wandb or swanlab)"""
generations_to_log = self.config.trainer.get("log_val_generations", 0)

if generations_to_log == 0:
return

prompts, response = batch.batch["prompts"], batch.batch["responses"]
prompts = self.tokenizer.batch_decode(prompts, skip_special_tokens=True)
response = self.tokenizer.batch_decode(response, skip_special_tokens=True)

if batch.batch.get("prompts_grm", None) is not None:
prompts_grm = batch.batch["prompts_grm"]
prompts_grm = self.tokenizer.batch_decode(prompts_grm, skip_special_tokens=True)
if batch.batch.get("responses_grm", None) is not None:
response_grm = batch.batch["responses_grm"]
response_grm = self.tokenizer.batch_decode(response_grm, skip_special_tokens=True)

res_ids = list(range(len(prompts)))
sample_ids = random.sample(res_ids, min(generations_to_log, len(res_ids)))

sample_inputs = []
sample_outputs = []
sample_scores = []
sample_inputs_grm = []
sample_outputs_grm = []
for idx in sample_ids:
sample_inputs.append(prompts[idx])
sample_outputs.append(response[idx])
# Use reward score from non_tensor_batch, fallback to acc if available
if "reward" in batch.non_tensor_batch:
sample_scores.append(f"{batch.non_tensor_batch['reward'][idx]:.2f}")
elif "acc" in batch.non_tensor_batch:
sample_scores.append(f"{batch.non_tensor_batch['acc'][idx]:.2f}")
else:
sample_scores.append("N/A")

if batch.batch.get("prompts_grm", None) is not None:
sample_inputs_grm.append(prompts_grm[idx])
sample_outputs_grm.append(response_grm[idx])

# Create samples as dict[dict] format
samples = {}
for i, idx in enumerate(sample_ids):
sample_key = f"sample_{i + 1}"
score_val = "N/A"
if "reward" in batch.non_tensor_batch:
score_val = f"{batch.non_tensor_batch['reward'][idx]:.2f}"
elif "acc" in batch.non_tensor_batch:
score_val = f"{batch.non_tensor_batch['acc'][idx]:.2f}"

sample_data = {"input": prompts[idx], "output": response[idx], "score": score_val}
if batch.batch.get("prompts_grm", None) is not None:
sample_data["input_grm"] = prompts_grm[idx]
sample_data["output_grm"] = response_grm[idx]
samples[sample_key] = sample_data

# Log to each configured logger
self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps)
7 changes: 6 additions & 1 deletion recipe/dapo/main_dapo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""


import hydra
import ray

Expand Down Expand Up @@ -89,8 +88,10 @@ def run(self, config):
}

global_pool_id = "global_pool"
grm_pool_id = "grm_pool"
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
grm_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.get("grm_nnodes", 0),
}
mapping = {
Role.ActorRollout: global_pool_id,
Expand All @@ -113,6 +114,10 @@ def run(self, config):
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
mapping[Role.RewardModel] = global_pool_id

if config.reward_model.grm.enable:
role_worker_mapping[Role.GenerativeRewardModel] = ray.remote(ActorRolloutRefWorker)
mapping[Role.GenerativeRewardModel] = grm_pool_id

# reference model
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
Expand Down
Loading