Skip to content

[rl] Refactor Episode definition to be a single completion instead of a group#2529

Open
wwwjn wants to merge 2 commits intomainfrom
rl-refactor
Open

[rl] Refactor Episode definition to be a single completion instead of a group#2529
wwwjn wants to merge 2 commits intomainfrom
rl-refactor

Conversation

@wwwjn
Copy link
Contributor

@wwwjn wwwjn commented Mar 9, 2026

  • Previously we use Episode to represent a group of completion in GRPO, which requires several places to flatten the list.
  • Flatten the Episode concept and introduce group_id field to identify the group

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 9, 2026
@wwwjn wwwjn requested review from acisseJZhong, daniellepintz, felipemello1, joecummings and tianyu-l and removed request for acisseJZhong and tianyu-l March 9, 2026 21:34
@wwwjn wwwjn changed the title Refactor Episode definition and introduce Group concept for GRPO [rl] Refactor Episode definition and introduce Group concept for GRPO Mar 9, 2026
Copy link

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the PR! I will do a proper review tomorrow. At first glance, it looks good, but if we are doing this refactoring, i wonder if we should just go ahead and remove from generate the concept of episode completely, i.e., generate only provides generations -- the controller is responsible for the "episode" abstraction and putting everything together.

Episode = Trajectory, where a Trajectory is a group of Transitions (turns)

https://github.com/thinking-machines-lab/tinker-cookbook/blob/934f0d9b2f53c3edff02cbf23ec6da8682047fa5/tinker_cookbook/rl/rollouts.py#L74-L112

Let me know if you agree that we should do this change now, or if i should go ahead and review the PR as is


# A Group holds all episodes that share the same prompt (the "G" in GRPO).
# Advantages are normalized within each group.
Group = list[Episode]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Group, as a term, is too general. Can we just use Episodes? (or alternatively EpisodeGroup)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think EpisodeGroup is a good name

@tianyu-l
Copy link
Contributor

Episode = Trajectory, where a Trajectory is a group of Transitions (turns)

@felipemello1

  1. The name of Episode was coming from @allenwang28 on forge-related experience. I'm unopinionated here, but would appreciate something that makes most sense, is common and consistent.
  2. In the link you showed, they have something called TrajectoryGroup, which may be another concept?

@acisseJZhong
Copy link
Contributor

acisseJZhong commented Mar 10, 2026

Episode = Trajectory, where a Trajectory is a group of Transitions (turns)

regarding naming on Episode, can we do a research on popular RL frameworks to see what corresponding naming they use? I do feel like I see Trajectory/Generations/Rollouts more often than Episode.

@@ -17,7 +17,7 @@
The architecture mirrors monarch's grpo_actor.py but adapted for vLLM rollouts + TorchTitan training.
Copy link
Contributor

@daniellepintz daniellepintz Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally don't feel we need this file name change... But if we do can we discuss it in a separate PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now simple_grpo is specific tied to sum_digits task, and it's unable to plug-in other tasks. It's tied to sum_digits including the evaluate() function and the code under branch config.log_samples. I'd prefer make it more specific before we have a good abstraction

@felipemello1
Copy link

felipemello1 commented Mar 10, 2026

Episode = Trajectory, where a Trajectory is a group of Transitions (turns)

@felipemello1

  1. The name of Episode was coming from @allenwang28 on forge-related experience. I'm unopinionated here, but would appreciate something that makes most sense, is common and consistent.
  2. In the link you showed, they have something called TrajectoryGroup, which may be another concept?

To be clear: my comment's focus was not really about naming, but more about the generator output, i.e. a generator produces a single generation, not a class that contains rewards, ground truth, etc -- This is the rollout's job.

We can talk about naming, but to clarify the concepts in Tinker's terms, they would be:
Transition: A single turn
Trajectory (List[Transition] + final reward): a collection of turns, a completed rollout
TrajectoryGroup (List[Trajectory]): A collection of rollouts for grpo, e.g. if we do n=8 rollouts for the same prompt
Datum (or something like that. What i would call Episode): The processed Trajectory -- a final piece of data ready for training. It has reflogprobs, advantages, mask, etc.

This is what i had in the RFC i was working on for multiturn

@dataclass
class GeneratorOutput:
    """Clean output from generator."""

    tokens: list[int]
    logprobs: list[float]
    stop_reason: str | None = None  # "stop", "length", or "abort" from vLLM
    message: dict | None = None

@dataclass
class TokensWithLogprobs:
    """Exact tokens from generation with their logprobs."""

    tokens: list[int]
    logprobs: list[float]

    def __post_init__(self):
        if self.logprobs and len(self.logprobs) != len(self.tokens):
            raise ValueError("logprobs length must equal tokens length")

