-
Notifications
You must be signed in to change notification settings - Fork 20
[FEAT][kernels]: implement fused GRPO loss with in-place group reward normalization #93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
13195f5
25e19bf
f434e31
e851c48
a4ca453
f1bc93a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,29 +7,24 @@ completion tokens. It targets the GRPO training step, where a naive PyTorch impl | |
| allocates several broadcasted `[batch, completion_len]` intermediates and a per-token | ||
| advantage tensor. | ||
|
|
||
| The operator consumes per-token **log-probs**, so it composes directly with the | ||
| [Fused LogP](fused-logp.md) operator: | ||
| The operator consumes per-token **log-probs**, so it composes directly with the [Fused LogP](fused-logp.md) operator: | ||
|
|
||
| ``` | ||
| logits --[logp op]--> current_logps --[grpo_loss op]--> loss | ||
| logits --[logp op]--> logps --[grpo_loss op]--> loss | ||
| ``` | ||
|
|
||
| It does not re-implement `log_softmax`; produce `current_logps` with the `logp` operator | ||
| (or any equivalent) first. | ||
|
|
||
| ## Entry Point | ||
|
|
||
| ```python | ||
| from rl_engine.kernels.registry import kernel_registry | ||
|
|
||
| grpo = kernel_registry.get_op("grpo_loss") | ||
| grpo_loss = kernel_registry.get_op("grpo_loss") | ||
|
|
||
| loss, policy_loss, kl = grpo.forward( | ||
| current_logps, # [B, T] current-policy per-token log-probs (differentiable) | ||
| old_logps, # [B, T] behavior-policy log-probs (constant) | ||
| ref_logps, # [B, T] reference-model log-probs (constant) | ||
| rewards, # [B] one scalar reward per sequence | ||
| completion_mask, # [B, T] bool; True = active completion token | ||
| loss, policy_loss, kl = grpo_loss( | ||
| current_logps, # [B, T] current policy logps (differentiable) | ||
| old_logps, # [B, T] inference engine log-probs | ||
| ref_logps, # [B, T] reference model log-probs | ||
| rewards, # [B] | ||
| completion_mask, # [B, T] | ||
| clip_eps=0.2, | ||
| beta=0.04, | ||
| samples_per_prompt=8, # uniform groups; or pass group_boundaries=[...] | ||
|
|
@@ -38,40 +33,20 @@ loss, policy_loss, kl = grpo.forward( | |
| loss.backward() # gradient flows into current_logps | ||
| ``` | ||
|
|
||
| `B = num_prompts * samples_per_prompt`. `forward` returns a 3-tuple | ||
| `(loss, policy_loss, kl)`; only `loss` is differentiable (`policy_loss` and `kl` are | ||
| detached reporting scalars). | ||
| Note: `B = num_prompts * samples_per_prompt`. The [B, T] tensors are made contiguous and flattened to 1-D [N = B*T] before the kernel launch. | ||
|
|
||
| ### Group specification | ||
|
|
||
| Provide **exactly one** of: | ||
|
|
||
| - `samples_per_prompt: int` — uniform groups (every prompt has the same number of samples). | ||
| - `group_boundaries` — CSR-style offsets of length `num_groups + 1` (e.g. `[0, 8, 16, 24]`) | ||
| for variable-sized groups. | ||
|
|
||
| ### Lower-level helpers | ||
|
|
||
| ```python | ||
| adv = grpo.group_advantages(rewards, samples_per_prompt=8) # [B] per-sequence advantages | ||
| loss, policy_loss, kl = grpo.apply( # skip reward normalization | ||
| current_logps, old_logps, ref_logps, advantages, completion_mask, | ||
| clip_eps=0.2, beta=0.04, | ||
| ) | ||
| ``` | ||
|
|
||
| !!! note "`apply` advantage layout differs by backend" | ||
| `NativeGRPOLossOp.apply` takes **per-token** advantages `[B, T]` (use | ||
| `expand_advantages`), while `TritonGRPOLossOp.apply` takes **per-sequence** | ||
| advantages `[B]` and gathers per token inside the kernel. Prefer `forward`, whose | ||
| signature is identical across backends. | ||
|
|
||
| ## Backends | ||
|
|
||
| | Backend | Wrapper | Native symbol | Status | | ||
| | --- | --- | --- | --- | | ||
| | CUDA | `TritonGRPOLossOp` | Triton JIT kernels | Fused forward + analytic backward. | | ||
| | ROCm | `TritonGRPOLossOp` | Triton JIT kernels | Same kernels; requires Triton. | | ||
| | PyTorch fallback | `NativeGRPOLossOp` | None | Reference path; CPU and Triton-less GPUs. | | ||
|
|
||
| The Triton op fuses three kernels: `_group_norm_kernel` (per-group reward mean/std in | ||
|
|
@@ -83,21 +58,14 @@ its advantage on the fly (`seq_id = token_index // completion_len`), so the broa | |
|
|
||
| | Argument | Shape | Dtype | Requirements | | ||
| | --- | --- | --- | --- | | ||
| | `current_logps` | `[B, T]` | float (fp32 recommended) | Differentiable input; on the target device. | | ||
| | `current_logps` | `[B, T]` | float (fp32 recommended) | Differentiable input | | ||
| | `old_logps` | `[B, T]` | float | Constant (no grad). | | ||
| | `ref_logps` | `[B, T]` | float | Constant (no grad). | | ||
| | `rewards` | `[B]` | float | One scalar per sequence. | | ||
| | `completion_mask` | `[B, T]` | bool / {0,1} | 2-D; `True` marks active tokens. | | ||
| | `loss` (output) | scalar | float32 | `policy_loss + beta * kl`. | | ||
| | `policy_loss`, `kl` (output) | scalar | float32 | Detached reporting values. | | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix incorrect gradient contract for The docs say these are “detached reporting values,” but the implementation returns them directly (no |
||
|
|
||
| ## Dispatch Behavior | ||
|
|
||
| `kernel_registry.get_op("grpo_loss")` selects `TritonGRPOLossOp` first on CUDA and ROCm, | ||
| and `NativeGRPOLossOp` on CPU or when Triton / the CUDA backend is unavailable. The Triton | ||
| op raises if handed CPU tensors, so on CPU always use the native op (the registry does this | ||
| automatically). | ||
|
|
||
| ## Accuracy | ||
|
|
||
| Reference semantics (matching `examples/grpo_single_gpu.py`): | ||
|
|
@@ -146,18 +114,6 @@ Covers the native reference, Triton forward/backward vs native, the | |
| `logp → grpo_loss` pipeline, masked-token invariance, an SGD loss step, and registry | ||
| dispatch. Triton tests skip without CUDA + Triton. | ||
|
|
||
| ## Known Limitations | ||
|
|
||
| - `TritonGRPOLossOp` requires CUDA tensors and a working Triton install; CPU falls back to | ||
| the native op. | ||
| - `old_logps` / `ref_logps` are precomputed per-token constants — only the current policy | ||
| is differentiated (standard for GRPO). | ||
| - Inputs are dense `[B, T]` (rectangular). Variable-length completions should be padded and | ||
| masked via `completion_mask`. | ||
| - Group reward normalization is a separate lightweight kernel (one scalar per sequence); | ||
| it is intentionally not fused into the token kernel, which would trade away | ||
| token-level parallelism. | ||
|
|
||
| ## Implementation Files | ||
|
|
||
| - `rl_engine/kernels/ops/pytorch/loss/grpo_loss.py` | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a language identifier to the fenced code block.
This fence is missing a language tag (markdownlint MD040). Use something like
```textfor the pipeline diagram.Suggested doc fix
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In
@docs/operators/grpo-loss.mdaround lines 12 - 14, The fenced code blockcontaining the pipeline diagram "logits --[logp op]--> logps --[grpo_loss op]-->
loss" is missing a language identifier; update the fence to include a language
tag (e.g., use ```text) so the block follows markdownlint MD040; locate the
fence around that pipeline diagram in docs/operators/grpo-loss.md and add the
language identifier to the opening backticks.