Tech News
← Back to articles

I rebuilt FlashAttention in Triton to understand the performance archaeology

read original related products more articles

⏲️ Estimated reading time ~45min.

Flash Attention: From Theory to Implementation

Flash Attention has become one of the most impactful optimizations in modern deep learning. Since the original paper was published in 2022, we’ve seen four major versions—each squeezing more performance out of increasingly powerful hardware. But here’s the thing: reading papers is one thing, understanding why these optimizations were made is another entirely.

My goal here is simple: start from first principles, implement FlashAttention v1 exactly as described in the paper, profile it, find the bottlenecks, and see how far we can push it. We’ll build intuition by iterating on the algorithm, discovering through profiling exactly why v2, v3, and v4 were necessary. Think of this as archaeology—digging through the performance layers to understand what each version was really solving.

So let’s put ourselves in the shoes of that mythical Stanford grad student. You’ve finally finished configuring your neovim and archlinux setup (a multi-year endeavor, naturally). You open up a fresh LLaMA model to peek under the hood. Text goes in, gets tokenized, embedded, then flows through a stack of transformer blocks. Standard stuff. But then you look closer at the attention mechanism—three projections and then… there it is:

scores = torch . matmul(q, k . transpose( 3 , 2 )) / math . sqrt(self . head_dim) scores = F . softmax(scores, dim =- 1 ) output = torch . matmul(scores, v) # (B, N_h, T, D_h)

This monstrosity of a code is staring you right in the face. For those who don’t immediately see the problem with these 4 lines, let me add some annotations on how this would normally execute in PyTorch (without compilation).

# 1. We load q `(B, N_h, S, D_h)` and k `(B, N_h, S, D_h)` # 2. We compute Q.Kt and write it back to HBM. Note score is `(B, N_h, S, S)`. scores = torch . matmul(q, k . transpose( 3 , 2 )) / math . sqrt(self . head_dim) # 3. Reload the scores tensor to compute the softmax and write it back to HBM scores = F . softmax(scores, dim =- 1 ) # 4. Load v `(B, N_h, S, D_h)`, load the scores from HBM # 5. Compute scores@v and write it back to HBM. output = torch . matmul(scores, v) # `(B, N_h, T, D_h)`

Do you see it now? Well, we have three tensors q , k , v each of dimension (B, N_h, T, D_h) . The output tensor of the attention mechanism is (B, N_h, T, D_h) , and somehow in the middle we had to materialize a (B, N_h, S, S) tensor for funsies. The attention mechanism has a critical bottleneck: quadratic memory complexity. Let’s take a standard training sequence length S=8192 . Computing attention naively requires O(S²) memory to store the full attention matrix, which means consuming several gigabytes of GPU memory. Crucially, in modern transformers, S >> D_h (sequence length is much larger than head dimension) - we typically have S=8192 or more while D_h=64 or 128 . This massive asymmetry is what makes Flash Attention algorithm possible. The second big issue here is the back and forth to HBM. Modern GPUs have compute throughput vastly exceeding memory bandwidth. Repeatedly reading Q, K, V, and scores from slow High Bandwidth Memory (HBM) is going to greatly impact performance (more on this later).

The whole idea of Flash Attention is to bypass these intermediate steps—i.e., go from tensors q, k, v to the output tensor directly and compute the attention in one go, with minimal memory footprint (materializing only the tensors we need) and minimal back and forth to HBM, and hopefully getting to O(S) memory complexity.

... continue reading