@dataclass
class Transition:
    """One generation step in a trajectory."""

    observation_tokens: list[int]
    action: TokensWithLogprobs
    messages: list[dict] | None = None
    rewards: dict[str, float] = field(default_factory=dict)
    info: dict[str, Any] = field(default_factory=dict)


@dataclass
class Trajectory:
    """Pure data container for a rollout"""

    transitions: list[Transition] = field(default_factory=list)
    status: TrajectoryStatus = TrajectoryStatus.COMPLETED
    rewards: dict[str, float] = field(default_factory=dict)
    advantage: float | None = None  # set by grpo_advantage / gdpo_advantage
    metadata: dict[str, Any] = field(default_factory=dict)

    @property
    def prompt_tokens(self) -> list[int]:
        """Initial prompt is the first transition's observation."""
        return self.transitions[0].observation_tokens if self.transitions else []

    @property
    def total_tokens(self) -> int:
        """Total tokens in the trajectory (last observation + last action)."""
        if not self.transitions:
            return 0
        last = self.transitions[-1]
        return len(last.observation_tokens) + len(last.action.tokens)


@dataclass
class Episode:
    """One training sequence

    All tensors are already shifted for next-token prediction:
    - input_ids[i] is the model input at position i
    - target_ids[i] is the prediction target at position i
    - loss_mask[i], policy_logprobs[i], advantage[i] correspond to predicting target_ids[i]
    """

    input_ids: torch.Tensor
    target_ids: torch.Tensor
    loss_mask: torch.Tensor
    policy_logprobs: torch.Tensor
    advantage: torch.Tensor  # per-token [seq_len]
    policy_version: int

    # optional
    ref_logprobs: torch.Tensor | None = None
    messages: list[dict] | None = None
    rewards: dict[str, float] = field(default_factory=dict)
    group_id: str | None = None
    metadata: dict[str, Any] = field(default_factory=dict)

    def __post_init__(self):
        n = len(self.input_ids)
        if len(self.target_ids) != n:
            raise ValueError(
                f"target_ids length ({len(self.target_ids)}) must equal "
                f"input_ids length ({n})"
            )
        if len(self.loss_mask) != n:
            raise ValueError(
                f"loss_mask length ({len(self.loss_mask)}) must equal "
                f"input_ids length ({n})"
            )
        if self.policy_logprobs.numel() > 0 and len(self.policy_logprobs) != n:
            raise ValueError(
                f"policy_logprobs length ({len(self.policy_logprobs)}) must equal "
                f"input_ids length ({n})"
            )
        if self.ref_logprobs is not None and len(self.ref_logprobs) != n:
            raise ValueError(
                f"ref_logprobs length ({len(self.ref_logprobs)}) must equal "
                f"input_ids length ({n})"
            )
        if len(self.advantage) != n:
            raise ValueError(
                f"advantage length ({len(self.advantage)}) must equal "
                f"input_ids length ({n})"
            )

@wwwjn
Copy link
Contributor Author

wwwjn commented Mar 10, 2026

@felipemello1 @tianyu-l Thanks for the discussion! I want to make this Datum / Episode as simple as possible, but make it easier for everyone to read the code, and make things super clear regarding all the concepts: I mainly focus on single turn design and we can refactor later for multi-turn.

  • Generation/ Sample (name to be decided, but Transition seems not descriptive enough) : Output of generator. A single turn generation. In GRPO, it will generate num_groups Generator all at ones, so generator will return List[Generation].
    • Generation = (generated token ids, log probs)
  • Episode (Prompt + Generation + reward): a single turn result. This class will be assembled by Grader.
  • EpisodeGroup (List[Episode]): A collection of rollouts for grpo. This class will be assembled by Grader.

I removed Datum layer because we don't have other processor other than Grader between generator and trainer. As we evolution, there might be post-processor or batcher, then we can introduce another layer

@tianyu-l
Copy link
Contributor

@wwwjn

Episode (Generation + reward): a single turn result. This class will be assembled by Grader.

IIRC in GRPO, the reward is graded after you have all the Generations and will be relative. Two questions:

  1. In today's sum_digits task, I didn't see relative score. Does it mean we are not using GRPO as is? cc @daniellepintz
  2. Do you use Episode as a placeholder, which Generator first fills the Generation, and then Grader fills the reward? This is not very object-oriented but may be OK for now. Just asking.

@allenwang28
Copy link
Contributor

allenwang28 commented Mar 10, 2026

@tianyu-l @felipemello1 @acisseJZhong @wwwjn

