Skip to content

[TPU][Pallas] lower hl.jagged_tile via a per-item DMA template#2614

Closed
yarongmu-google wants to merge 4 commits into
pytorch:mainfrom
yarongmu-google:jagged-pallas-codegen
Closed

[TPU][Pallas] lower hl.jagged_tile via a per-item DMA template#2614
yarongmu-google wants to merge 4 commits into
pytorch:mainfrom
yarongmu-google:jagged-pallas-codegen

Conversation

@yarongmu-google
Copy link
Copy Markdown
Collaborator

@yarongmu-google yarongmu-google commented May 27, 2026

Helion's existing Pallas lowering for hl.jagged_tile goes through a one-hot × matmul gather, which OOMs on real inputs and produces wrong output on TPU. This PR replaces it for the (single items axis, sum-shaped) case with a per-item DMA-orchestrated template that prefetches each item's data into VMEM with double-buffered ping-pong, accumulates in fp32, and flushes the result via a separate double-buffered output DMA — no cross-program writes, no gather.

A pair-based detector (detect_jagged_dispatch) decides at codegen time whether the kernel fits the template; if so, the device function emission is skipped and the host wrapper calls default_pallas_jagged_reduce_launcher instead of _launcher(...). Triton and other backends are unaffected. Verified end-to-end on TPU with examples/jagged_sum.py.

After this PR:

=================================================================
Benchmark Results
=================================================================
Implementation       Time (ms)    Speedup        
-----------------------------------------------------------------
helion               0.2838       21.86x         
torch                6.2051       1.00x (ref)    
=================================================================

Note the perf is slower than #2596's 22.85x because this doesn't modify the original helion kernel which has 3 nested loops, while #2596 only has two nested loops in teh rewritten helion kernel.

Adds helion.runtime.pallas_templates.jagged_reduce_pallas, a JAX
Pallas kernel that computes a per-item reduction over the
second-minor axis of a jagged tensor:

    out[i, :] = reduce(jagged_data[offsets[i] : offsets[i+1], :])

Pattern: single-program grid; pl.loop iterates over items; each
item's rows are tiled, fetched into VMEM via pltpu.make_async_copy
with double-buffered ping-pong, accumulated in fp32, and flushed
through a separate double-buffered output DMA. No cross-program
writes, so no atomic accumulation is needed.

Standalone template (no Helion call sites yet); the dispatch
detector and codegen integration land in follow-up commits.

pyproject: exempt this module from ANN001/ANN202 because the
template is plumbing Pallas Refs and traced JAX values that don't
have meaningful Python types.
Adds detect_jagged_dispatch(env) -> int | None, which reads
CompileEnvironment.jagged_tile_parent_ids and returns the
items-axis block_id when every hl.jagged_tile in the kernel shares
a single hl.tile parent (i.e. the kernel has a single items axis).
Returns None otherwise so codegen can fall through.

Detector is pure (no IR mutation, no FX walk); follow-up codegen
commit reads the returned block_id to route the call through the
jagged_reduce template.

Unit tests (8) cover: no jagged tiles, single pair, multiple pairs
sharing the same parent, child with multiple parents, multiple items
axes, mixed dispatchable/non-dispatchable, and the defensive zero-
parent path.
Wires the dispatch detector and template together:

  - PallasBackend.pre_codegen calls detect_jagged_dispatch() and
    stashes the items-axis block_id on the DeviceFunction when the
    kernel matches the pair-based pattern.

  - DeviceFunction.codegen_function_call short-circuits to
    _codegen_jagged_dispatch_call when the flag is set, emitting
    `<out> = _default_pallas_jagged_reduce_launcher(jagged_data, offsets)`
    instead of the gather-based _launcher(...) call. The matching
    codegen_function_def skips emitting the device kernel as dead
    code so the generated module compiles.

  - default_pallas_jagged_reduce_launcher (helion.runtime) bridges
    torch tensors to the JAX template: jits jagged_reduce_pallas
    once per (template-fn-id, jagged_tile_size), wraps it in a
    cached JaxCallable on TPU, or uses the DLPack bridge in
    interpret mode. Single jit layer, no nesting.

Input identification reads HostFunction.params.arguments (the user's
kernel parameter names) and classifies by ndim: 2-D = jagged_data,
1-D = jagged_dim_offsets. Iterating sorted_args was incorrect because
intermediates like `x_flat = x_data.view(-1)` share storage with
x_data and surface as 1-D TensorArgs, so the launcher would receive
the view instead of the original 2-D parameter.

Output is identified by storage identity: a TensorArg whose storage
is not in env.input_sources was created in the kernel body
(via torch.zeros etc.) and is the output to return.

Currently sum-only and expects exactly the (2-D data, 1-D offsets)
signature on the kernel; richer signatures (e.g. jagged_mean's
feature counts) are out of scope for this commit.
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 27, 2026
Two additions:

  - test/test_pallas_jagged_reduce.py — 8 launcher-direct cases that
    call default_pallas_jagged_reduce_launcher from torch and compare
    against a scatter-sum reference. Covers branches the example test
    doesn't naturally hit: empty items, partial-tail tiles, multi-tile
    items, many items, lane-padded M (M_actual not a multiple of 128),
    bf16 cast on flush, int64 offsets coercion, and the single-item
    prologue/epilogue path.

  - test/test_examples.py::test_jagged_sum — flip the decorator from
    @xfailIfPallas to @xfailIfPallasInterpret. The kernel now runs
    correctly on real TPU through the per-item DMA template; interpret
    mode still xfails because the template uses pltpu primitives that
    require a TPU device.
@yarongmu-google yarongmu-google requested review from cota, jansel and oulgen and removed request for cota, jansel and oulgen May 27, 2026 22:55
@yarongmu-google
Copy link
Copy Markdown
Collaborator Author

This is now subsumed into #2616, which also handles jagged_mean (2 jagged dims instead of 1) so I will close this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant