Skip to content
Tech News
← Back to articles

I made a kernel 2.2x faster. It made my training loop 3x slower

read original more articles

Making Dr GRPO go brrr

02 Jun, 2026

I wrote a fused decode-attention kernel for an RL training loop, got it 2.2× faster than the SDPA path it replaces at the microbenchmark level, dropped it into HuggingFace's generate , and watched the decode step get nearly 3× slower. The kernel was doing exactly what the microbench said it would. The integration broke an auto-compile path that the baseline was quietly benefiting from. This post is how I got there, what the gap actually was, and what closing it would have cost.

The wider context: this is the writeup of a project to RL-train a small open source model on GSM8K and write CuteDSL kernels for whichever paths dominate. The concrete setup is Qwen2.5-0.5B-Instruct, Dr. GRPO, a single A10G. The post covers two things: building the training loop from scratch (and squeezing 4.8× out of the rollout phase before any kernel work), and then writing the kernel above for the path that still dominated. Most of what follows is what those two facts look like sitting next to each other.

What is RL post-training, and why is it slow

In RL post-training for LLMs, you have a policy (the model), a verifier (something that scores outputs), and a loop that pushes the policy to produce higher-scoring outputs. For a math task like GSM8K, the verifier is just a regex that pulls the final number out of the model's response and compares it to the ground truth.

Each training step has two phases.

Rollout. Sample a prompt. Generate G completions from the current policy. Score them. Compute advantages.

Update. For K inner epochs: forward pass through the policy, compute the GRPO loss against the rewards, backprop, optimizer step.

Rollout dominates wall time. The reason is structural. Update is one big batched forward pass over (B*G, P+C) tokens, then a backward and a step. That's three GPU calls. Rollout is model.generate , which is a sequential decode loop that runs one forward pass per generated token, with each pass operating on (B*G, 1, hidden) plus a growing KV cache. Per-token compute is small, but you do it max_new_tokens times in serial. Even with KV cache and batching, you can't parallelize across the time dimension because each token depends on the last.

... continue reading