re: Episode vs Trajectory, I am not opinionated on this either, but some notes on the etymology of these terms:

  • Episode is the older, more general RL term referring to the complete interaction from initial state to terminal state
  • Trajectory is the specific sequence of (state, action, reward) tuples within an episode. Technically every episode produces a trajectory, but not every trajectory is a full episode (you could have partial episodes == trajectory)
  • For LLMs this distinction collapses because for reasoning loops like ours, a single generation produces a trajectory which == episode.
  • DeepSeek R1 paper uses "Trajectory", and Episode may show up more when people think about multi-turn or agentic settings when you have a genuine env loop

I proposed Episode as the most technically correct representation of what we expect goes into the trainer, but given that there are multiple instances where naming blurs the lines (i.e. I could see Sampler vs Generator vs Policy become a debate as well), I would recommend we pick and commit to a heuristic for defining these terms sooner than later.

@wwwjn
Copy link
Contributor Author

wwwjn commented Mar 10, 2026

@wwwjn

Episode (Generation + reward): a single turn result. This class will be assembled by Grader.

IIRC in GRPO, the reward is graded after you have all the Generations and will be relative. Two questions:

  1. In today's sum_digits task, I didn't see relative score. Does it mean we are not using GRPO as is? cc @daniellepintz\

We are not grading the reward. Actually, reward itself is the "score" we give to the completion generated by Generator. In sum_digits, the reward is defined in reward function (link), this is a tensor which contains a single score for each completion. In grader, we just simply plug in the reward into each completion within the Episode. The relative score you are referring to is in done in Trainer.step() [link] (

advantages = rewards - rewards.mean()
) to calculate the advantage -> loss for backward. The trainer part (which handles GRPO logic) might need a refactor and better separation, or rename it to be GRPOTrainer.

  1. Do you use Episode as a placeholder, which Generator first fills the Generation, and then Grader fills the reward? This is not very object-oriented but may be OK for now. Just asking.

This is the current design in this PR, but after discussion with @felipemello1 , I would prefer move to this design (comment), adding an extra layer of Generation to represent the output of Generator.

@felipemello1
Copy link

felipemello1 commented Mar 10, 2026

If you guys compare the Trajectory vs Episode definition i shared, the main differences are

  1. Trajectory is discrete steps(Trajectory.transitions). Episode is in the token space (Episode.input_ids), ready for fwd_bwd. This matters for multiturn, but we are not there yet.
  2. Episode already has loss_mask (i.e. the trainer doesnt know what to mask) -- I think that we naively mask the len(prompt). This works for single turn, but fails for multiturn.
  3. Episode has ref_logprobs

I think i can propose an RFC later to have a more in depth discussion. For now, i am not too worried about what we call it or if we have an Episode with many empty fields that we fill along the way. @allenwang28 suggestion works for me.

I just want the Generator to not be opinionated about what an Episode is. That way, people can reuse it and change to whatever they want.

I would prefer move to this design (#2529 (comment)), adding an extra layer of Generation to represent the output of Generator.

That works for me @wwwjn !

The relative score you are referring to is in done in Trainer.step()

Same principle goes here: Trainer should train, not compute advantage. This is the controllers job.

@wwwjn
Copy link
Contributor Author

wwwjn commented Mar 10, 2026

The relative score you are referring to is in done in Trainer.step()

Same principle goes here: Trainer should train, not compute advantage. This is the controllers job.

Yeah, we need a refactor here. This is a left over issue from my previous refactor. The reward -> advantage -> Loss function should be removed out of trainer, and have a separate loss.py. Will working on this later

@wwwjn wwwjn changed the title [rl] Refactor Episode definition and introduce Group concept for GRPO [rl] Refactor Episode definition to be a single completion instead of a group Mar 11, 2026
Comment on lines 301 to 303
train_prompts = []
train_answers = []
train_questions = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason for the prefix "train_"? wonder if we should remove

Comment on lines +324 to +332
# 3. Controller computes GRPO advantages (normalize within group)
groups: dict[str, list[int]] = defaultdict(list)
for idx, ep in enumerate(episodes):
groups[ep.group_id].append(idx)
for indices in groups.values():
rewards = torch.tensor([episodes[i].reward for i in indices])
mean_reward = rewards.mean().item()
for i in indices:
episodes[i].advantage = episodes[i].reward - mean_reward
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

losing TrajectoryGroup feels ... a bit strange, but I'm ok with it

async def generate(
self,
prompt_texts: list[str],
expected_answers: list[str] | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems required right now?

Comment on lines +234 to +236
Each prompt produces ``num_samples_per_prompt`` Episodes. Episodes
from the same prompt share a ``group_id`` so the controller can
compute group-level advantages later.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is GRPO (or any GxPO) details, which shouldn't be coupled with generator?
Please leave a TODO if you don't want to tackle it in this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants