[TPU][Pallas] lower hl.jagged_tile via a per-item DMA template#2614
Closed
yarongmu-google wants to merge 4 commits into
Closed
[TPU][Pallas] lower hl.jagged_tile via a per-item DMA template#2614yarongmu-google wants to merge 4 commits into
yarongmu-google wants to merge 4 commits into
Conversation
Adds helion.runtime.pallas_templates.jagged_reduce_pallas, a JAX
Pallas kernel that computes a per-item reduction over the
second-minor axis of a jagged tensor:
out[i, :] = reduce(jagged_data[offsets[i] : offsets[i+1], :])
Pattern: single-program grid; pl.loop iterates over items; each
item's rows are tiled, fetched into VMEM via pltpu.make_async_copy
with double-buffered ping-pong, accumulated in fp32, and flushed
through a separate double-buffered output DMA. No cross-program
writes, so no atomic accumulation is needed.
Standalone template (no Helion call sites yet); the dispatch
detector and codegen integration land in follow-up commits.
pyproject: exempt this module from ANN001/ANN202 because the
template is plumbing Pallas Refs and traced JAX values that don't
have meaningful Python types.
Adds detect_jagged_dispatch(env) -> int | None, which reads CompileEnvironment.jagged_tile_parent_ids and returns the items-axis block_id when every hl.jagged_tile in the kernel shares a single hl.tile parent (i.e. the kernel has a single items axis). Returns None otherwise so codegen can fall through. Detector is pure (no IR mutation, no FX walk); follow-up codegen commit reads the returned block_id to route the call through the jagged_reduce template. Unit tests (8) cover: no jagged tiles, single pair, multiple pairs sharing the same parent, child with multiple parents, multiple items axes, mixed dispatchable/non-dispatchable, and the defensive zero- parent path.
Wires the dispatch detector and template together:
- PallasBackend.pre_codegen calls detect_jagged_dispatch() and
stashes the items-axis block_id on the DeviceFunction when the
kernel matches the pair-based pattern.
- DeviceFunction.codegen_function_call short-circuits to
_codegen_jagged_dispatch_call when the flag is set, emitting
`<out> = _default_pallas_jagged_reduce_launcher(jagged_data, offsets)`
instead of the gather-based _launcher(...) call. The matching
codegen_function_def skips emitting the device kernel as dead
code so the generated module compiles.
- default_pallas_jagged_reduce_launcher (helion.runtime) bridges
torch tensors to the JAX template: jits jagged_reduce_pallas
once per (template-fn-id, jagged_tile_size), wraps it in a
cached JaxCallable on TPU, or uses the DLPack bridge in
interpret mode. Single jit layer, no nesting.
Input identification reads HostFunction.params.arguments (the user's
kernel parameter names) and classifies by ndim: 2-D = jagged_data,
1-D = jagged_dim_offsets. Iterating sorted_args was incorrect because
intermediates like `x_flat = x_data.view(-1)` share storage with
x_data and surface as 1-D TensorArgs, so the launcher would receive
the view instead of the original 2-D parameter.
Output is identified by storage identity: a TensorArg whose storage
is not in env.input_sources was created in the kernel body
(via torch.zeros etc.) and is the output to return.
Currently sum-only and expects exactly the (2-D data, 1-D offsets)
signature on the kernel; richer signatures (e.g. jagged_mean's
feature counts) are out of scope for this commit.
Two additions:
- test/test_pallas_jagged_reduce.py — 8 launcher-direct cases that
call default_pallas_jagged_reduce_launcher from torch and compare
against a scatter-sum reference. Covers branches the example test
doesn't naturally hit: empty items, partial-tail tiles, multi-tile
items, many items, lane-padded M (M_actual not a multiple of 128),
bf16 cast on flush, int64 offsets coercion, and the single-item
prologue/epilogue path.
- test/test_examples.py::test_jagged_sum — flip the decorator from
@xfailIfPallas to @xfailIfPallasInterpret. The kernel now runs
correctly on real TPU through the per-item DMA template; interpret
mode still xfails because the template uses pltpu primitives that
require a TPU device.
Collaborator
Author
|
This is now subsumed into #2616, which also handles jagged_mean (2 jagged dims instead of 1) so I will close this. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Helion's existing Pallas lowering for
hl.jagged_tilegoes through a one-hot × matmul gather, which OOMs on real inputs and produces wrong output on TPU. This PR replaces it for the (single items axis, sum-shaped) case with a per-item DMA-orchestrated template that prefetches each item's data into VMEM with double-buffered ping-pong, accumulates in fp32, and flushes the result via a separate double-buffered output DMA — no cross-program writes, no gather.A pair-based detector (
detect_jagged_dispatch) decides at codegen time whether the kernel fits the template; if so, the device function emission is skipped and the host wrapper callsdefault_pallas_jagged_reduce_launcherinstead of_launcher(...). Triton and other backends are unaffected. Verified end-to-end on TPU withexamples/jagged_sum.py.After this PR:
Note the perf is slower than #2596's 22.85x because this doesn't modify the original helion kernel which has 3 nested loops, while #2596 only has two nested loops in teh rewritten helion kernel.