Skip to content

FlashInfer CUDA Kernel Integration for Decode Attention with Attention Weights#65

Open
vihangp wants to merge 2 commits intoawslabs:mainfrom
vihangp:flashinfer_kernels
Open

FlashInfer CUDA Kernel Integration for Decode Attention with Attention Weights#65
vihangp wants to merge 2 commits intoawslabs:mainfrom
vihangp:flashinfer_kernels

Conversation

@vihangp
Copy link
Collaborator

@vihangp vihangp commented Mar 16, 2026

Summary

This PR vendors FlashInfer's CUDA primitives into keys_values/csrc/ and builds custom prefill and decode kernels that return per-position attention weights — needed for H2O cache eviction. Standard attention libraries (Flash Attention, PyTorch SDPA) do not expose these weights during decode.

Key changes

  • Vendored FlashInfer headers (keys_values/csrc/flashinfer/) — low-level primitives (vec_t, state_t, ptx_exp2, warp shuffles) used by our custom kernels
  • Custom CUDA kernels (keys_values/csrc/kernels/) — prefill and decode kernels with attention weight support, compiled via PyTorch cpp_extension
  • Optimized decode kernel — batched launch, Q-in-registers, vectorized loads, warp-level reductions, logits caching to avoid recomputing Q*K in the attention weights pass
  • Python integration — FlashInferSDPA wrapper (flashinfer_wrapper.py) with automatic fallback to eager; attention.py auto-selects FlashInfer when return_attn_weights=True
  • Build system — setup.py with conditional CUDA compilation, half-precision flag fixes, multi-arch support (SM 70-90)
  • Cleanup — consolidated 6 progress docs into FLASHINFER_INTEGRATION.md, removed debug scripts, updated .gitignore

Benchmark results (Qwen3-4B, A100, fp16, decode)

Batch size │ vs Eager (with weights) │ vs PyTorch SDPA (no weights |
4 │ 1.68-1.80x faster │ 0.96-1.17x (faster at 64K+) |
1 │ 0.60-0.85x │ 0.34-0.44x |

Prefill: 9-29x faster than eager, competitive with PyTorch SDPA.

Test plan

  • Verify CUDA extension compiles (python setup.py build_ext --inplace)
  • Run existing test suite (pytest test/)
  • Run decode benchmark on A100 (python benchmark_long_context.py --models Qwen3-4B)
  • Verify attention weights match eager implementation

Copy link
Contributor

@mseeger mseeger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your PR seems to be touching a lot of files. Please check this:

  • Put your changes on top of the current main branch. This should resolve almost all the changes in existing files of the project
  • Are you sure you need all the new files? They seem to be taken from flashinfer. Could we not make flashinfer a dependence, and then only copy files here which need a change?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants