Intro Post | Code | Low-Latency Megakernels | Brr
TLDR: We're releasing a throughput-optimized megakernel for tensor-parallel inference with Llama-70B on H100s. Our kernel can aggressively overlap compute, memory, and communication ops in order to simultaneously use the different hardware resources available on a GPU. When integrated into the Tokasaurus inference engine, our megakernel can outperform SGLang by >22% on end-to-end throughput (measured as time to finish 65,536 prompts from the ShareGPT benchmark). We're releasing the code here; please be warned that this really is research code; it is sensitive to compiler versions, GPU setup, and sometimes even being looked at the wrong way, and we have no intention whatsoever of supporting it. We hope you'll find the ideas and results interesting nonetheless!
Figure 1: Zoooommmm
A few months ago, we showed how we could fuse an entire model forward pass into a single "megakernel" in order to deliver low-latency inference with Llama-1B. In that post, we teased that many of the same concepts we introduced would also be useful for optimizing for throughput. We're now excited to bring receipts and release a new megakernel optimized for high-throughput inference with Llama-70B.
The inference workloads targeted by our low-latency and high-throughput megakernels are quite different and require distinct optimizations. Our low-latency megakernel targeted inference using Llama-1B when running on a single GPU with batch size one. This workload was entirely memory bound, and our focus was therefore on eliminating stalls that delayed loading model weights from global memory.
With large-batch Llama-70B inference, our workload is much more heterogeneous. Large portions of it (e.g. matrix multiplies, attention prefill) are compute-bound. Other parts (e.g. attention decode, RMS norm) are still bottlenecked by global memory bandwidth. Additionally, by distributing our model across multiple GPUs, we now need to perform cross-GPU communication that throttles the NVLink connections between devices. By running these components sequentially, we've paid for the whole GPU, but are only using little bits and pieces of it at a time. :(
Overall, these different operations in our model each make use of different resources available on the GPU (e.g. tensor cores, non-matmul compute units, HBM bandwidth, NVLink bandwidth) in unique ways. Therefore, a key area for optimizing this high-throughput workload is to overlap multiple kinds of work in order to simultaneously use more of the GPU's resources. We want to do this across many levels of the GPU -- within an individual SM, across multiple SMs, and even across GPUs.
Existing approaches to overlapping include assigning different SMs to different ops, developing custom kernels to run prefill and decode simultaneously, and running kernels in parallel with cross-gpu memory copying operations. Here, we show that the same simple, interpreter-based megakernel patterns we previously introduced can also achieve all of these fine-grained overlapping patterns -- and more! Most excitingly, despite the significant differences between our low-latency and high-throughput workloads, our core megakernel abstraction (a pipelined instruction interpreter that runs on each SM) is highly transferable across both domains.
In the rest of this blog, we will:
Give a brief recap on the design of our megakernels from our last, low-latency post. Walk through the details of the tensor-parallel Llama forward pass that we map into our megakernel, including a novel approach to communicating intermediate results across GPUs right after running attention. This new operation requires a complicated multi-GPU transpose not efficiently expressable with standard communication patterns, but is trivial to implement within the megakernel! Show how megakernels can achieve fine-grained resource overlapping at multiple levels of the GPU hierarchy: within individual SMs, across multiple SMs, and across multiple GPUs! Within individual SMs, the same inter-instruction pipelining we used in low-latency llama can also help keep overlap memory movement and compute across instructions, thereby keeping the tensor cores running.
... continue reading