Conversation
There was a problem hiding this comment.
Code Review
This pull request aims to fix a bug by modifying the prompt filtering logic in dapo_ray_trainer.py. The change adds a condition to retain prompt groups where all generated responses have the same, non-extremal reward. However, my review indicates that this change may introduce a critical issue by providing incorrect training targets to the value function (critic) when using the GRPO advantage estimator. This could lead to training instability and an incorrectly trained model. I've provided a detailed comment on this potential issue.
| kept_prompt_uids = [ | ||
| uid | ||
| for uid, std in prompt_uid2metric_std.items() | ||
| if std > 0 or len(prompt_uid2metric_vals[uid]) == 1 | ||
| if std > 0 or len(prompt_uid2metric_vals[uid]) == 1 or prompt_uid2within[uid] | ||
| ] |
There was a problem hiding this comment.
The new condition or prompt_uid2within[uid] keeps prompt groups where all responses have the same non-extremal reward (a constant c where 0 < c < 1). While this might seem to add more data to the batch, it appears to introduce incorrect training signals for the critic when using the grpo advantage estimator.
Here's the breakdown of the issue:
- For these newly kept prompt groups, the standard deviation of rewards is 0.
- When using the
grpoadvantage estimator,compute_grpo_outcome_advantagewill calculate an advantage of 0 for all responses in such a group. - In
dapo_ray_trainer.py, thereturnsfor the critic are set to be equal to theadvantages. - Therefore, for these prompts, the critic will be trained with a target (
returns) of 0. - However, the value function should predict the expected reward. Since all responses for the prompt have a reward of
c, the expected reward isc, not 0.
This mismatch trains the value function towards an incorrect target, which can destabilize the training process and harm model performance.
If the intention is to handle cases where rewards are continuous, it might be better to adjust the filtering threshold (e.g., std > 1e-6) rather than adding this condition which seems to break the critic's training logic. Could you clarify the reasoning behind this change?
|
recipe has been move to verl-project/verl-recipe as a submodule, #4795. Please submit a PR to ver-recipe. |
fix bugs for #4786