In a previous post on language modeling, I implemented a GPT-style transformer. Lately I’ve been learning mechanistic interpretability to go deeper and understand why the transformer works on a mathematical level.
This post is a brain dump of what I’ve learned so far after reading A Mathematical Framework for Transformer Circuits (herein: “Framework”) and working through the Intro to Mech Interp section on ARENA. My goal is to describe my current intuition for the paper, especially parts I was confused about so that perhaps my take can help others gain clarity on these areas as well.
First, a brief aside on my overall motivation for working on this stuff. Mechanistic Interpretability (MI/mech interp) is the study of ML model internals whose aim is to understand from first principles why models behave and work as they do. You can kind of think of it as the machine learning analogue of reverse engineering software. It is similar in spirit to the science of biological neural networks, but applied to artificial neural networks instead.
MI is part of a broader field of interpretability, which is used in yet another field called AI alignment. Alignment strives to make our large AI models aligned with human values. Basically, the overall goal is to understand and control the models before they control us. To ensure that they don’t engage in harmful, deceptive, dangerous, or subversive behavior. Unfortunately, we live in a world where large language models have encouraged “successful” suicide, engaged in blackmail for self-preservation, and asserted humans should be enslaved by AI. This current version of reality is unacceptable to me.
And as if that weren’t enough, we don’t even understand why these models do what they do. They are the only man-made technology in history that we don’t fully understand from first principles. Given this state of reality, I think that alignment is one of the most important problems we face today and one we have to get right. As a personal bonus, the alignment problem is as fascinating as it is important. It provides an outlet for me to leverage my specific technical skills and interests towards a meaningful cause. It is also extremely difficult, and I like a good challenge.
Ok, now back to the originally scheduled programming.
Share
Attention-Only Transformers
Framework does a deep dive into the key components of a simplified transformer-based language model. It analyzes transformer blocks that only have multi-head attention. This means no MLPs and no layernorms. This leaves the token embedding and positional encoding at the beginning, followed by n layers of multi-head attention, followed by the unembedding at the end. Here is a picture of a single-layer transformer with one attention head only:
My goal in this post is not to re-derive all the math, because the Framework paper does a better job, and Neel Nanda’s walkthrough of the paper on YouTube is also good for that (although this material only really started to click for me after I worked through the “Intro to Mech Interp” problems on ARENA, which I recommend doing if you are actually interested in doing this stuff yourself).
... continue reading