Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 44 additions & 42 deletions benchmarks/benchmark_grpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@

"""Benchmark NativeGRPOLossOp vs TritonGRPOLossOp.

Reports forward and forward+backward latency (and peak extra VRAM for the
forward pass) across a range of (groups x samples x completion-length) shapes.
The ops operate on per-token log-probs, so the working set scales with
N = num_prompts * samples_per_prompt * completion_len.
Both ops consume logits and build on the fused ratio/KL op, so the working set
scales with the logits tensor [B, T, V] (B = num_prompts * samples_per_prompt).
Reports forward and forward+backward latency plus peak extra VRAM for the
forward pass across a range of (groups x samples x completion-length x vocab)
shapes; the vocab axis is where the online-softmax fusion pays off.

Usage:
python benchmarks/benchmark_grpo_loss.py
python benchmarks/benchmark_grpo_loss.py --iters 50 --clip-eps 0.2 --beta 0.04
python benchmarks/benchmark_grpo_loss.py --configs "4,8,256,32768;4,8,256,131072"
"""

import argparse
Expand All @@ -19,27 +20,27 @@
from tabulate import tabulate

from rl_engine.kernels.ops.pytorch.loss.grpo_loss import NativeGRPOLossOp
from rl_engine.kernels.ops.triton.triton_grpo_loss import TritonGRPOLossOp
from rl_engine.kernels.ops.triton.loss.grpo_loss import TritonGRPOLossOp
from rl_engine.platforms.device import device_ctx
from rl_engine.utils.logger import logger

# (num_prompts, samples_per_prompt, completion_len)
# (num_prompts, samples_per_prompt, completion_len, vocab)
DEFAULT_CONFIGS = [
(32, 8, 256),
(64, 8, 512),
(128, 8, 1024),
(256, 16, 1024),
(4, 8, 256, 32768),
(4, 8, 256, 50257),
(4, 8, 256, 131072),
]


def _make_inputs(num_prompts, spp, completion_len, device, dtype):
def _make_inputs(num_prompts, spp, completion_len, vocab, device, dtype):
batch = num_prompts * spp
current = torch.randn(batch, completion_len, device=device, dtype=dtype)
policy = torch.randn(batch, completion_len, vocab, device=device, dtype=dtype)
ref = torch.randn(batch, completion_len, vocab, device=device, dtype=dtype)
action_ids = torch.randint(0, vocab, (batch, completion_len), device=device)
old = torch.randn(batch, completion_len, device=device, dtype=dtype)
ref = torch.randn(batch, completion_len, device=device, dtype=dtype)
rewards = torch.randn(batch, device=device, dtype=dtype)
mask = torch.ones(batch, completion_len, dtype=torch.bool, device=device)
return current, old, ref, rewards, mask
return policy, ref, action_ids, old, rewards, mask


def _time_ms(fn, warmup, iters):
Expand Down Expand Up @@ -74,36 +75,38 @@ def run_benchmark(args):
raise RuntimeError("GRPO loss benchmark requires a CUDA device (Triton op is CUDA-only).")

device = device_ctx.device
dtype = torch.float32
dtype = torch.float16
native = NativeGRPOLossOp()
triton_op = TritonGRPOLossOp()
kwargs = dict(clip_eps=args.clip_eps, beta=args.beta)

logger.info(f"GRPO loss benchmark on {device} (dtype={dtype})")

rows = []
for num_prompts, spp, comp_len in args.configs:
current, old, ref, rewards, mask = _make_inputs(num_prompts, spp, comp_len, device, dtype)
n_tokens = current.numel()
call_args = (old, ref, rewards, mask)
for num_prompts, spp, comp_len, vocab in args.configs:
policy, ref, action_ids, old, rewards, mask = _make_inputs(
num_prompts, spp, comp_len, vocab, device, dtype
)
n_tokens = num_prompts * spp * comp_len
call_args = (ref, action_ids, old, rewards, mask)

def native_fwd(c=current):
def native_fwd(p=policy):
with torch.no_grad():
native.forward(c, *call_args, samples_per_prompt=spp, **kwargs)
native.forward(p, *call_args, samples_per_prompt=spp, **kwargs)

def triton_fwd(c=current):
def triton_fwd(p=policy):
with torch.no_grad():
triton_op.forward(c, *call_args, samples_per_prompt=spp, **kwargs)
triton_op.forward(p, *call_args, samples_per_prompt=spp, **kwargs)

cur_grad = current.clone().requires_grad_(True)
pol_grad = policy.clone().requires_grad_(True)

def native_fwd_bwd(c=cur_grad):
loss, _, _ = native.forward(c, *call_args, samples_per_prompt=spp, **kwargs)
torch.autograd.grad(loss, c)
def native_fwd_bwd(p=pol_grad):
loss, _, _ = native.forward(p, *call_args, samples_per_prompt=spp, **kwargs)
torch.autograd.grad(loss, p)

def triton_fwd_bwd(c=cur_grad):
loss, _, _ = triton_op.forward(c, *call_args, samples_per_prompt=spp, **kwargs)
torch.autograd.grad(loss, c)
def triton_fwd_bwd(p=pol_grad):
loss, _, _ = triton_op.forward(p, *call_args, samples_per_prompt=spp, **kwargs)
torch.autograd.grad(loss, p)
Comment on lines +93 to +109

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Bind loop variables as default arguments to avoid late binding.

