Skip to content

[FEAT][kernels]: implement Fused Policy Ratio and KL Penalty kernel#97

Open
KJLdefeated wants to merge 1 commit into
RL-Align:mainfrom
KJLdefeated:feat/fused-ratio-kl
Open

[FEAT][kernels]: implement Fused Policy Ratio and KL Penalty kernel#97
KJLdefeated wants to merge 1 commit into
RL-Align:mainfrom
KJLdefeated:feat/fused-ratio-kl

Conversation

@KJLdefeated

@KJLdefeated KJLdefeated commented Jun 11, 2026

Copy link
Copy Markdown
Contributor

#40

Summary

  • Adds a new fused Policy Ratio + KL Penalty (ratio_kl) operator that turns raw logits into the per-token PPO/GRPO importance ratio and reference-KL penalty via online softmax.
  • Because grpo_loss use same ratio and kl computation kernel, I refactors grpo_loss ([FEAT][kernels]: implement fused GRPO loss with in-place group reward normalization #93 ) to consume logits and build on this operator, deleting its bespoke forward/backward Triton kernels (group norm kernel is kept). Future PPO Loss can also use this kernel.
  • The ratio_kl forward peak VRAM is independent of vocab size; the fused path is 5–10× faster and trades GBs of log-softmax scratch for ~0, with the win growing as vocab grows.

What changed and why

New operator — ratio_kl

  • rl_engine/kernels/ops/triton/loss/ratio_kl.py:
    • New Triton kernel + _RatioKLFunction autograd + TritonRatioKLOp.
    • Streams an online LSE over the vocab in BLOCK_V tiles for policy and ref
  • rl_engine/kernels/ops/pytorch/loss/ratio_kl.py: NativeRatioKLOp reference, reusing NativeLogpOp
  • tests/test_ratio_kl.py: native-vs-reference, masked neutrality, ratio==1 when old==policy, grad-to-policy-only, Triton fwd/bwd vs native (V=64 and 50257), registry dispatch.
  • benchmarks/benchmark_ratio_kl.py : New benchmark (latency + peak VRAM, vocab sweep).
  • docs/operators/ratio-kl.md: operator doc.

GRPO loss refactor — compose ratio_kl instead of a bespoke kernel

  • rl_engine/kernels/ops/triton/{triton_grpo_loss.py → loss/grpo_loss.py}:
    • Moved into loss/;
    • Dropped _grpo_fwd/_grpo_bwd kernels and _GRPOLossFunction;
    • Now calls TritonRatioKLOp and keeps only _group_norm_kernel for reward normalization.
    • Make cleaner code.
  • benchmarks/benchmark_grpo_loss.py: Updated to logits API
  • tests/test_grpo_loss.py: Loss/grad/masking/SGD tests reworked for logits API
  • docs/operators/grpo-loss.md

Wiring

  • rl_engine/kernels/registry.py: Register ratio_kl (Triton/PyTorch) for cuda/rocm/cpu; repoint TRITON_GRPO_LOSS to triton.loss.grpo_loss

Performance (fp16, B=32, T=256)

shape (P×S×L×V) fwd speedup fwd+bwd speedup peak fwd VRAM (native → Triton)
4×8×256×32768 5.2× 2.8× 2048 MB → ~0 MB
4×8×256×50257 7.3× 2.4× 3141 MB → ~0 MB
4×8×256×131072 10.3× 3.4× 8192 MB → ~0 MB

Notes

  • PPO Wrapper will be implemented in the future
  • The current CUDA Logp kernel is lacking backward function.

Summary by CodeRabbit

Release Notes

  • New Features

    • Introduced new ratio_kl operator for fused policy ratio and KL penalty computation in PPO/GRPO workflows.
    • Refactored GRPO Loss operator to consume logits directly, improving performance through fused computation instead of pre-computed log-probabilities.
  • Documentation

    • Updated GRPO Loss documentation with new API signature and implementation details.
    • Added new documentation page for ratio_kl operator with tensor contract and performance specifications.
  • Tests

    • Added comprehensive test suite for ratio_kl operator covering correctness, masking, and gradient flow.
    • Updated GRPO Loss tests to align with logits-based inputs and new API.
  • Benchmarking

    • Added benchmark scripts for performance measurement of loss operators across backends.

@coderabbitai

coderabbitai Bot commented Jun 11, 2026

Copy link
Copy Markdown

Review Change Stack

📝 Walkthrough

Walkthrough

This PR introduces a fused ratio_kl operator (policy ratio + KL penalty) with Triton and PyTorch implementations, refactors GRPO loss to use it instead of per-token log-probabilities, updates benchmarks with vocab-aware configs, adds comprehensive test coverage, and documents the new logits-based pipeline.

Changes

Fused Ratio-KL and GRPO Loss Integration

Layer / File(s) Summary
Triton Ratio-KL Kernels and Autograd
rl_engine/kernels/ops/triton/loss/ratio_kl.py
Forward/backward Triton kernels compute log-sum-exp softmax, derive selected-action log-probabilities, compute ratio = exp(delta) and KL penalty from policy/ref logits via autograd wrapper; stores intermediate d and logz for backward pass.
PyTorch Fallback Ratio-KL and Tests
rl_engine/kernels/ops/pytorch/loss/ratio_kl.py, tests/test_ratio_kl.py
Native fallback uses log-prob selection with clamped indices; test suite validates numerical agreement, masked-token neutrality, gradient flow to policy logits, Triton/native equivalence, and registry dispatch.
Native GRPO Loss Refactoring
rl_engine/kernels/ops/pytorch/loss/grpo_loss.py
NativeGRPOLossOp now instantiates NativeRatioKLOp and accepts policy_logits/ref_logits/action_ids instead of per-token log-probs; derives advantages from rewards and applies clipped surrogate with masked means.
Triton GRPO Loss Refactoring
rl_engine/kernels/ops/triton/loss/grpo_loss.py
TritonGRPOLossOp integrates TritonRatioKLOp, replaces custom autograd kernels with PyTorch-based loss computation, adds expand_advantages and _masked_mean helpers, maintains group-normalized advantage normalization.
Kernel Registry and Backend Dispatch
rl_engine/kernels/registry.py, rl_engine/kernels/ops/triton/loss/__init__.py
Registers TRITON_RATIO_KL and PYTORCH_RATIO_KL backends in OpBackend enum; extends _priority_map for cuda/rocm/cpu to route ratio_kl operator; updates TRITON_GRPO_LOSS module path.
GRPO Loss Tests
tests/test_grpo_loss.py
Test suite refactored for logits-based inputs: validates forward/backward matching between native and Triton, gradient flow to policy logits only, masked-token invariance, group-advantages, advantage application (apply), and SGD-style loss descent.
Benchmarks
benchmarks/benchmark_grpo_loss.py, benchmarks/benchmark_ratio_kl.py
GRPO benchmark updated to 4-tuple configs (prompts, samples, len, vocab) with logits inputs; new ratio_kl benchmark measures performance and drift between reference and candidate with configurable shapes, timing, and CSV output.
Operator Documentation
docs/operators/ratio-kl.md, docs/operators/grpo-loss.md, docs/.nav.yml, docs/operators/README.md
New ratio_kl docs specify tensor contract, reference semantics, Triton/PyTorch backends, and performance results; GRPO docs updated to reflect logits-based interface and ratio_kl composition; navigation index updated.

Sequence Diagram(s)

sequenceDiagram
  participant User
  participant NativeGRPOLossOp
  participant NativeRatioKLOp
  participant Advantage as Advantage Normalization
  participant Loss as Loss Computation
  User->>NativeGRPOLossOp: forward(policy_logits, ref_logits, action_ids, old_logps, rewards)
  NativeGRPOLossOp->>Advantage: group_advantages(rewards, boundaries)
  Advantage-->>NativeGRPOLossOp: sample_advantages
  NativeGRPOLossOp->>NativeRatioKLOp: __call__(policy_logits, ref_logits, action_ids, ...)
  NativeRatioKLOp->>NativeRatioKLOp: select log-probs, compute delta/diff
  NativeRatioKLOp-->>NativeGRPOLossOp: ratio, kl_terms
  NativeGRPOLossOp->>Loss: expand_advantages, compute clipped surrogate
  Loss-->>NativeGRPOLossOp: loss, policy_loss, kl
  NativeGRPOLossOp-->>User: (loss, policy_loss, kl), backward graph
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related issues

  • RL-Align/RL-Kernel#40: This PR implements the exact fused policy-ratio + KL kernel (Triton + PyTorch fallback) and integrates it into GRPO loss, directly addressing the feature specification and function-level design.

Possibly related PRs

  • RL-Align/RL-Kernel#93: Both PRs modify GRPO loss operator signatures, benchmarks, registry dispatch, and test structure; PR #97 builds on the fused GRPO/group-normalization work from #93.

Suggested reviewers

  • inaniloquentee
  • Flink-ddd

Poem

🐰 A kernel so fused, logits refined,
Policy ratios and KL aligned!
Triton computes while PyTorch plays,
GRPO loss brightens all our days. ✨

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 17.72% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[FEAT][kernels]: implement Fused Policy Ratio and KL Penalty kernel' is a direct, specific description of the main feature introduced by this PR.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
rl_engine/kernels/ops/pytorch/loss/grpo_loss.py (1)

192-203: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Mirror Triton’s group_boundaries validation here.

_resolve_group_ids() currently accepts translated boundary vectors because it only checks sizes. For example, a 4-sequence batch would accept [2, 4, 6] and silently treat it like [0, 2, 4], while TritonGRPOLossOp._build_bounds() rejects the same input because it requires batch-local offsets that start at 0 and end at num_sequences. That makes the shared API behave differently across backends depending on dispatch target.

Suggested fix
         boundaries = torch.as_tensor(group_boundaries, device=device, dtype=torch.long)
         if boundaries.ndim != 1 or boundaries.numel() < 2:
             raise ValueError("group_boundaries must be a 1D tensor of length num_groups + 1.")
+        if int(boundaries[0].item()) != 0 or int(boundaries[-1].item()) != num_sequences:
+            raise ValueError("group_boundaries must start at 0 and end at num_sequences.")
         sizes = boundaries[1:] - boundaries[:-1]
-
-        if int(sizes.sum().item()) != num_sequences:
-            raise ValueError(
-                f"group sizes sum to {int(sizes.sum().item())} but there are "
-                f"{num_sequences} sequences."
-            )
         if bool((sizes < 1).any().item()):
             raise ValueError("each group must contain at least one sequence.")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/ops/pytorch/loss/grpo_loss.py` around lines 192 - 203,
_update _resolve_group_ids() to mirror TritonGRPOLossOp._build_bounds()
validation: ensure group_boundaries is a 1D integer tensor of length
num_groups+1 on the correct device, that boundaries[0] == 0 and boundaries[-1]
== num_sequences, and that the sequence is non-decreasing (each boundary >=
previous) and all group sizes are >=1; raise a ValueError with a clear message
if any of these checks fail so the PyTorch path rejects translated/offset
vectors like [2,4,6] just like Triton does.
🧹 Nitpick comments (2)
docs/operators/ratio-kl.md (1)

9-11: 💤 Low value

Consider adding a language identifier to the fenced code block.

The ASCII pipeline diagram lacks a language specifier. Adding text improves rendering and addresses the linter warning.

📝 Proposed fix
-```
+```text
 logits --[ratio_kl op]--> (ratio, kl) --[clipped surrogate + beta*kl]--> loss
</details>

<details>
<summary>🤖 Prompt for AI Agents</summary>

Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In @docs/operators/ratio-kl.md around lines 9 - 11, The fenced ASCII pipeline
block containing "logits --[ratio_kl op]--> (ratio, kl) --[clipped surrogate +
beta*kl]--> loss" should include a language identifier (e.g., text) to satisfy
the linter and improve rendering; update the triple backtick fence that
surrounds the pipeline diagram to use ```text so the block becomes a text code
block while keeping the pipeline content unchanged.


</details>

<!-- cr-comment:v1:ade0dc832571fe86b6ad1795 -->

_Source: Linters/SAST tools_

</blockquote></details>
<details>
<summary>docs/operators/grpo-loss.md (1)</summary><blockquote>

`15-15`: _💤 Low value_

**Consider adding a language identifier to the fenced code block.**

The ASCII pipeline diagram uses a fenced code block without a language specifier. While this is valid Markdown, adding `text` improves rendering consistency and silences the linter warning.




<details>
<summary>📝 Proposed fix</summary>

```diff
-```
+```text
 logits --[ratio_kl op]--> (ratio, kl) --[group adv + clipped surrogate]--> loss
 ```
```
</details>

<details>
<summary>🤖 Prompt for AI Agents</summary>

Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In @docs/operators/grpo-loss.md at line 15, The fenced ASCII pipeline diagram
lacks a language tag; update the fenced code block that contains "logits
--[ratio_kl op]--> (ratio, kl) --[group adv + clipped surrogate]--> loss" to
include a language identifier (use "text") after the opening backticks so the
block becomes text ... to satisfy the linter and improve rendering
consistency.


</details>

<!-- cr-comment:v1:58983bcca01d03c1472ca2d1 -->

_Source: Linters/SAST tools_

</blockquote></details>

</blockquote></details>

<details>
<summary>🤖 Prompt for all review comments with AI agents</summary>

Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In @benchmarks/benchmark_grpo_loss.py:

  • Around line 93-109: The nested functions native_fwd, triton_fwd,
    native_fwd_bwd, and triton_fwd_bwd capture loop-scoped variables call_args and
    spp causing potential late-binding (Ruff B023); fix by binding those loop
    variables as default arguments in each function signature (e.g.,
    native_fwd(p=policy, call_args=call_args, spp=spp) and similarly triton_fwd, and
    for the backward helpers use p=pol_grad, call_args=call_args, spp=spp) so each
    closure retains the current iteration's values instead of the final ones.

In @docs/operators/grpo-loss.md:

  • Line 30: Doc parameter names are inconsistent: grpo_loss docs use
    completion_mask while ratio_kl docs and implementations use attention_mask;
    update the docs so the same tensor name is used and unambiguous. Pick one
    canonical name (e.g., attention_mask) and rename the parameter in
    docs/operators/grpo-loss.md (and any related doc examples) from completion_mask
    to attention_mask, or add a single clarifying sentence in grpo_loss docs stating
    that completion_mask is the same tensor passed as attention_mask to ratio_kl;
    ensure references to shapes ([B, T]) and callsites in grpo_loss, ratio_kl, and
    fused _ratio_kl are updated to match the chosen terminology.

In @docs/operators/ratio-kl.md:

  • Around line 43-48: Add a short introductory sentence before the formula to
    explain what it represents (for example: "The Triton backward kernel
    computes:"), add a language identifier to the fenced code block (e.g., text or python) around the line grad_policy_logits[v] = c * (1[v == action] - softmax_policy(v)), and either move or expand the following sentence ("so the
    backward also avoids materializing...") so it immediately relates to the formula
    (e.g., state that this expression means no [B, T, V] probability tensor is
    materialized), ensuring the formula is connected to the preceding discussion of
    backends and wrapper classes.

In @rl_engine/kernels/ops/triton/loss/ratio_kl.py:

  • Around line 38-67: The kernel reads an unclamped action id "a" (loaded as a =
    tl.load(action_ptr + row)) and then uses it to index policy_ptr/ref_ptr, which
    can cause out-of-bounds reads for active tokens; clamp "a" to the valid range
    [0, V-1] before any indexing (e.g., replace uses of "a" with a_safe =
    tl.minimum(tl.maximum(a, 0), V - 1) or tl.clamp if available) and ensure a_safe
    has the correct integer dtype matching row_off offsets; update the loads
    tl.load(policy_ptr + row_off + a) and tl.load(ref_ptr + row_off + a) to use
    a_safe so active-path indexing is always safe in the ratio_kl Triton forward
    kernel.

Outside diff comments:
In @rl_engine/kernels/ops/pytorch/loss/grpo_loss.py:

  • Around line 192-203: _update _resolve_group_ids() to mirror
    TritonGRPOLossOp._build_bounds() validation: ensure group_boundaries is a 1D
    integer tensor of length num_groups+1 on the correct device, that boundaries[0]
    == 0 and boundaries[-1] == num_sequences, and that the sequence is
    non-decreasing (each boundary >= previous) and all group sizes are >=1; raise a
    ValueError with a clear message if any of these checks fail so the PyTorch path
    rejects translated/offset vectors like [2,4,6] just like Triton does.

Nitpick comments:
In @docs/operators/grpo-loss.md:

  • Line 15: The fenced ASCII pipeline diagram lacks a language tag; update the
    fenced code block that contains "logits --[ratio_kl op]--> (ratio, kl) --[group
    adv + clipped surrogate]--> loss" to include a language identifier (use "text")
    after the opening backticks so the block becomes text ... to satisfy the
    linter and improve rendering consistency.

In @docs/operators/ratio-kl.md:

  • Around line 9-11: The fenced ASCII pipeline block containing "logits
    --[ratio_kl op]--> (ratio, kl) --[clipped surrogate + beta*kl]--> loss" should
    include a language identifier (e.g., text) to satisfy the linter and improve
    rendering; update the triple backtick fence that surrounds the pipeline diagram
    to use ```text so the block becomes a text code block while keeping the pipeline
    content unchanged.

</details>

<details>
<summary>🪄 Autofix (Beta)</summary>

Fix all unresolved CodeRabbit comments on this PR:

- [ ] <!-- {"checkboxId": "4b0d0e0a-96d7-4f10-b296-3a18ea78f0b9"} --> Push a commit to this branch (recommended)
- [ ] <!-- {"checkboxId": "ff5b1114-7d8c-49e6-8ac1-43f82af23a33"} --> Create a new PR with the fixes

</details>

---

<details>
<summary>ℹ️ Review info</summary>

<details>
<summary>⚙️ Run configuration</summary>

**Configuration used**: defaults

**Review profile**: CHILL

**Plan**: Pro

**Run ID**: `b91d50a6-caba-4711-b397-1074cc35634b`

</details>

<details>
<summary>📥 Commits</summary>

Reviewing files that changed from the base of the PR and between 4a9ca42a07c9da6b5634375e1043f9150b19a9e7 and 6a858ae077e3302f6c0456911a289f4cb11481e5.

</details>

<details>
<summary>📒 Files selected for processing (14)</summary>

* `benchmarks/benchmark_grpo_loss.py`
* `benchmarks/benchmark_ratio_kl.py`
* `docs/.nav.yml`
* `docs/operators/README.md`
* `docs/operators/grpo-loss.md`
* `docs/operators/ratio-kl.md`
* `rl_engine/kernels/ops/pytorch/loss/grpo_loss.py`
* `rl_engine/kernels/ops/pytorch/loss/ratio_kl.py`
* `rl_engine/kernels/ops/triton/loss/__init__.py`
* `rl_engine/kernels/ops/triton/loss/grpo_loss.py`
* `rl_engine/kernels/ops/triton/loss/ratio_kl.py`
* `rl_engine/kernels/registry.py`
* `tests/test_grpo_loss.py`
* `tests/test_ratio_kl.py`

</details>

</details>

<!-- This is an auto-generated comment by CodeRabbit for review status -->

Comment on lines +93 to +109
def native_fwd(p=policy):
with torch.no_grad():
native.forward(c, *call_args, samples_per_prompt=spp, **kwargs)
native.forward(p, *call_args, samples_per_prompt=spp, **kwargs)

def triton_fwd(c=current):
def triton_fwd(p=policy):
with torch.no_grad():
triton_op.forward(c, *call_args, samples_per_prompt=spp, **kwargs)
triton_op.forward(p, *call_args, samples_per_prompt=spp, **kwargs)

cur_grad = current.clone().requires_grad_(True)
pol_grad = policy.clone().requires_grad_(True)

def native_fwd_bwd(c=cur_grad):
loss, _, _ = native.forward(c, *call_args, samples_per_prompt=spp, **kwargs)
torch.autograd.grad(loss, c)
def native_fwd_bwd(p=pol_grad):
loss, _, _ = native.forward(p, *call_args, samples_per_prompt=spp, **kwargs)
torch.autograd.grad(loss, p)

def triton_fwd_bwd(c=cur_grad):
loss, _, _ = triton_op.forward(c, *call_args, samples_per_prompt=spp, **kwargs)
torch.autograd.grad(loss, c)
def triton_fwd_bwd(p=pol_grad):
loss, _, _ = triton_op.forward(p, *call_args, samples_per_prompt=spp, **kwargs)
torch.autograd.grad(loss, p)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Bind loop variables as default arguments to avoid late binding.

The nested functions capture call_args and spp from the loop scope without binding them as default arguments. While the functions are called immediately within the same iteration (making this safe in practice), it's still a code smell flagged by Ruff (B023). If the benchmark logic changes and these functions are stored or called later, they would reference the last loop iteration's values.

🛡️ Recommended fix
-        def native_fwd(p=policy):
+        def native_fwd(p=policy, args=call_args, s=spp):
             with torch.no_grad():
-                native.forward(p, *call_args, samples_per_prompt=spp, **kwargs)
+                native.forward(p, *args, samples_per_prompt=s, **kwargs)

-        def triton_fwd(p=policy):
+        def triton_fwd(p=policy, args=call_args, s=spp):
             with torch.no_grad():
-                triton_op.forward(p, *call_args, samples_per_prompt=spp, **kwargs)
+                triton_op.forward(p, *args, samples_per_prompt=s, **kwargs)

         pol_grad = policy.clone().requires_grad_(True)

-        def native_fwd_bwd(p=pol_grad):
-            loss, _, _ = native.forward(p, *call_args, samples_per_prompt=spp, **kwargs)
+        def native_fwd_bwd(p=pol_grad, args=call_args, s=spp):
+            loss, _, _ = native.forward(p, *args, samples_per_prompt=s, **kwargs)
             torch.autograd.grad(loss, p)

-        def triton_fwd_bwd(p=pol_grad):
-            loss, _, _ = triton_op.forward(p, *call_args, samples_per_prompt=spp, **kwargs)
+        def triton_fwd_bwd(p=pol_grad, args=call_args, s=spp):
+            loss, _, _ = triton_op.forward(p, *args, samples_per_prompt=s, **kwargs)
             torch.autograd.grad(loss, p)
🧰 Tools
🪛 Ruff (0.15.15)

[warning] 95-95: Function definition does not bind loop variable call_args

(B023)


[warning] 95-95: Function definition does not bind loop variable spp

(B023)


[warning] 99-99: Function definition does not bind loop variable call_args

(B023)


[warning] 99-99: Function definition does not bind loop variable spp

(B023)


[warning] 104-104: Function definition does not bind loop variable call_args

(B023)


[warning] 104-104: Function definition does not bind loop variable spp

(B023)


[warning] 108-108: Function definition does not bind loop variable call_args

(B023)


[warning] 108-108: Function definition does not bind loop variable spp

(B023)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@benchmarks/benchmark_grpo_loss.py` around lines 93 - 109, The nested
functions native_fwd, triton_fwd, native_fwd_bwd, and triton_fwd_bwd capture
loop-scoped variables call_args and spp causing potential late-binding (Ruff
B023); fix by binding those loop variables as default arguments in each function
signature (e.g., native_fwd(p=policy, call_args=call_args, spp=spp) and
similarly triton_fwd, and for the backward helpers use p=pol_grad,
call_args=call_args, spp=spp) so each closure retains the current iteration's
values instead of the final ones.

action_ids, # [B, T] token taken at each position
old_logps, # [B, T] cached behavior-policy log-probs
rewards, # [B]
completion_mask, # [B, T]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Description: Check mask parameter naming in grpo_loss and ratio_kl implementations

# Check grpo_loss operator signatures
echo "=== grpo_loss signatures ==="
ast-grep --pattern $'def forward($$$, completion_mask, $$$)'
ast-grep --pattern $'def forward($$$, attention_mask, $$$)'

# Check ratio_kl operator signatures  
echo "=== ratio_kl signatures ==="
rg -A 3 'def forward' rl_engine/kernels/ops/pytorch/loss/ratio_kl.py
rg -A 3 'def forward' rl_engine/kernels/ops/triton/loss/ratio_kl.py

Repository: RL-Align/RL-Kernel

Length of output: 3188


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Find mask parameter naming in docs =="
sed -n '1,120p' docs/operators/grpo-loss.md | nl -ba | sed -n '20,60p'
echo
sed -n '1,120p' docs/operators/ratio-kl.md | nl -ba | sed -n '1,80p'
echo
echo "== Search code for completion_mask and grpo_loss =="
rg -n --hidden --glob '!**/node_modules/**' --glob '!**/.git/**' "completion_mask" .
rg -n --hidden --glob '!**/node_modules/**' --glob '!**/.git/**' "grpo_loss" rl_engine docs .
echo
echo "== Search ratio_kl signatures for attention_mask/completion_mask =="
rg -n "def forward\(" rl_engine/kernels/ops -S --glob '*ratio_kl*'
rg -n "attention_mask" rl_engine/kernels/ops -S --glob '*ratio_kl*'
rg -n "completion_mask" rl_engine/kernels/ops -S --glob '*ratio_kl*'

Repository: RL-Align/RL-Kernel

Length of output: 164


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Docs: grpo-loss mask naming =="
cat -n docs/operators/grpo-loss.md | sed -n '1,120p'

echo
echo "== Docs: ratio-kl mask naming =="
cat -n docs/operators/ratio-kl.md | sed -n '1,120p'

echo
echo "== Code search: completion_mask and grpo_loss =="
rg -n "completion_mask" rl_engine docs || true
rg -n "grpo_loss" rl_engine docs || true

echo
echo "== Code search: ratio_kl call sites =="
rg -n "ratio_kl" rl_engine/kernels/ops -S || true

echo
echo "== Code search: attention_mask/completion_mask in ratio_kl =="
rg -n "attention_mask" rl_engine/kernels/ops -S --glob '*ratio*kl*' || true
rg -n "completion_mask" rl_engine/kernels/ops -S --glob '*ratio*kl*' || true

echo
echo "== Find grpo_loss operator implementation files =="
fd -e py -e md -e ts -e js grpo -a rl_engine docs || true
fd -e py -e md grpo_loss -a rl_engine docs || true

Repository: RL-Align/RL-Kernel

Length of output: 19926


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== grpo_loss: pytorch call to ratio_kl =="
cat -n rl_engine/kernels/ops/pytorch/loss/grpo_loss.py | sed -n '90,140p'

echo
echo "== grpo_loss: triton call to ratio_kl =="
cat -n rl_engine/kernels/ops/triton/loss/grpo_loss.py | sed -n '180,240p'

echo
echo "== ratio_kl: pytorch forward signature =="
cat -n rl_engine/kernels/ops/pytorch/loss/ratio_kl.py | sed -n '1,80p'

echo
echo "== ratio_kl: triton forward signature (ctx) =="
cat -n rl_engine/kernels/ops/triton/loss/ratio_kl.py | sed -n '110,220p'

Repository: RL-Align/RL-Kernel

Length of output: 11128


Align mask parameter naming across grpo_loss and ratio_kl docs/contracts.

  • docs/operators/grpo-loss.md documents the mask argument as completion_mask ([B, T]), and grpo_loss implementations take completion_mask and pass it positionally into the fused _ratio_kl call.
  • docs/operators/ratio-kl.md (and the ratio_kl implementations) name that same mask argument as attention_mask ([B, T] active tokens).

Rename one of the doc parameters for consistency (or explicitly document that completion_mask in grpo_loss is the same tensor as attention_mask consumed by ratio_kl), to avoid confusion when users call these ops directly.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@docs/operators/grpo-loss.md` at line 30, Doc parameter names are
inconsistent: grpo_loss docs use completion_mask while ratio_kl docs and
implementations use attention_mask; update the docs so the same tensor name is
used and unambiguous. Pick one canonical name (e.g., attention_mask) and rename
the parameter in docs/operators/grpo-loss.md (and any related doc examples) from
completion_mask to attention_mask, or add a single clarifying sentence in
grpo_loss docs stating that completion_mask is the same tensor passed as
attention_mask to ratio_kl; ensure references to shapes ([B, T]) and callsites
in grpo_loss, ratio_kl, and fused _ratio_kl are updated to match the chosen
terminology.

Comment on lines +43 to +48
```
grad_policy_logits[v] = c * (1[v == action] - softmax_policy(v))
```

so the backward also avoids materializing any `[B, T, V]` probability tensor (only the
unavoidable `[B, T, V]` gradient output is written).

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Incomplete or unclear context for the gradient formula.

The fenced code block at lines 43-45 shows a gradient formula but appears disconnected from surrounding text. The preceding paragraph (lines 36-42) discusses backends and wrapper classes, then the formula appears without introduction. Additionally, the formula block lacks a language identifier.

📋 Suggested improvements
  1. Add introductory text before the formula explaining what it represents (e.g., "The Triton backward kernel computes:").
  2. Add a language identifier to the code block (python or text).
  3. Consider whether the text at lines 47-48 ("so the backward also avoids...") should be moved before the formula or expanded to provide better context.

Example fix:

 | PyTorch fallback | `NativeRatioKLOp` | None | Reference path; CPU and Triton-less GPUs. |
 
+The Triton backward kernel computes the policy logits gradient analytically:
+
-```
+```python
 grad_policy_logits[v] = c * (1[v == action] - softmax_policy(v))
</details>

<details>
<summary>🧰 Tools</summary>

<details>
<summary>🪛 markdownlint-cli2 (0.22.1)</summary>

[warning] 43-43: Fenced code blocks should have a language specified

(MD040, fenced-code-language)

</details>

</details>

<details>
<summary>🤖 Prompt for AI Agents</summary>

Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In @docs/operators/ratio-kl.md around lines 43 - 48, Add a short introductory
sentence before the formula to explain what it represents (for example: "The
Triton backward kernel computes:"), add a language identifier to the fenced code
block (e.g., text or python) around the line grad_policy_logits[v] = c * (1[v == action] - softmax_policy(v)), and either move or expand the following
sentence ("so the backward also avoids materializing...") so it immediately
relates to the formula (e.g., state that this expression means no [B, T, V]
probability tensor is materialized), ensuring the formula is connected to the
preceding discussion of backends and wrapper classes.


</details>

<!-- fingerprinting:phantom:triton:puma -->

<!-- cr-comment:v1:bfe80326597ee1dea81e01e0 -->

_Source: Linters/SAST tools_

<!-- This is an auto-generated comment by CodeRabbit -->

Comment on lines +38 to +67
if active:
row_off = row.to(tl.int64) * V
a = tl.load(action_ptr + row)

max_p = -float("inf")
sum_p = 0.0
max_r = -float("inf")
sum_r = 0.0
for start in range(0, V, BLOCK_V):
cols = start + tl.arange(0, BLOCK_V)
cmask = cols < V
p = tl.load(policy_ptr + row_off + cols, mask=cmask, other=-float("inf")).to(tl.float32)
r = tl.load(ref_ptr + row_off + cols, mask=cmask, other=-float("inf")).to(tl.float32)

tmax_p = tl.max(p, axis=0)
nmax_p = tl.maximum(max_p, tmax_p)
sum_p = sum_p * tl.exp(max_p - nmax_p) + tl.sum(tl.exp(p - nmax_p), axis=0)
max_p = nmax_p

tmax_r = tl.max(r, axis=0)
nmax_r = tl.maximum(max_r, tmax_r)
sum_r = sum_r * tl.exp(max_r - nmax_r) + tl.sum(tl.exp(r - nmax_r), axis=0)
max_r = nmax_r

logz_p = max_p + tl.log(sum_p)
logz_r = max_r + tl.log(sum_r)
pa = tl.load(policy_ptr + row_off + a).to(tl.float32)
ra = tl.load(ref_ptr + row_off + a).to(tl.float32)
logp_p = pa - logz_p
logp_r = ra - logz_r

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if action_ids can be out-of-bounds at call sites

# Search for ratio_kl operator calls and examine action_ids sources
rg -nP -C5 '(TritonRatioKLOp|NativeRatioKLOp|ratio_kl.*\(|\.get_op\(["\']ratio_kl)' \
  --type=py \
  -g '!test*.py'

# Look for any validation or clamping of action_ids/token_ids before operator calls
rg -nP -C3 '(action_ids|token_ids).*(clamp|clip|clip_|where|masked_fill)' \
  --type=py \
  -g '!test*.py'

Repository: RL-Align/RL-Kernel

Length of output: 235


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Inspect Triton and PyTorch implementations around the referenced lines.
echo "== Triton file (ratio_kl.py) ==" 
sed -n '1,220p' rl_engine/kernels/ops/triton/loss/ratio_kl.py | nl -ba | sed -n '1,220p'

echo "== PyTorch file (ratio_kl.py) ==" 
sed -n '1,220p' rl_engine/kernels/ops/pytorch/loss/ratio_kl.py | nl -ba | sed -n '1,220p'


# Find call sites/usages of ratio_kl / ratio kl ops.
echo "== Search for ratio_kl call sites ==" 
rg -n --hidden --no-ignore-vcs --glob '!test*.py' \
  -S '(ratio_kl\.|ratio_kl\b|TritonRatioKLOp|NativeRatioKLOp)' \
  -g '*.py'

# Search for any clamping/validation of action_ids / token_ids before operator invocation.
echo "== Search for clamps on action_ids/token_ids ==" 
rg -n --hidden --no-ignore-vcs --glob '!test*.py' \
  -S 'action_ids|token_ids' \
  -g '*.py' | head -n 200

echo "== Search for clamp/clip usage near action_ids/token_ids ==" 
rg -n --hidden --no-ignore-vcs --glob '!test*.py' \
  -S '(action_ids|token_ids).{0,80}(clamp|clip|clip_|\bwhere\b|masked_fill)' \
  -g '*.py'

Repository: RL-Align/RL-Kernel

Length of output: 155


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Triton file (ratio_kl.py) ==" 
cat -n rl_engine/kernels/ops/triton/loss/ratio_kl.py | sed -n '1,220p'

echo "== PyTorch file (ratio_kl.py) ==" 
cat -n rl_engine/kernels/ops/pytorch/loss/ratio_kl.py | sed -n '1,220p'

echo "== Search for ratio_kl call sites ==" 
rg -n --hidden --no-ignore-vcs --glob '!test*.py' \
  'ratio_kl' \
  rl_engine -S

echo "== Search for clamps/validation of action_ids/token_ids ==" 
rg -n --hidden --no-ignore-vcs --glob '!test*.py' \
  '(action_ids|token_ids).*\.?(clamp|clip|where|masked_fill)\b|safe_ids|safe_.*ids' \
  rl_engine -S

echo "== Search for range checks involving action_ids/token_ids ==" 
rg -n --hidden --no-ignore-vcs --glob '!test*.py' \
  '(action_ids|token_ids).*(<|<=|>|>=).*(V|vocab|size|logits)' \
  rl_engine -S || true

Repository: RL-Align/RL-Kernel

Length of output: 12176


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Triton GRPO loss: around _ratio_kl invocation =="
rg -n "ratio, kl_terms = self\\._ratio_kl|self\\._ratio_kl\\(" \
  rl_engine/kernels/ops/triton/loss/grpo_loss.py
sed -n '160,260p' rl_engine/kernels/ops/triton/loss/grpo_loss.py

echo "== Triton GRPO loss: earlier preparation of action_ids / attention_mask =="
rg -n "action_ids|attention_mask" rl_engine/kernels/ops/triton/loss/grpo_loss.py
sed -n '1,160p' rl_engine/kernels/ops/triton/loss/grpo_loss.py

echo "== PyTorch GRPO loss: around _ratio_kl invocation =="
rg -n "ratio, kl_terms = self\\._ratio_kl|self\\._ratio_kl\\(" \
  rl_engine/kernels/ops/pytorch/loss/grpo_loss.py
sed -n '80,160p' rl_engine/kernels/ops/pytorch/loss/grpo_loss.py

echo "== PyTorch GRPO loss: earlier preparation of action_ids / attention_mask =="
rg -n "action_ids|attention_mask" rl_engine/kernels/ops/pytorch/loss/grpo_loss.py
sed -n '1,90p' rl_engine/kernels/ops/pytorch/loss/grpo_loss.py

echo "== reference_ops: gather_token_ids masked_fill context =="
sed -n '1,120p' rl_engine/testing/reference_ops.py
rg -n "gather_token_ids" rl_engine/testing/reference_ops.py

Repository: RL-Align/RL-Kernel

Length of output: 20105


Critical: Triton ratio_kl uses unclamped action_ids for active tokens

  • rl_engine/kernels/ops/triton/loss/ratio_kl.py loads a = tl.load(action_ptr + row) and indexes logits with tl.load(policy_ptr + row_off + a) / tl.load(ref_ptr + row_off + a) without clamping/validation.
  • Although the kernel only reads a when active = attention_mask_flat[row] != 0, any out-of-range action_ids at active positions (a < 0 or a >= V) can produce out-of-bounds reads and incorrect losses/grads.
  • PyTorch fallback rl_engine/kernels/ops/pytorch/loss/ratio_kl.py clamps action_ids via safe_ids = action_ids.clamp(0, logits.size(-1) - 1).long(), so behavior is inconsistent when upstream supplies invalid active ids.
  • No other range checks/clamping for action_ids were found in the repo call paths (only this PyTorch clamp).
🛡️ Recommended fix: Clamp `a` in the Triton forward kernel
     if active:
         row_off = row.to(tl.int64) * V
         a = tl.load(action_ptr + row)
+        # Clamp action_id to valid vocab range [0, V)
+        a = tl.maximum(0, tl.minimum(a, V - 1))
 
         max_p = -float("inf")
         sum_p = 0.0
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if active:
row_off = row.to(tl.int64) * V
a = tl.load(action_ptr + row)
max_p = -float("inf")
sum_p = 0.0
max_r = -float("inf")
sum_r = 0.0
for start in range(0, V, BLOCK_V):
cols = start + tl.arange(0, BLOCK_V)
cmask = cols < V
p = tl.load(policy_ptr + row_off + cols, mask=cmask, other=-float("inf")).to(tl.float32)
r = tl.load(ref_ptr + row_off + cols, mask=cmask, other=-float("inf")).to(tl.float32)
tmax_p = tl.max(p, axis=0)
nmax_p = tl.maximum(max_p, tmax_p)
sum_p = sum_p * tl.exp(max_p - nmax_p) + tl.sum(tl.exp(p - nmax_p), axis=0)
max_p = nmax_p
tmax_r = tl.max(r, axis=0)
nmax_r = tl.maximum(max_r, tmax_r)
sum_r = sum_r * tl.exp(max_r - nmax_r) + tl.sum(tl.exp(r - nmax_r), axis=0)
max_r = nmax_r
logz_p = max_p + tl.log(sum_p)
logz_r = max_r + tl.log(sum_r)
pa = tl.load(policy_ptr + row_off + a).to(tl.float32)
ra = tl.load(ref_ptr + row_off + a).to(tl.float32)
logp_p = pa - logz_p
logp_r = ra - logz_r
if active:
row_off = row.to(tl.int64) * V
a = tl.load(action_ptr + row)
# Clamp action_id to valid vocab range [0, V)
a = tl.maximum(0, tl.minimum(a, V - 1))
max_p = -float("inf")
sum_p = 0.0
max_r = -float("inf")
sum_r = 0.0
for start in range(0, V, BLOCK_V):
cols = start + tl.arange(0, BLOCK_V)
cmask = cols < V
p = tl.load(policy_ptr + row_off + cols, mask=cmask, other=-float("inf")).to(tl.float32)
r = tl.load(ref_ptr + row_off + cols, mask=cmask, other=-float("inf")).to(tl.float32)
tmax_p = tl.max(p, axis=0)
nmax_p = tl.maximum(max_p, tmax_p)
sum_p = sum_p * tl.exp(max_p - nmax_p) + tl.sum(tl.exp(p - nmax_p), axis=0)
max_p = nmax_p
tmax_r = tl.max(r, axis=0)
nmax_r = tl.maximum(max_r, tmax_r)
sum_r = sum_r * tl.exp(max_r - nmax_r) + tl.sum(tl.exp(r - nmax_r), axis=0)
max_r = nmax_r
logz_p = max_p + tl.log(sum_p)
logz_r = max_r + tl.log(sum_r)
pa = tl.load(policy_ptr + row_off + a).to(tl.float32)
ra = tl.load(ref_ptr + row_off + a).to(tl.float32)
logp_p = pa - logz_p
logp_r = ra - logz_r
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/ops/triton/loss/ratio_kl.py` around lines 38 - 67, The
kernel reads an unclamped action id "a" (loaded as a = tl.load(action_ptr +
row)) and then uses it to index policy_ptr/ref_ptr, which can cause
out-of-bounds reads for active tokens; clamp "a" to the valid range [0, V-1]
before any indexing (e.g., replace uses of "a" with a_safe =
tl.minimum(tl.maximum(a, 0), V - 1) or tl.clamp if available) and ensure a_safe
has the correct integer dtype matching row_off offsets; update the loads
tl.load(policy_ptr + row_off + a) and tl.load(ref_ptr + row_off + a) to use
a_safe so active-path indexing is always safe in the ratio_kl Triton forward
kernel.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants