-For prefill (multi-query) attention, we reimplemented the FlashAttention 2 algorithm in pure CUDA with some additional optimizations. Standard FlashAttention implementation uses Tensor Cores with fp16 input and fp32 accumulator, however, RTX 4090 GPUs has lower Tensor Cores performance with fp32 accumulator, we observe that the $\frac{\mathbf{q}\cdot \mathbf{k}^{T}}{\sqrt(d)}$ phase in attention computation have small range and can be accumulated with fp16 (because the head dimension is always small: e.g. 128), FlashInfer provides an `allow_fp16_qk_reduction` option to allow this optimization (but still use fp32 accumulation for $\mathbf{score} \cdot \mathbf{v}$), this optimization could bring 50% speedup on RTX 4090. Below is the performance comparison of FlashInfer 0.0.1 and FlashAttention 2.4.2 on different GPUs:
0 commit comments