The nested functions capture call_args and spp from the loop scope without binding them as default arguments. While the functions are called immediately within the same iteration (making this safe in practice), it's still a code smell flagged by Ruff (B023). If the benchmark logic changes and these functions are stored or called later, they would reference the last loop iteration's values.

🛡️ Recommended fix
-        def native_fwd(p=policy):
+        def native_fwd(p=policy, args=call_args, s=spp):
             with torch.no_grad():
-                native.forward(p, *call_args, samples_per_prompt=spp, **kwargs)
+                native.forward(p, *args, samples_per_prompt=s, **kwargs)

-        def triton_fwd(p=policy):
+        def triton_fwd(p=policy, args=call_args, s=spp):
             with torch.no_grad():
-                triton_op.forward(p, *call_args, samples_per_prompt=spp, **kwargs)
+                triton_op.forward(p, *args, samples_per_prompt=s, **kwargs)

         pol_grad = policy.clone().requires_grad_(True)

-        def native_fwd_bwd(p=pol_grad):
-            loss, _, _ = native.forward(p, *call_args, samples_per_prompt=spp, **kwargs)
+        def native_fwd_bwd(p=pol_grad, args=call_args, s=spp):
+            loss, _, _ = native.forward(p, *args, samples_per_prompt=s, **kwargs)
             torch.autograd.grad(loss, p)

-        def triton_fwd_bwd(p=pol_grad):
-            loss, _, _ = triton_op.forward(p, *call_args, samples_per_prompt=spp, **kwargs)
+        def triton_fwd_bwd(p=pol_grad, args=call_args, s=spp):
+            loss, _, _ = triton_op.forward(p, *args, samples_per_prompt=s, **kwargs)
             torch.autograd.grad(loss, p)
🧰 Tools
🪛 Ruff (0.15.15)

[warning] 95-95: Function definition does not bind loop variable call_args

(B023)


[warning] 95-95: Function definition does not bind loop variable spp

(B023)


[warning] 99-99: Function definition does not bind loop variable call_args

(B023)


[warning] 99-99: Function definition does not bind loop variable spp

(B023)


[warning] 104-104: Function definition does not bind loop variable call_args

(B023)


[warning] 104-104: Function definition does not bind loop variable spp

(B023)


[warning] 108-108: Function definition does not bind loop variable call_args

(B023)


[warning] 108-108: Function definition does not bind loop variable spp

(B023)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@benchmarks/benchmark_grpo_loss.py` around lines 93 - 109, The nested
functions native_fwd, triton_fwd, native_fwd_bwd, and triton_fwd_bwd capture
loop-scoped variables call_args and spp causing potential late-binding (Ruff
B023); fix by binding those loop variables as default arguments in each function
signature (e.g., native_fwd(p=policy, call_args=call_args, spp=spp) and
similarly triton_fwd, and for the backward helpers use p=pol_grad,
call_args=call_args, spp=spp) so each closure retains the current iteration's
values instead of the final ones.


n_fwd = _time_ms(native_fwd, args.warmup, args.iters)
t_fwd = _time_ms(triton_fwd, args.warmup, args.iters)
Expand All @@ -114,21 +117,21 @@ def triton_fwd_bwd(c=cur_grad):

rows.append(
[
f"{num_prompts}x{spp}x{comp_len}",
f"{n_tokens/1e6:.2f}M",
f"{num_prompts}x{spp}x{comp_len}x{vocab}",
f"{n_tokens/1e3:.0f}K",
f"{n_fwd:.3f}",
f"{t_fwd:.3f}",
f"{n_fwd/t_fwd:.2f}x",
f"{n_fb:.3f}",
f"{t_fb:.3f}",
f"{n_fb/t_fb:.2f}x",
f"{n_vram*1024:.1f}",
f"{t_vram*1024:.1f}",
f"{n_vram*1024:.0f}",
f"{t_vram*1024:.0f}",
]
)

headers = [
"shape (P x S x L)",
"shape (P x S x L x V)",
"tokens",
"native fwd ms",
"triton fwd ms",
Expand All @@ -144,21 +147,20 @@ def triton_fwd_bwd(c=cur_grad):

def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--iters", type=int, default=30)
parser.add_argument("--warmup", type=int, default=10)
parser.add_argument("--iters", type=int, default=20)
parser.add_argument("--warmup", type=int, default=5)
parser.add_argument("--clip-eps", type=float, default=0.2)
parser.add_argument("--beta", type=float, default=0.04)
parser.add_argument(
"--configs",
type=str,
default=None,
help="Semicolon-separated 'prompts,samples,len' triples, e.g. '64,8,512;128,8,1024'.",
help="Semicolon-separated 'prompts,samples,len,vocab' tuples, "
"e.g. '4,8,256,32768;4,8,256,131072'.",
)
args = parser.parse_args()
if args.configs:
args.configs = [
tuple(int(x) for x in triple.split(",")) for triple in args.configs.split(";")
]
args.configs = [tuple(int(x) for x in tup.split(",")) for tup in args.configs.split(";")]
else:
args.configs = DEFAULT_CONFIGS
return args
Expand Down
Loading
Loading