Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
(doc) add grpo-loss kernel doc
  • Loading branch information
KJLdefeated committed Jun 9, 2026
commit e851c484a5cf1fbc1f399569a46a1fe9563fa708
66 changes: 11 additions & 55 deletions docs/operators/grpo-loss.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,24 @@ completion tokens. It targets the GRPO training step, where a naive PyTorch impl
allocates several broadcasted `[batch, completion_len]` intermediates and a per-token
advantage tensor.

The operator consumes per-token **log-probs**, so it composes directly with the
[Fused LogP](fused-logp.md) operator:
The operator consumes per-token **log-probs**, so it composes directly with the [Fused LogP](fused-logp.md) operator:

```
logits --[logp op]--> current_logps --[grpo_loss op]--> loss
logits --[logp op]--> logps --[grpo_loss op]--> loss
```
Comment on lines +12 to +14

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Add a language identifier to the fenced code block.

This fence is missing a language tag (markdownlint MD040). Use something like ```text for the pipeline diagram.

Suggested doc fix
-```
+```text
 logits --[logp op]--> logps --[grpo_loss op]--> loss
</details>

<details>
<summary>🧰 Tools</summary>

<details>
<summary>🪛 markdownlint-cli2 (0.22.1)</summary>

[warning] 12-12: Fenced code blocks should have a language specified

(MD040, fenced-code-language)

</details>

</details>

<details>
<summary>🤖 Prompt for AI Agents</summary>

Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In @docs/operators/grpo-loss.md around lines 12 - 14, The fenced code block
containing the pipeline diagram "logits --[logp op]--> logps --[grpo_loss op]-->
loss" is missing a language identifier; update the fence to include a language
tag (e.g., use ```text) so the block follows markdownlint MD040; locate the
fence around that pipeline diagram in docs/operators/grpo-loss.md and add the
language identifier to the opening backticks.


</details>

<!-- fingerprinting:phantom:triton:hawk -->

<!-- cr-comment:v1:13c3393cb81ae72a6c408c62 -->

_Source: Linters/SAST tools_

<!-- This is an auto-generated comment by CodeRabbit -->


It does not re-implement `log_softmax`; produce `current_logps` with the `logp` operator
(or any equivalent) first.

## Entry Point

```python
from rl_engine.kernels.registry import kernel_registry

grpo = kernel_registry.get_op("grpo_loss")
grpo_loss = kernel_registry.get_op("grpo_loss")

loss, policy_loss, kl = grpo.forward(
current_logps, # [B, T] current-policy per-token log-probs (differentiable)
old_logps, # [B, T] behavior-policy log-probs (constant)
ref_logps, # [B, T] reference-model log-probs (constant)
rewards, # [B] one scalar reward per sequence
completion_mask, # [B, T] bool; True = active completion token
loss, policy_loss, kl = grpo_loss(
current_logps, # [B, T] current policy logps (differentiable)
old_logps, # [B, T] inference engine log-probs
ref_logps, # [B, T] reference model log-probs
rewards, # [B]
completion_mask, # [B, T]
clip_eps=0.2,
beta=0.04,
samples_per_prompt=8, # uniform groups; or pass group_boundaries=[...]
Expand All @@ -38,40 +33,20 @@ loss, policy_loss, kl = grpo.forward(
loss.backward() # gradient flows into current_logps
```

`B = num_prompts * samples_per_prompt`. `forward` returns a 3-tuple
`(loss, policy_loss, kl)`; only `loss` is differentiable (`policy_loss` and `kl` are
detached reporting scalars).
Note: `B = num_prompts * samples_per_prompt`. The [B, T] tensors are made contiguous and flattened to 1-D [N = B*T] before the kernel launch.

### Group specification

Provide **exactly one** of:

- `samples_per_prompt: int` — uniform groups (every prompt has the same number of samples).
- `group_boundaries` — CSR-style offsets of length `num_groups + 1` (e.g. `[0, 8, 16, 24]`)
for variable-sized groups.

### Lower-level helpers

```python
adv = grpo.group_advantages(rewards, samples_per_prompt=8) # [B] per-sequence advantages
loss, policy_loss, kl = grpo.apply( # skip reward normalization
current_logps, old_logps, ref_logps, advantages, completion_mask,
clip_eps=0.2, beta=0.04,
)
```

!!! note "`apply` advantage layout differs by backend"
`NativeGRPOLossOp.apply` takes **per-token** advantages `[B, T]` (use
`expand_advantages`), while `TritonGRPOLossOp.apply` takes **per-sequence**
advantages `[B]` and gathers per token inside the kernel. Prefer `forward`, whose
signature is identical across backends.

## Backends

| Backend | Wrapper | Native symbol | Status |
| --- | --- | --- | --- |
| CUDA | `TritonGRPOLossOp` | Triton JIT kernels | Fused forward + analytic backward. |
| ROCm | `TritonGRPOLossOp` | Triton JIT kernels | Same kernels; requires Triton. |
| PyTorch fallback | `NativeGRPOLossOp` | None | Reference path; CPU and Triton-less GPUs. |

The Triton op fuses three kernels: `_group_norm_kernel` (per-group reward mean/std in
Expand All @@ -83,21 +58,14 @@ its advantage on the fly (`seq_id = token_index // completion_len`), so the broa

| Argument | Shape | Dtype | Requirements |
| --- | --- | --- | --- |
| `current_logps` | `[B, T]` | float (fp32 recommended) | Differentiable input; on the target device. |
| `current_logps` | `[B, T]` | float (fp32 recommended) | Differentiable input |
| `old_logps` | `[B, T]` | float | Constant (no grad). |
| `ref_logps` | `[B, T]` | float | Constant (no grad). |
| `rewards` | `[B]` | float | One scalar per sequence. |
| `completion_mask` | `[B, T]` | bool / {0,1} | 2-D; `True` marks active tokens. |
| `loss` (output) | scalar | float32 | `policy_loss + beta * kl`. |
| `policy_loss`, `kl` (output) | scalar | float32 | Detached reporting values. |

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Fix incorrect gradient contract for policy_loss and kl outputs.

The docs say these are “detached reporting values,” but the implementation returns them directly (no .detach()), so they remain connected to autograd. Please update the wording (or detach in code if that behavior is intentional).


## Dispatch Behavior

`kernel_registry.get_op("grpo_loss")` selects `TritonGRPOLossOp` first on CUDA and ROCm,
and `NativeGRPOLossOp` on CPU or when Triton / the CUDA backend is unavailable. The Triton
op raises if handed CPU tensors, so on CPU always use the native op (the registry does this
automatically).

## Accuracy

Reference semantics (matching `examples/grpo_single_gpu.py`):
Expand Down Expand Up @@ -146,18 +114,6 @@ Covers the native reference, Triton forward/backward vs native, the
`logp → grpo_loss` pipeline, masked-token invariance, an SGD loss step, and registry
dispatch. Triton tests skip without CUDA + Triton.

## Known Limitations

- `TritonGRPOLossOp` requires CUDA tensors and a working Triton install; CPU falls back to
the native op.
- `old_logps` / `ref_logps` are precomputed per-token constants — only the current policy
is differentiated (standard for GRPO).
- Inputs are dense `[B, T]` (rectangular). Variable-length completions should be padded and
masked via `completion_mask`.
- Group reward normalization is a separate lightweight kernel (one scalar per sequence);
it is intentionally not fused into the token kernel, which would trade away
token-level parallelism.

## Implementation Files

- `rl_engine/kernels/ops/pytorch/loss/grpo_loss.py`
Expand Down
Loading