Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Document the GRPO loss operator
Add docs/operators/grpo-loss.md following the operator doc template: purpose,
entry point, group specification, backend table, tensor contract, dispatch
behavior, reference semantics, benchmark numbers, tests, and limitations.
Link it from the operators index.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
  • Loading branch information
KJLdefeated and claude committed Jun 9, 2026
commit f434e31f15f0b2d1ca06796f5ce8c7189ff82408
1 change: 1 addition & 0 deletions docs/operators/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@ Every operator page should include:
## Current Pages

- [Fused LogP](fused-logp.md)
- [GRPO Loss](grpo-loss.md)
- [Sampling](sampling.md)
- [Operator Doc Template](../contributing/operator-doc-template.md)
167 changes: 167 additions & 0 deletions docs/operators/grpo-loss.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# GRPO Loss

GRPO Loss computes the Group Relative Policy Optimization objective for RL post-training:
it normalizes raw sequence rewards within each generation group into advantages, then
evaluates the clipped surrogate objective plus a reference-KL penalty over the active
completion tokens. It targets the GRPO training step, where a naive PyTorch implementation
allocates several broadcasted `[batch, completion_len]` intermediates and a per-token
advantage tensor.

The operator consumes per-token **log-probs**, so it composes directly with the
[Fused LogP](fused-logp.md) operator:

```
logits --[logp op]--> current_logps --[grpo_loss op]--> loss
```
Comment on lines +12 to +14

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Add a language identifier to the fenced code block.

This fence is missing a language tag (markdownlint MD040). Use something like ```text for the pipeline diagram.

Suggested doc fix
-```
+```text
 logits --[logp op]--> logps --[grpo_loss op]--> loss
</details>

<details>
<summary>🧰 Tools</summary>

<details>
<summary>🪛 markdownlint-cli2 (0.22.1)</summary>

[warning] 12-12: Fenced code blocks should have a language specified

(MD040, fenced-code-language)

</details>

</details>

<details>
<summary>🤖 Prompt for AI Agents</summary>

Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In @docs/operators/grpo-loss.md around lines 12 - 14, The fenced code block
containing the pipeline diagram "logits --[logp op]--> logps --[grpo_loss op]-->
loss" is missing a language identifier; update the fence to include a language
tag (e.g., use ```text) so the block follows markdownlint MD040; locate the
fence around that pipeline diagram in docs/operators/grpo-loss.md and add the
language identifier to the opening backticks.


</details>

<!-- fingerprinting:phantom:triton:hawk -->

<!-- cr-comment:v1:13c3393cb81ae72a6c408c62 -->

_Source: Linters/SAST tools_

<!-- This is an auto-generated comment by CodeRabbit -->


It does not re-implement `log_softmax`; produce `current_logps` with the `logp` operator
(or any equivalent) first.

## Entry Point

```python
from rl_engine.kernels.registry import kernel_registry

grpo = kernel_registry.get_op("grpo_loss")

loss, policy_loss, kl = grpo.forward(
current_logps, # [B, T] current-policy per-token log-probs (differentiable)
old_logps, # [B, T] behavior-policy log-probs (constant)
ref_logps, # [B, T] reference-model log-probs (constant)
rewards, # [B] one scalar reward per sequence
completion_mask, # [B, T] bool; True = active completion token
clip_eps=0.2,
beta=0.04,
samples_per_prompt=8, # uniform groups; or pass group_boundaries=[...]
)

loss.backward() # gradient flows into current_logps
```

`B = num_prompts * samples_per_prompt`. `forward` returns a 3-tuple
`(loss, policy_loss, kl)`; only `loss` is differentiable (`policy_loss` and `kl` are
detached reporting scalars).

### Group specification

Provide **exactly one** of:

- `samples_per_prompt: int` — uniform groups (every prompt has the same number of samples).
- `group_boundaries` — CSR-style offsets of length `num_groups + 1` (e.g. `[0, 8, 16, 24]`)
for variable-sized groups.

### Lower-level helpers

```python
adv = grpo.group_advantages(rewards, samples_per_prompt=8) # [B] per-sequence advantages
loss, policy_loss, kl = grpo.apply( # skip reward normalization
current_logps, old_logps, ref_logps, advantages, completion_mask,
clip_eps=0.2, beta=0.04,
)
```

!!! note "`apply` advantage layout differs by backend"
`NativeGRPOLossOp.apply` takes **per-token** advantages `[B, T]` (use
`expand_advantages`), while `TritonGRPOLossOp.apply` takes **per-sequence**
advantages `[B]` and gathers per token inside the kernel. Prefer `forward`, whose
signature is identical across backends.

## Backends

