In this post, I will walkthrough how I learned to implement Flash Attention for 5090 in CUDA C++. The main objective is to learn writing attention in CUDA C++, since many features are not available in Triton, such as MXFP8 / NVFP4 MMA for sm120. I also feel this is a natural next step after learning about matmul kernels. Lastly, there are many excellent blogposts on writing fast matmul kernels, but there is none for attention. So I want to take this chance to write up something nicely.
Readers are highly recommended to be familiar with CUDA C++ and how to use Tensor cores on NVIDIA GPUs. Of course you can still read along and clarify with your favourite LLMs along the way. Or you can check out GPU-MODE series (slides, YouTube) for basic CUDA C++ knowledge, as well as the excellent matmul blogposts mentioned above, to quickly get up to speed.
You can find the full implementation discussed in this post here: https://github.com/gau-nernst/learn-cuda/tree/e83c256/07_attention. For bs=1, num_heads=8, len_query=4096, len_kv = 8192 , 5090 @ 400W, compile with CUDA 12.9, I obtained the following benchmark results (theoretical limit of 5090 is 209.5 TFLOPS for BF16)
Kernel TFLOPS % of SOL F.sdpa() (Flash Attention) 186.73 89.13% F.sdpa() (CuDNN) 203.61 97.19% flash-attn 190.58 90.97% v1 (basic) 142.87 68.20% v2 (shared memory swizzling) 181.11 86.45% v3 (2-stage pipelining) 189.84 90.62% v4 ( ldmatrix.x4 for K and V) 194.33 92.76% v5 (better pipelining) 197.74 94.39%
Do note that although I only use Ampere features in these implementations (sm120 supports cp.async.bulk i.e. TMA, but I don’t use it here), my implementations might not run performantly on earlier generations of GPUs. Due to improvements in newer hardware, you might need to use more tricks to reach Speed-of-Light on older GPUs e.g. pipeline shared memory to register memory data movements.
Flash Attention algorithm
Let’s start with the reference implementation of attention.
from torch import Tensor def sdpa (q: Tensor, k: Tensor, v: Tensor): # q: [B, Lq, DIM] # k: [B, Lk, DIM] # v: [B, Lk, DIM] D = q . shape[ - 1 ] scale = D ** - 0.5 attn = (q @ k . transpose( - 1 , - 2 )) * scale # [B, Lq, Lk] attn = attn . softmax(dim =- 1 ) out = attn @ v # [B, Lq, DIM] return out
Technically, if the inputs are BF16, some computations should remain in FP32, especially softmax. However, for brevity, we omit them.
We are implementing the algorithm outlined in the Flash Attention 2 paper. Each threadblock is responsible for a chunk of Q, and we will iterate along the sequence length of KV. A Python-like outline of the algorithm looks like below (S and P follow Flash Attention notation).
... continue reading