[rl] Refactor Episode definition to be a single completion instead of a group#2529
[rl] Refactor Episode definition to be a single completion instead of a group#2529
Conversation
Group concept for GRPOGroup concept for GRPO
There was a problem hiding this comment.
thanks for the PR! I will do a proper review tomorrow. At first glance, it looks good, but if we are doing this refactoring, i wonder if we should just go ahead and remove from generate the concept of episode completely, i.e., generate only provides generations -- the controller is responsible for the "episode" abstraction and putting everything together.
Episode = Trajectory, where a Trajectory is a group of Transitions (turns)
Let me know if you agree that we should do this change now, or if i should go ahead and review the PR as is
|
|
||
| # A Group holds all episodes that share the same prompt (the "G" in GRPO). | ||
| # Advantages are normalized within each group. | ||
| Group = list[Episode] |
There was a problem hiding this comment.
Group, as a term, is too general. Can we just use Episodes? (or alternatively EpisodeGroup)
There was a problem hiding this comment.
I think EpisodeGroup is a good name
|
regarding naming on Episode, can we do a research on popular RL frameworks to see what corresponding naming they use? I do feel like I see Trajectory/Generations/Rollouts more often than Episode. |
| @@ -17,7 +17,7 @@ | |||
| The architecture mirrors monarch's grpo_actor.py but adapted for vLLM rollouts + TorchTitan training. | |||
There was a problem hiding this comment.
I personally don't feel we need this file name change... But if we do can we discuss it in a separate PR?
There was a problem hiding this comment.
Now simple_grpo is specific tied to sum_digits task, and it's unable to plug-in other tasks. It's tied to sum_digits including the evaluate() function and the code under branch config.log_samples. I'd prefer make it more specific before we have a good abstraction
To be clear: my comment's focus was not really about naming, but more about the generator output, i.e. a generator produces a single generation, not a class that contains rewards, ground truth, etc -- This is the rollout's job. We can talk about naming, but to clarify the concepts in Tinker's terms, they would be: This is what i had in the RFC i was working on for multiturn |
|
@felipemello1 @tianyu-l Thanks for the discussion! I want to make this
I removed |
IIRC in GRPO, the reward is graded after you have all the Generations and will be relative. Two questions:
|
|
@tianyu-l @felipemello1 @acisseJZhong @wwwjn re:
I proposed |
We are not grading the reward. Actually, reward itself is the "score" we give to the completion generated by Generator. In sum_digits, the reward is defined in reward function (link), this is a tensor which contains a single score for each completion. In grader, we just simply plug in the reward into each completion within the
This is the current design in this PR, but after discussion with @felipemello1 , I would prefer move to this design (comment), adding an extra layer of |
|
If you guys compare the
I think i can propose an RFC later to have a more in depth discussion. For now, i am not too worried about what we call it or if we have an I just want the
That works for me @wwwjn !
Same principle goes here: Trainer should train, not compute advantage. This is the controllers job. |
Yeah, we need a refactor here. This is a left over issue from my previous refactor. The reward -> advantage -> Loss function should be removed out of trainer, and have a separate |
Group concept for GRPO| train_prompts = [] | ||
| train_answers = [] | ||
| train_questions = [] |
There was a problem hiding this comment.
any reason for the prefix "train_"? wonder if we should remove
| # 3. Controller computes GRPO advantages (normalize within group) | ||
| groups: dict[str, list[int]] = defaultdict(list) | ||
| for idx, ep in enumerate(episodes): | ||
| groups[ep.group_id].append(idx) | ||
| for indices in groups.values(): | ||
| rewards = torch.tensor([episodes[i].reward for i in indices]) | ||
| mean_reward = rewards.mean().item() | ||
| for i in indices: | ||
| episodes[i].advantage = episodes[i].reward - mean_reward |
There was a problem hiding this comment.
losing TrajectoryGroup feels ... a bit strange, but I'm ok with it
| async def generate( | ||
| self, | ||
| prompt_texts: list[str], | ||
| expected_answers: list[str] | None = None, |
There was a problem hiding this comment.
This seems required right now?
| Each prompt produces ``num_samples_per_prompt`` Episodes. Episodes | ||
| from the same prompt share a ``group_id`` so the controller can | ||
| compute group-level advantages later. |
There was a problem hiding this comment.
This is GRPO (or any GxPO) details, which shouldn't be coupled with generator?
Please leave a TODO if you don't want to tackle it in this PR.
Episodeto represent a group of completion in GRPO, which requires several places to flatten the list.group_idfield to identify the group