| Backend | Wrapper | Native symbol | Status |
| --- | --- | --- | --- |
| CUDA | `TritonGRPOLossOp` | Triton JIT kernels | Fused forward + analytic backward. |
| ROCm | `TritonGRPOLossOp` | Triton JIT kernels | Same kernels; requires Triton. |
| PyTorch fallback | `NativeGRPOLossOp` | None | Reference path; CPU and Triton-less GPUs. |

The Triton op fuses three kernels: `_group_norm_kernel` (per-group reward mean/std in
registers), and token-parallel `_grpo_fwd_kernel` / `_grpo_bwd_kernel`. Each token gathers
its advantage on the fly (`seq_id = token_index // completion_len`), so the broadcasted
`[B, T]` advantage tensor is never materialized.

## Tensor Contract

| Argument | Shape | Dtype | Requirements |
| --- | --- | --- | --- |
| `current_logps` | `[B, T]` | float (fp32 recommended) | Differentiable input; on the target device. |
| `old_logps` | `[B, T]` | float | Constant (no grad). |
| `ref_logps` | `[B, T]` | float | Constant (no grad). |
| `rewards` | `[B]` | float | One scalar per sequence. |
| `completion_mask` | `[B, T]` | bool / {0,1} | 2-D; `True` marks active tokens. |
| `loss` (output) | scalar | float32 | `policy_loss + beta * kl`. |
| `policy_loss`, `kl` (output) | scalar | float32 | Detached reporting values. |

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Fix incorrect gradient contract for policy_loss and kl outputs.

The docs say these are “detached reporting values,” but the implementation returns them directly (no .detach()), so they remain connected to autograd. Please update the wording (or detach in code if that behavior is intentional).


## Dispatch Behavior

`kernel_registry.get_op("grpo_loss")` selects `TritonGRPOLossOp` first on CUDA and ROCm,
and `NativeGRPOLossOp` on CPU or when Triton / the CUDA backend is unavailable. The Triton
op raises if handed CPU tensors, so on CPU always use the native op (the registry does this
automatically).

## Accuracy

Reference semantics (matching `examples/grpo_single_gpu.py`):

```python
# advantages: group-normalized rewards (population std, unbiased=False)
grouped = rewards.view(-1, samples_per_prompt)
adv = (grouped - grouped.mean(1, keepdim=True)) / grouped.std(1, keepdim=True, unbiased=False).clamp_min(1e-6)
adv = adv.reshape(-1)[:, None].expand_as(completion_mask)

ratio = torch.exp(current_logps - old_logps)
policy = -torch.minimum(ratio * adv, torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * adv)
diff = ref_logps - current_logps
kl = torch.exp(diff) - diff - 1.0 # k3 estimator
loss = masked_mean(policy, completion_mask) + beta * masked_mean(kl, completion_mask)
```

The Triton op matches the native reference (forward and backward) to `atol=1e-4` in fp32.
Composing with the dispatched CUDA fused logp matches a torch oracle to `atol=1e-3`.

## Performance Notes

```bash
python benchmarks/benchmark_grpo_loss.py
python benchmarks/benchmark_grpo_loss.py --configs "64,8,512;256,16,1024"
```

Indicative results (RTX PRO 6000, SM120, fp32):

| shape (prompts × samples × len) | tokens | forward speedup | fwd+bwd speedup | peak VRAM (native → Triton) |
| --- | --- | --- | --- | --- |
| 64 × 8 × 512 | 0.26M | 4.7× | 3.0× | 10 MB → 1 MB |
| 128 × 8 × 1024 | 1.05M | 3.2× | 2.7× | 40 MB → 4 MB |
| 256 × 16 × 1024 | 4.19M | 3.0× | 3.0× | 160 MB → 16 MB |

The ~10× VRAM reduction comes from not materializing the broadcasted advantage and
per-token surrogate/KL intermediates.

## Tests

```bash
python -m pytest tests/test_grpo_loss.py -v
```

Covers the native reference, Triton forward/backward vs native, the
`logp → grpo_loss` pipeline, masked-token invariance, an SGD loss step, and registry
dispatch. Triton tests skip without CUDA + Triton.

## Known Limitations

- `TritonGRPOLossOp` requires CUDA tensors and a working Triton install; CPU falls back to
the native op.
- `old_logps` / `ref_logps` are precomputed per-token constants — only the current policy
is differentiated (standard for GRPO).
- Inputs are dense `[B, T]` (rectangular). Variable-length completions should be padded and
masked via `completion_mask`.
- Group reward normalization is a separate lightweight kernel (one scalar per sequence);
it is intentionally not fused into the token kernel, which would trade away
token-level parallelism.

## Implementation Files

- `rl_engine/kernels/ops/pytorch/loss/grpo_loss.py`
- `rl_engine/kernels/ops/triton/triton_grpo_loss.py`
- `rl_engine/kernels/registry.py`
- `tests/test_grpo_loss.py`
- `benchmarks/benchmark_grpo_loss.py`