Tech News
← Back to articles

A 20-Year-Old Algorithm Can Help Us Understand Transformer Embeddings

read original related products more articles

Suppose we ask an LLM: “Can you tell me about Java?” What “Java” is the model thinking about? The programming language or the Indonesian island? To answer this question, we can try to understand what is going on inside the model. Specifically, we want to represent the model’s internal states in a human-interpretable way by finding the concepts that the model is thinking about.

One approach to this problem is to phrase it as a dictionary learning problem, in which we try to decompose complex embeddings into a sum of simple and interpretable concept vectors. These vectors are selected from a learned dictionary. It is not obvious that we can break down embeddings as a linear sum of interpretable elements. However, in 2022, Elhage et. al introduced supporting evidence for the “superposition hypothesis,” which suggests that this superposition of monosemantic concept vectors is a good model for the complex embeddings. Finding these concept vectors remains an ongoing challenge. When dictionary learning was first proposed for this problem by Bricken et al. in 2023, they used a single-layer neural network called a sparse autoencoder (SAE) to learn the dictionary, which has since become widely popular. The problem of dictionary learning actually goes way back (pre-2000s!), but Bricken et al. opted to forego established algorithms in favor of SAEs for two main reasons:

“First, a sparse autoencoder can readily scale to very large datasets, which we believe is necessary to characterize the features present in a model trained on a large and diverse corpus. […] Secondly, we have a concern that iterative dictionary learning methods might be “too strong”, in the sense of being able to recover features from the activations which the model itself cannot access. Exact compressed sensing is NP-hard, which the neural network certainly isn’t doing.”

In our recent paper, we show that with minor modifications, traditional methods can be scaled to sufficiently large datasets with millions of samples and thousands of dimensions and that their performance matches that of SAEs on a variety of benchmarks. We can also use established theory to gain insights on the applicability of SAEs to different problem sizes, e.g., when less data is available.

From 30 days to 8 minutes: Reviving KSVD

Instead of using gradient descent to optimize the dictionary elements as done by SAEs, we consider the previously reigning champion of dictionary learning: the 20-year-old KSVD algorithm. KSVD solves the dictionary learning problem through alternating optimization of two subproblems in a loop: (i) updating the concept vector assignments for each sample and (ii) updating the concept vectors themselves given the assignments. Intuitively, KSVD is a generalization of k-means clustering, which uses a similar alternating optimization, but with the key difference that KSVD assigns samples to multiple clusters (concept vectors) at once.

In the naive implementation of KSVD, dictionary elements are updated sequentially and each update requires computing an expensive singular value decomposition of the assigned samples at each step. Based on extrapolated timing results, a naïve implementation of KSVD would take over 30 days to produce a dictionary sufficient to interpret LLM embeddings. In our paper, we introduce algorithmic modifications and a highly efficient implementation, which we call double-batch KSVD (DB-KSVD). DB-KSVD provides a 10,000 times speed up over a naïve KSVD implementation and allows us to find interpretable features for LLM embeddings in just 8 minutes.

We provide an implementation of DB-KSVD in our open source Julia package called KSVD.jl.

KSVD.jl can easily be called from Python:

import numpy , torch import juliacall ; jl = juliacall . Main jl . seval ( "using KSVD" ) # placeholder embeddings Y = torch . rand ( 128 , 5_000 , dtype = torch . float32 ) res = jl . ksvd ( Y . numpy (), 256 , 3 ) # choose dictionary size and sparsity print ( res . D , res . X ) # returns dictionary vectors and sparse codes

... continue reading