[FEAT][kernels]: implement Fused Policy Ratio and KL Penalty kernel#97
[FEAT][kernels]: implement Fused Policy Ratio and KL Penalty kernel#97KJLdefeated wants to merge 1 commit into
Conversation
…Triton/CUDA/ROCm)
📝 WalkthroughWalkthroughThis PR introduces a fused ChangesFused Ratio-KL and GRPO Loss Integration
Sequence Diagram(s)sequenceDiagram
participant User
participant NativeGRPOLossOp
participant NativeRatioKLOp
participant Advantage as Advantage Normalization
participant Loss as Loss Computation
User->>NativeGRPOLossOp: forward(policy_logits, ref_logits, action_ids, old_logps, rewards)
NativeGRPOLossOp->>Advantage: group_advantages(rewards, boundaries)
Advantage-->>NativeGRPOLossOp: sample_advantages
NativeGRPOLossOp->>NativeRatioKLOp: __call__(policy_logits, ref_logits, action_ids, ...)
NativeRatioKLOp->>NativeRatioKLOp: select log-probs, compute delta/diff
NativeRatioKLOp-->>NativeGRPOLossOp: ratio, kl_terms
NativeGRPOLossOp->>Loss: expand_advantages, compute clipped surrogate
Loss-->>NativeGRPOLossOp: loss, policy_loss, kl
NativeGRPOLossOp-->>User: (loss, policy_loss, kl), backward graph
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related issues
Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
rl_engine/kernels/ops/pytorch/loss/grpo_loss.py (1)
192-203:⚠️ Potential issue | 🟠 Major | ⚡ Quick winMirror Triton’s
group_boundariesvalidation here.
_resolve_group_ids()currently accepts translated boundary vectors because it only checkssizes. For example, a 4-sequence batch would accept[2, 4, 6]and silently treat it like[0, 2, 4], whileTritonGRPOLossOp._build_bounds()rejects the same input because it requires batch-local offsets that start at0and end atnum_sequences. That makes the shared API behave differently across backends depending on dispatch target.Suggested fix
boundaries = torch.as_tensor(group_boundaries, device=device, dtype=torch.long) if boundaries.ndim != 1 or boundaries.numel() < 2: raise ValueError("group_boundaries must be a 1D tensor of length num_groups + 1.") + if int(boundaries[0].item()) != 0 or int(boundaries[-1].item()) != num_sequences: + raise ValueError("group_boundaries must start at 0 and end at num_sequences.") sizes = boundaries[1:] - boundaries[:-1] - - if int(sizes.sum().item()) != num_sequences: - raise ValueError( - f"group sizes sum to {int(sizes.sum().item())} but there are " - f"{num_sequences} sequences." - ) if bool((sizes < 1).any().item()): raise ValueError("each group must contain at least one sequence.")🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@rl_engine/kernels/ops/pytorch/loss/grpo_loss.py` around lines 192 - 203, _update _resolve_group_ids() to mirror TritonGRPOLossOp._build_bounds() validation: ensure group_boundaries is a 1D integer tensor of length num_groups+1 on the correct device, that boundaries[0] == 0 and boundaries[-1] == num_sequences, and that the sequence is non-decreasing (each boundary >= previous) and all group sizes are >=1; raise a ValueError with a clear message if any of these checks fail so the PyTorch path rejects translated/offset vectors like [2,4,6] just like Triton does.
🧹 Nitpick comments (2)
docs/operators/ratio-kl.md (1)
9-11: 💤 Low valueConsider adding a language identifier to the fenced code block.
The ASCII pipeline diagram lacks a language specifier. Adding
textimproves rendering and addresses the linter warning.📝 Proposed fix
-``` +```text logits --[ratio_kl op]--> (ratio, kl) --[clipped surrogate + beta*kl]--> loss</details> <details> <summary>🤖 Prompt for AI Agents</summary>Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.In
@docs/operators/ratio-kl.mdaround lines 9 - 11, The fenced ASCII pipeline
block containing "logits --[ratio_kl op]--> (ratio, kl) --[clipped surrogate +
beta*kl]--> loss" should include a language identifier (e.g., text) to satisfy
the linter and improve rendering; update the triple backtick fence that
surrounds the pipeline diagram to use ```text so the block becomes a text code
block while keeping the pipeline content unchanged.</details> <!-- cr-comment:v1:ade0dc832571fe86b6ad1795 --> _Source: Linters/SAST tools_ </blockquote></details> <details> <summary>docs/operators/grpo-loss.md (1)</summary><blockquote> `15-15`: _💤 Low value_ **Consider adding a language identifier to the fenced code block.** The ASCII pipeline diagram uses a fenced code block without a language specifier. While this is valid Markdown, adding `text` improves rendering consistency and silences the linter warning. <details> <summary>📝 Proposed fix</summary> ```diff -``` +```text logits --[ratio_kl op]--> (ratio, kl) --[group adv + clipped surrogate]--> loss ``` ``` </details> <details> <summary>🤖 Prompt for AI Agents</summary>Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.In
@docs/operators/grpo-loss.mdat line 15, The fenced ASCII pipeline diagram
lacks a language tag; update the fenced code block that contains "logits
--[ratio_kl op]--> (ratio, kl) --[group adv + clipped surrogate]--> loss" to
include a language identifier (use "text") after the opening backticks so the
block becomestext ...to satisfy the linter and improve rendering
consistency.</details> <!-- cr-comment:v1:58983bcca01d03c1472ca2d1 --> _Source: Linters/SAST tools_ </blockquote></details> </blockquote></details> <details> <summary>🤖 Prompt for all review comments with AI agents</summary>Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.Inline comments:
In@benchmarks/benchmark_grpo_loss.py:
- Around line 93-109: The nested functions native_fwd, triton_fwd,
native_fwd_bwd, and triton_fwd_bwd capture loop-scoped variables call_args and
spp causing potential late-binding (Ruff B023); fix by binding those loop
variables as default arguments in each function signature (e.g.,
native_fwd(p=policy, call_args=call_args, spp=spp) and similarly triton_fwd, and
for the backward helpers use p=pol_grad, call_args=call_args, spp=spp) so each
closure retains the current iteration's values instead of the final ones.In
@docs/operators/grpo-loss.md:
- Line 30: Doc parameter names are inconsistent: grpo_loss docs use
completion_mask while ratio_kl docs and implementations use attention_mask;
update the docs so the same tensor name is used and unambiguous. Pick one
canonical name (e.g., attention_mask) and rename the parameter in
docs/operators/grpo-loss.md (and any related doc examples) from completion_mask
to attention_mask, or add a single clarifying sentence in grpo_loss docs stating
that completion_mask is the same tensor passed as attention_mask to ratio_kl;
ensure references to shapes ([B, T]) and callsites in grpo_loss, ratio_kl, and
fused _ratio_kl are updated to match the chosen terminology.In
@docs/operators/ratio-kl.md:
- Around line 43-48: Add a short introductory sentence before the formula to
explain what it represents (for example: "The Triton backward kernel
computes:"), add a language identifier to the fenced code block (e.g.,text orpython) around the linegrad_policy_logits[v] = c * (1[v == action] - softmax_policy(v)), and either move or expand the following sentence ("so the
backward also avoids materializing...") so it immediately relates to the formula
(e.g., state that this expression means no [B, T, V] probability tensor is
materialized), ensuring the formula is connected to the preceding discussion of
backends and wrapper classes.In
@rl_engine/kernels/ops/triton/loss/ratio_kl.py:
- Around line 38-67: The kernel reads an unclamped action id "a" (loaded as a =
tl.load(action_ptr + row)) and then uses it to index policy_ptr/ref_ptr, which
can cause out-of-bounds reads for active tokens; clamp "a" to the valid range
[0, V-1] before any indexing (e.g., replace uses of "a" with a_safe =
tl.minimum(tl.maximum(a, 0), V - 1) or tl.clamp if available) and ensure a_safe
has the correct integer dtype matching row_off offsets; update the loads
tl.load(policy_ptr + row_off + a) and tl.load(ref_ptr + row_off + a) to use
a_safe so active-path indexing is always safe in the ratio_kl Triton forward
kernel.
Outside diff comments:
In@rl_engine/kernels/ops/pytorch/loss/grpo_loss.py:
- Around line 192-203: _update _resolve_group_ids() to mirror
TritonGRPOLossOp._build_bounds() validation: ensure group_boundaries is a 1D
integer tensor of length num_groups+1 on the correct device, that boundaries[0]
== 0 and boundaries[-1] == num_sequences, and that the sequence is
non-decreasing (each boundary >= previous) and all group sizes are >=1; raise a
ValueError with a clear message if any of these checks fail so the PyTorch path
rejects translated/offset vectors like [2,4,6] just like Triton does.
Nitpick comments:
In@docs/operators/grpo-loss.md:
- Line 15: The fenced ASCII pipeline diagram lacks a language tag; update the
fenced code block that contains "logits --[ratio_kl op]--> (ratio, kl) --[group
adv + clipped surrogate]--> loss" to include a language identifier (use "text")
after the opening backticks so the block becomestext ...to satisfy the
linter and improve rendering consistency.In
@docs/operators/ratio-kl.md:
- Around line 9-11: The fenced ASCII pipeline block containing "logits
--[ratio_kl op]--> (ratio, kl) --[clipped surrogate + beta*kl]--> loss" should
include a language identifier (e.g., text) to satisfy the linter and improve
rendering; update the triple backtick fence that surrounds the pipeline diagram
to use ```text so the block becomes a text code block while keeping the pipeline
content unchanged.</details> <details> <summary>🪄 Autofix (Beta)</summary> Fix all unresolved CodeRabbit comments on this PR: - [ ] <!-- {"checkboxId": "4b0d0e0a-96d7-4f10-b296-3a18ea78f0b9"} --> Push a commit to this branch (recommended) - [ ] <!-- {"checkboxId": "ff5b1114-7d8c-49e6-8ac1-43f82af23a33"} --> Create a new PR with the fixes </details> --- <details> <summary>ℹ️ Review info</summary> <details> <summary>⚙️ Run configuration</summary> **Configuration used**: defaults **Review profile**: CHILL **Plan**: Pro **Run ID**: `b91d50a6-caba-4711-b397-1074cc35634b` </details> <details> <summary>📥 Commits</summary> Reviewing files that changed from the base of the PR and between 4a9ca42a07c9da6b5634375e1043f9150b19a9e7 and 6a858ae077e3302f6c0456911a289f4cb11481e5. </details> <details> <summary>📒 Files selected for processing (14)</summary> * `benchmarks/benchmark_grpo_loss.py` * `benchmarks/benchmark_ratio_kl.py` * `docs/.nav.yml` * `docs/operators/README.md` * `docs/operators/grpo-loss.md` * `docs/operators/ratio-kl.md` * `rl_engine/kernels/ops/pytorch/loss/grpo_loss.py` * `rl_engine/kernels/ops/pytorch/loss/ratio_kl.py` * `rl_engine/kernels/ops/triton/loss/__init__.py` * `rl_engine/kernels/ops/triton/loss/grpo_loss.py` * `rl_engine/kernels/ops/triton/loss/ratio_kl.py` * `rl_engine/kernels/registry.py` * `tests/test_grpo_loss.py` * `tests/test_ratio_kl.py` </details> </details> <!-- This is an auto-generated comment by CodeRabbit for review status -->
| def native_fwd(p=policy): | ||
| with torch.no_grad(): | ||
| native.forward(c, *call_args, samples_per_prompt=spp, **kwargs) | ||
| native.forward(p, *call_args, samples_per_prompt=spp, **kwargs) | ||
|
|
||
| def triton_fwd(c=current): | ||
| def triton_fwd(p=policy): | ||
| with torch.no_grad(): | ||
| triton_op.forward(c, *call_args, samples_per_prompt=spp, **kwargs) | ||
| triton_op.forward(p, *call_args, samples_per_prompt=spp, **kwargs) | ||
|
|
||
| cur_grad = current.clone().requires_grad_(True) | ||
| pol_grad = policy.clone().requires_grad_(True) | ||
|
|
||
| def native_fwd_bwd(c=cur_grad): | ||
| loss, _, _ = native.forward(c, *call_args, samples_per_prompt=spp, **kwargs) | ||
| torch.autograd.grad(loss, c) | ||
| def native_fwd_bwd(p=pol_grad): | ||
| loss, _, _ = native.forward(p, *call_args, samples_per_prompt=spp, **kwargs) | ||
| torch.autograd.grad(loss, p) | ||
|
|
||
| def triton_fwd_bwd(c=cur_grad): | ||
| loss, _, _ = triton_op.forward(c, *call_args, samples_per_prompt=spp, **kwargs) | ||
| torch.autograd.grad(loss, c) | ||
| def triton_fwd_bwd(p=pol_grad): | ||
| loss, _, _ = triton_op.forward(p, *call_args, samples_per_prompt=spp, **kwargs) | ||
| torch.autograd.grad(loss, p) |
There was a problem hiding this comment.
Bind loop variables as default arguments to avoid late binding.
The nested functions capture call_args and spp from the loop scope without binding them as default arguments. While the functions are called immediately within the same iteration (making this safe in practice), it's still a code smell flagged by Ruff (B023). If the benchmark logic changes and these functions are stored or called later, they would reference the last loop iteration's values.
🛡️ Recommended fix
- def native_fwd(p=policy):
+ def native_fwd(p=policy, args=call_args, s=spp):
with torch.no_grad():
- native.forward(p, *call_args, samples_per_prompt=spp, **kwargs)
+ native.forward(p, *args, samples_per_prompt=s, **kwargs)
- def triton_fwd(p=policy):
+ def triton_fwd(p=policy, args=call_args, s=spp):
with torch.no_grad():
- triton_op.forward(p, *call_args, samples_per_prompt=spp, **kwargs)
+ triton_op.forward(p, *args, samples_per_prompt=s, **kwargs)
pol_grad = policy.clone().requires_grad_(True)
- def native_fwd_bwd(p=pol_grad):
- loss, _, _ = native.forward(p, *call_args, samples_per_prompt=spp, **kwargs)
+ def native_fwd_bwd(p=pol_grad, args=call_args, s=spp):
+ loss, _, _ = native.forward(p, *args, samples_per_prompt=s, **kwargs)
torch.autograd.grad(loss, p)
- def triton_fwd_bwd(p=pol_grad):
- loss, _, _ = triton_op.forward(p, *call_args, samples_per_prompt=spp, **kwargs)
+ def triton_fwd_bwd(p=pol_grad, args=call_args, s=spp):
+ loss, _, _ = triton_op.forward(p, *args, samples_per_prompt=s, **kwargs)
torch.autograd.grad(loss, p)🧰 Tools
🪛 Ruff (0.15.15)
[warning] 95-95: Function definition does not bind loop variable call_args
(B023)
[warning] 95-95: Function definition does not bind loop variable spp
(B023)
[warning] 99-99: Function definition does not bind loop variable call_args
(B023)
[warning] 99-99: Function definition does not bind loop variable spp
(B023)
[warning] 104-104: Function definition does not bind loop variable call_args
(B023)
[warning] 104-104: Function definition does not bind loop variable spp
(B023)
[warning] 108-108: Function definition does not bind loop variable call_args
(B023)
[warning] 108-108: Function definition does not bind loop variable spp
(B023)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@benchmarks/benchmark_grpo_loss.py` around lines 93 - 109, The nested
functions native_fwd, triton_fwd, native_fwd_bwd, and triton_fwd_bwd capture
loop-scoped variables call_args and spp causing potential late-binding (Ruff
B023); fix by binding those loop variables as default arguments in each function
signature (e.g., native_fwd(p=policy, call_args=call_args, spp=spp) and
similarly triton_fwd, and for the backward helpers use p=pol_grad,
call_args=call_args, spp=spp) so each closure retains the current iteration's
values instead of the final ones.
| action_ids, # [B, T] token taken at each position | ||
| old_logps, # [B, T] cached behavior-policy log-probs | ||
| rewards, # [B] | ||
| completion_mask, # [B, T] |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Description: Check mask parameter naming in grpo_loss and ratio_kl implementations
# Check grpo_loss operator signatures
echo "=== grpo_loss signatures ==="
ast-grep --pattern $'def forward($$$, completion_mask, $$$)'
ast-grep --pattern $'def forward($$$, attention_mask, $$$)'
# Check ratio_kl operator signatures
echo "=== ratio_kl signatures ==="
rg -A 3 'def forward' rl_engine/kernels/ops/pytorch/loss/ratio_kl.py
rg -A 3 'def forward' rl_engine/kernels/ops/triton/loss/ratio_kl.pyRepository: RL-Align/RL-Kernel
Length of output: 3188
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== Find mask parameter naming in docs =="
sed -n '1,120p' docs/operators/grpo-loss.md | nl -ba | sed -n '20,60p'
echo
sed -n '1,120p' docs/operators/ratio-kl.md | nl -ba | sed -n '1,80p'
echo
echo "== Search code for completion_mask and grpo_loss =="
rg -n --hidden --glob '!**/node_modules/**' --glob '!**/.git/**' "completion_mask" .
rg -n --hidden --glob '!**/node_modules/**' --glob '!**/.git/**' "grpo_loss" rl_engine docs .
echo
echo "== Search ratio_kl signatures for attention_mask/completion_mask =="
rg -n "def forward\(" rl_engine/kernels/ops -S --glob '*ratio_kl*'
rg -n "attention_mask" rl_engine/kernels/ops -S --glob '*ratio_kl*'
rg -n "completion_mask" rl_engine/kernels/ops -S --glob '*ratio_kl*'Repository: RL-Align/RL-Kernel
Length of output: 164
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== Docs: grpo-loss mask naming =="
cat -n docs/operators/grpo-loss.md | sed -n '1,120p'
echo
echo "== Docs: ratio-kl mask naming =="
cat -n docs/operators/ratio-kl.md | sed -n '1,120p'
echo
echo "== Code search: completion_mask and grpo_loss =="
rg -n "completion_mask" rl_engine docs || true
rg -n "grpo_loss" rl_engine docs || true
echo
echo "== Code search: ratio_kl call sites =="
rg -n "ratio_kl" rl_engine/kernels/ops -S || true
echo
echo "== Code search: attention_mask/completion_mask in ratio_kl =="
rg -n "attention_mask" rl_engine/kernels/ops -S --glob '*ratio*kl*' || true
rg -n "completion_mask" rl_engine/kernels/ops -S --glob '*ratio*kl*' || true
echo
echo "== Find grpo_loss operator implementation files =="
fd -e py -e md -e ts -e js grpo -a rl_engine docs || true
fd -e py -e md grpo_loss -a rl_engine docs || trueRepository: RL-Align/RL-Kernel
Length of output: 19926
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== grpo_loss: pytorch call to ratio_kl =="
cat -n rl_engine/kernels/ops/pytorch/loss/grpo_loss.py | sed -n '90,140p'
echo
echo "== grpo_loss: triton call to ratio_kl =="
cat -n rl_engine/kernels/ops/triton/loss/grpo_loss.py | sed -n '180,240p'
echo
echo "== ratio_kl: pytorch forward signature =="
cat -n rl_engine/kernels/ops/pytorch/loss/ratio_kl.py | sed -n '1,80p'
echo
echo "== ratio_kl: triton forward signature (ctx) =="
cat -n rl_engine/kernels/ops/triton/loss/ratio_kl.py | sed -n '110,220p'Repository: RL-Align/RL-Kernel
Length of output: 11128
Align mask parameter naming across grpo_loss and ratio_kl docs/contracts.
docs/operators/grpo-loss.mddocuments the mask argument ascompletion_mask([B, T]), andgrpo_lossimplementations takecompletion_maskand pass it positionally into the fused_ratio_klcall.docs/operators/ratio-kl.md(and theratio_klimplementations) name that same mask argument asattention_mask([B, T]active tokens).
Rename one of the doc parameters for consistency (or explicitly document that completion_mask in grpo_loss is the same tensor as attention_mask consumed by ratio_kl), to avoid confusion when users call these ops directly.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@docs/operators/grpo-loss.md` at line 30, Doc parameter names are
inconsistent: grpo_loss docs use completion_mask while ratio_kl docs and
implementations use attention_mask; update the docs so the same tensor name is
used and unambiguous. Pick one canonical name (e.g., attention_mask) and rename
the parameter in docs/operators/grpo-loss.md (and any related doc examples) from
completion_mask to attention_mask, or add a single clarifying sentence in
grpo_loss docs stating that completion_mask is the same tensor passed as
attention_mask to ratio_kl; ensure references to shapes ([B, T]) and callsites
in grpo_loss, ratio_kl, and fused _ratio_kl are updated to match the chosen
terminology.
| ``` | ||
| grad_policy_logits[v] = c * (1[v == action] - softmax_policy(v)) | ||
| ``` | ||
|
|
||
| so the backward also avoids materializing any `[B, T, V]` probability tensor (only the | ||
| unavoidable `[B, T, V]` gradient output is written). |
There was a problem hiding this comment.
Incomplete or unclear context for the gradient formula.
The fenced code block at lines 43-45 shows a gradient formula but appears disconnected from surrounding text. The preceding paragraph (lines 36-42) discusses backends and wrapper classes, then the formula appears without introduction. Additionally, the formula block lacks a language identifier.
📋 Suggested improvements
- Add introductory text before the formula explaining what it represents (e.g., "The Triton backward kernel computes:").
- Add a language identifier to the code block (
python ortext). - Consider whether the text at lines 47-48 ("so the backward also avoids...") should be moved before the formula or expanded to provide better context.
Example fix:
| PyTorch fallback | `NativeRatioKLOp` | None | Reference path; CPU and Triton-less GPUs. |
+The Triton backward kernel computes the policy logits gradient analytically:
+
-```
+```python
grad_policy_logits[v] = c * (1[v == action] - softmax_policy(v))</details>
<details>
<summary>🧰 Tools</summary>
<details>
<summary>🪛 markdownlint-cli2 (0.22.1)</summary>
[warning] 43-43: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
</details>
</details>
<details>
<summary>🤖 Prompt for AI Agents</summary>
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In @docs/operators/ratio-kl.md around lines 43 - 48, Add a short introductory
sentence before the formula to explain what it represents (for example: "The
Triton backward kernel computes:"), add a language identifier to the fenced code
block (e.g., text or python) around the line grad_policy_logits[v] = c * (1[v == action] - softmax_policy(v)), and either move or expand the following
sentence ("so the backward also avoids materializing...") so it immediately
relates to the formula (e.g., state that this expression means no [B, T, V]
probability tensor is materialized), ensuring the formula is connected to the
preceding discussion of backends and wrapper classes.
</details>
<!-- fingerprinting:phantom:triton:puma -->
<!-- cr-comment:v1:bfe80326597ee1dea81e01e0 -->
_Source: Linters/SAST tools_
<!-- This is an auto-generated comment by CodeRabbit -->
| if active: | ||
| row_off = row.to(tl.int64) * V | ||
| a = tl.load(action_ptr + row) | ||
|
|
||
| max_p = -float("inf") | ||
| sum_p = 0.0 | ||
| max_r = -float("inf") | ||
| sum_r = 0.0 | ||
| for start in range(0, V, BLOCK_V): | ||
| cols = start + tl.arange(0, BLOCK_V) | ||
| cmask = cols < V | ||
| p = tl.load(policy_ptr + row_off + cols, mask=cmask, other=-float("inf")).to(tl.float32) | ||
| r = tl.load(ref_ptr + row_off + cols, mask=cmask, other=-float("inf")).to(tl.float32) | ||
|
|
||
| tmax_p = tl.max(p, axis=0) | ||
| nmax_p = tl.maximum(max_p, tmax_p) | ||
| sum_p = sum_p * tl.exp(max_p - nmax_p) + tl.sum(tl.exp(p - nmax_p), axis=0) | ||
| max_p = nmax_p | ||
|
|
||
| tmax_r = tl.max(r, axis=0) | ||
| nmax_r = tl.maximum(max_r, tmax_r) | ||
| sum_r = sum_r * tl.exp(max_r - nmax_r) + tl.sum(tl.exp(r - nmax_r), axis=0) | ||
| max_r = nmax_r | ||
|
|
||
| logz_p = max_p + tl.log(sum_p) | ||
| logz_r = max_r + tl.log(sum_r) | ||
| pa = tl.load(policy_ptr + row_off + a).to(tl.float32) | ||
| ra = tl.load(ref_ptr + row_off + a).to(tl.float32) | ||
| logp_p = pa - logz_p | ||
| logp_r = ra - logz_r |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check if action_ids can be out-of-bounds at call sites
# Search for ratio_kl operator calls and examine action_ids sources
rg -nP -C5 '(TritonRatioKLOp|NativeRatioKLOp|ratio_kl.*\(|\.get_op\(["\']ratio_kl)' \
--type=py \
-g '!test*.py'
# Look for any validation or clamping of action_ids/token_ids before operator calls
rg -nP -C3 '(action_ids|token_ids).*(clamp|clip|clip_|where|masked_fill)' \
--type=py \
-g '!test*.py'Repository: RL-Align/RL-Kernel
Length of output: 235
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Inspect Triton and PyTorch implementations around the referenced lines.
echo "== Triton file (ratio_kl.py) =="
sed -n '1,220p' rl_engine/kernels/ops/triton/loss/ratio_kl.py | nl -ba | sed -n '1,220p'
echo "== PyTorch file (ratio_kl.py) =="
sed -n '1,220p' rl_engine/kernels/ops/pytorch/loss/ratio_kl.py | nl -ba | sed -n '1,220p'
# Find call sites/usages of ratio_kl / ratio kl ops.
echo "== Search for ratio_kl call sites =="
rg -n --hidden --no-ignore-vcs --glob '!test*.py' \
-S '(ratio_kl\.|ratio_kl\b|TritonRatioKLOp|NativeRatioKLOp)' \
-g '*.py'
# Search for any clamping/validation of action_ids / token_ids before operator invocation.
echo "== Search for clamps on action_ids/token_ids =="
rg -n --hidden --no-ignore-vcs --glob '!test*.py' \
-S 'action_ids|token_ids' \
-g '*.py' | head -n 200
echo "== Search for clamp/clip usage near action_ids/token_ids =="
rg -n --hidden --no-ignore-vcs --glob '!test*.py' \
-S '(action_ids|token_ids).{0,80}(clamp|clip|clip_|\bwhere\b|masked_fill)' \
-g '*.py'Repository: RL-Align/RL-Kernel
Length of output: 155
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== Triton file (ratio_kl.py) =="
cat -n rl_engine/kernels/ops/triton/loss/ratio_kl.py | sed -n '1,220p'
echo "== PyTorch file (ratio_kl.py) =="
cat -n rl_engine/kernels/ops/pytorch/loss/ratio_kl.py | sed -n '1,220p'
echo "== Search for ratio_kl call sites =="
rg -n --hidden --no-ignore-vcs --glob '!test*.py' \
'ratio_kl' \
rl_engine -S
echo "== Search for clamps/validation of action_ids/token_ids =="
rg -n --hidden --no-ignore-vcs --glob '!test*.py' \
'(action_ids|token_ids).*\.?(clamp|clip|where|masked_fill)\b|safe_ids|safe_.*ids' \
rl_engine -S
echo "== Search for range checks involving action_ids/token_ids =="
rg -n --hidden --no-ignore-vcs --glob '!test*.py' \
'(action_ids|token_ids).*(<|<=|>|>=).*(V|vocab|size|logits)' \
rl_engine -S || trueRepository: RL-Align/RL-Kernel
Length of output: 12176
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== Triton GRPO loss: around _ratio_kl invocation =="
rg -n "ratio, kl_terms = self\\._ratio_kl|self\\._ratio_kl\\(" \
rl_engine/kernels/ops/triton/loss/grpo_loss.py
sed -n '160,260p' rl_engine/kernels/ops/triton/loss/grpo_loss.py
echo "== Triton GRPO loss: earlier preparation of action_ids / attention_mask =="
rg -n "action_ids|attention_mask" rl_engine/kernels/ops/triton/loss/grpo_loss.py
sed -n '1,160p' rl_engine/kernels/ops/triton/loss/grpo_loss.py
echo "== PyTorch GRPO loss: around _ratio_kl invocation =="
rg -n "ratio, kl_terms = self\\._ratio_kl|self\\._ratio_kl\\(" \
rl_engine/kernels/ops/pytorch/loss/grpo_loss.py
sed -n '80,160p' rl_engine/kernels/ops/pytorch/loss/grpo_loss.py
echo "== PyTorch GRPO loss: earlier preparation of action_ids / attention_mask =="
rg -n "action_ids|attention_mask" rl_engine/kernels/ops/pytorch/loss/grpo_loss.py
sed -n '1,90p' rl_engine/kernels/ops/pytorch/loss/grpo_loss.py
echo "== reference_ops: gather_token_ids masked_fill context =="
sed -n '1,120p' rl_engine/testing/reference_ops.py
rg -n "gather_token_ids" rl_engine/testing/reference_ops.pyRepository: RL-Align/RL-Kernel
Length of output: 20105
Critical: Triton ratio_kl uses unclamped action_ids for active tokens
rl_engine/kernels/ops/triton/loss/ratio_kl.pyloadsa = tl.load(action_ptr + row)and indexes logits withtl.load(policy_ptr + row_off + a)/tl.load(ref_ptr + row_off + a)without clamping/validation.- Although the kernel only reads
awhenactive = attention_mask_flat[row] != 0, any out-of-rangeaction_idsat active positions (a < 0ora >= V) can produce out-of-bounds reads and incorrect losses/grads. - PyTorch fallback
rl_engine/kernels/ops/pytorch/loss/ratio_kl.pyclampsaction_idsviasafe_ids = action_ids.clamp(0, logits.size(-1) - 1).long(), so behavior is inconsistent when upstream supplies invalid active ids. - No other range checks/clamping for
action_idswere found in the repo call paths (only this PyTorch clamp).
🛡️ Recommended fix: Clamp `a` in the Triton forward kernel
if active:
row_off = row.to(tl.int64) * V
a = tl.load(action_ptr + row)
+ # Clamp action_id to valid vocab range [0, V)
+ a = tl.maximum(0, tl.minimum(a, V - 1))
max_p = -float("inf")
sum_p = 0.0📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if active: | |
| row_off = row.to(tl.int64) * V | |
| a = tl.load(action_ptr + row) | |
| max_p = -float("inf") | |
| sum_p = 0.0 | |
| max_r = -float("inf") | |
| sum_r = 0.0 | |
| for start in range(0, V, BLOCK_V): | |
| cols = start + tl.arange(0, BLOCK_V) | |
| cmask = cols < V | |
| p = tl.load(policy_ptr + row_off + cols, mask=cmask, other=-float("inf")).to(tl.float32) | |
| r = tl.load(ref_ptr + row_off + cols, mask=cmask, other=-float("inf")).to(tl.float32) | |
| tmax_p = tl.max(p, axis=0) | |
| nmax_p = tl.maximum(max_p, tmax_p) | |
| sum_p = sum_p * tl.exp(max_p - nmax_p) + tl.sum(tl.exp(p - nmax_p), axis=0) | |
| max_p = nmax_p | |
| tmax_r = tl.max(r, axis=0) | |
| nmax_r = tl.maximum(max_r, tmax_r) | |
| sum_r = sum_r * tl.exp(max_r - nmax_r) + tl.sum(tl.exp(r - nmax_r), axis=0) | |
| max_r = nmax_r | |
| logz_p = max_p + tl.log(sum_p) | |
| logz_r = max_r + tl.log(sum_r) | |
| pa = tl.load(policy_ptr + row_off + a).to(tl.float32) | |
| ra = tl.load(ref_ptr + row_off + a).to(tl.float32) | |
| logp_p = pa - logz_p | |
| logp_r = ra - logz_r | |
| if active: | |
| row_off = row.to(tl.int64) * V | |
| a = tl.load(action_ptr + row) | |
| # Clamp action_id to valid vocab range [0, V) | |
| a = tl.maximum(0, tl.minimum(a, V - 1)) | |
| max_p = -float("inf") | |
| sum_p = 0.0 | |
| max_r = -float("inf") | |
| sum_r = 0.0 | |
| for start in range(0, V, BLOCK_V): | |
| cols = start + tl.arange(0, BLOCK_V) | |
| cmask = cols < V | |
| p = tl.load(policy_ptr + row_off + cols, mask=cmask, other=-float("inf")).to(tl.float32) | |
| r = tl.load(ref_ptr + row_off + cols, mask=cmask, other=-float("inf")).to(tl.float32) | |
| tmax_p = tl.max(p, axis=0) | |
| nmax_p = tl.maximum(max_p, tmax_p) | |
| sum_p = sum_p * tl.exp(max_p - nmax_p) + tl.sum(tl.exp(p - nmax_p), axis=0) | |
| max_p = nmax_p | |
| tmax_r = tl.max(r, axis=0) | |
| nmax_r = tl.maximum(max_r, tmax_r) | |
| sum_r = sum_r * tl.exp(max_r - nmax_r) + tl.sum(tl.exp(r - nmax_r), axis=0) | |
| max_r = nmax_r | |
| logz_p = max_p + tl.log(sum_p) | |
| logz_r = max_r + tl.log(sum_r) | |
| pa = tl.load(policy_ptr + row_off + a).to(tl.float32) | |
| ra = tl.load(ref_ptr + row_off + a).to(tl.float32) | |
| logp_p = pa - logz_p | |
| logp_r = ra - logz_r |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@rl_engine/kernels/ops/triton/loss/ratio_kl.py` around lines 38 - 67, The
kernel reads an unclamped action id "a" (loaded as a = tl.load(action_ptr +
row)) and then uses it to index policy_ptr/ref_ptr, which can cause
out-of-bounds reads for active tokens; clamp "a" to the valid range [0, V-1]
before any indexing (e.g., replace uses of "a" with a_safe =
tl.minimum(tl.maximum(a, 0), V - 1) or tl.clamp if available) and ensure a_safe
has the correct integer dtype matching row_off offsets; update the loads
tl.load(policy_ptr + row_off + a) and tl.load(ref_ptr + row_off + a) to use
a_safe so active-path indexing is always safe in the ratio_kl Triton forward
kernel.
#40
Summary
What changed and why
New operator —
ratio_klrl_engine/kernels/ops/triton/loss/ratio_kl.py:_RatioKLFunctionautograd +TritonRatioKLOp.BLOCK_Vtiles for policy and refrl_engine/kernels/ops/pytorch/loss/ratio_kl.py:NativeRatioKLOpreference, reusingNativeLogpOptests/test_ratio_kl.py: native-vs-reference, masked neutrality, ratio==1 whenold==policy, grad-to-policy-only, Triton fwd/bwd vs native (V=64 and 50257), registry dispatch.benchmarks/benchmark_ratio_kl.py: New benchmark (latency + peak VRAM, vocab sweep).docs/operators/ratio-kl.md: operator doc.GRPO loss refactor — compose
ratio_klinstead of a bespoke kernelrl_engine/kernels/ops/triton/{triton_grpo_loss.py → loss/grpo_loss.py}:loss/;_grpo_fwd/_grpo_bwdkernels and_GRPOLossFunction;TritonRatioKLOpand keeps only_group_norm_kernelfor reward normalization.benchmarks/benchmark_grpo_loss.py: Updated to logits APItests/test_grpo_loss.py: Loss/grad/masking/SGD tests reworked for logits APIdocs/operators/grpo-loss.mdWiring
rl_engine/kernels/registry.py: Registerratio_kl(Triton/PyTorch) for cuda/rocm/cpu; repointTRITON_GRPO_LOSStotriton.loss.grpo_lossPerformance (fp16, B=32, T=256)
Notes
Summary by CodeRabbit
Release Notes
New Features
ratio_kloperator for fused policy ratio and KL penalty computation in PPO/GRPO workflows.Documentation
ratio_kloperator with tensor contract and performance specifications.Tests
ratio_kloperator covering correctness, masking, and gradient flow.Benchmarking