*January 20, 2025* # Dictionary Learning ![[gettyimages-1441960949-640x640.jpg]] *** In dictionary learning, a variation of sparse autoencoder (SAE) is used to expand and identify features that are semantically meaningful to model behaviour and capability. This is a little post to explore the application of SAE's for interpretability and how they work, more or less expanding the first section of the dictionary learning paper ^[[Identifying Functionally Important Features with End-to-End Sparse Dictionary Learning](https://arxiv.org/abs/2405.12241)] through analogy. *If you like this or have any questions/feedback please reach out!* ## Classic SAE A SAE is an autoencoder where the inner dimension is *higher* than the input dimension. The SAE takes an input, blows it up to a higher dimension, and learns to project it back down again. ![[oldschool-sae.png]] > Image from [*A Unified Coded Deep Neural Network Training Strategy based on Generalized PolyDot codes*](https://www.researchgate.net/publication/327084299_A_Unified_Coded_Deep_Neural_Network_Training_Strategy_based_on_Generalized_PolyDot_codes?_tp=eyJjb250ZXh0Ijp7ImZpcnN0UGFnZSI6Il9kaXJlY3QiLCJwYWdlIjoiX2RpcmVjdCJ9fQ) To train a "classic" SAE, you train a reconstruction loss between the "up" and "down" layers of the SAE to reconstruct the activations of the input layer. You might add a sparsity objective of some kind to the middle layer. This is an old-school approach to *feature extraction*. The emergent result is that activations in the sparse layer begin to represent features that we can intuitively understand. For example, in an SAE trained on the MNIST dataset, some feature would probably activate only for the number "7", some other feature would only activate for slanted handwriting. The idea is that we can use a similar thing to peek into the structural features of a transformer's activations and name meaningful features. We stick a SAE into some layer of the transformer and use it to "sniff" for signals and get a better handle on what is going on. The issue with plain sparse autoencoders, is that unlike MNIST digits, in a transformer's activations, it's not necessarily the most obvious features that are the most important. In the transformer, we want to identify features that are relevant to *computational* properties of the overall transformer; features that tell us what it's *actually doing right now*. The challenge lies in how the "important parts" of an activation, that signal or instigate a complex behaviour can be extremely subtle in the activations—think bufferfly effect. If a SAE is only trained to reconstruct local patterns, the features it learns end up missing the forest for the trees so to speak. ## Finding Better Features ![[Screenshot 2025-01-20 at 2.12.56 PM.png]] > From the paper To address this, in *e2e sparse dictionary learning*, you put the same SAEs in the same places, but train them on an objective that preserves the *output* of the entire transformer, so that the SAE is encouraged to mesh with the transformer's overall behaviour (green/blue arrows). This way, the gradient information shaping the SAEs is more holistic, and will better reflect the "forest" of the model's behaviour and not just the "trees". The features you get with the SAE form a kind of "dictionary". If a transformer's layer activations are layered and information-dense like short poems, the SAE's activations are then like a book written about the poem, expanding the different layers and themes into a more explicit form. This is a little bit similar to superposition in quantum mechanics. In a sentence, quantum mechanics happens when you can only measure the overall amount of energy of something, but not really where the energy is currently located. You could have enough "handles" on a particle like an electron to infer how much energy it ought to have, but not enough to say what it's actively doing. Some energy is stored in the electron's position relative to the nucleus, some energy is stored in its momentum... But you don't know how it's allocated between the two. Nonetheless, there is structure that allows scientists to infer the underlying possibilities and their likelihoods, based on things we *do* know. Borrowing the physics analogy, the transformer's states are in some way analogous to an ambiguous measurement. The SAE's job is to essentially unpack the dense, ambiguous features into a more legible form, where we have surface area to see more explicitly what could be going on broadly. The trick to a *good* SAE is figuring out how to do this in a meaningful way. If an SAE isn't guided toward forest-level patterns, the features it obtains from the tree-level won't be informative about broad behaviour. At the same time, there are many forest-level patterns to choose from, and some are more helpful than others. The goal of an SAE is a flexible and powerful measurement tool that: 1. **You can understand;** you need concepts you can grasp and reason about as a human. Analogous to how the "position" and "momentum" of an electron are immediately sensible things to us, allowing us to play with ideas and connect meaning. Whereas it's not obvious what to do with a "wave function" in a "hilbert space" and understand what it might do. We want to measure things that make sense. 2. **You can adjust;** there are many ways to measure something, and the best angle will depend on what you're trying to do. If you're solving a chemistry problem, you'll measure things to cause a reaction. If you're engineering a semiconductor, you'll measure things to prevent one. We want control over what we measure. Ultimately we want the SAE to help us understand transformer states through intuitive, human concepts. "German", "python code", and "misleading" for example are helpful and meaningful features. Additionally, we want to be able to "focus" an SAE based on what we are trying to measure with it. If we are using an SAE to analyze English writing style, we don't need to waste any space on a feature to tell us if the model is speaking German. ## Minimizing Disturbance Now, there's the additional question of how the model is affected by the SAE. Amusingly, it's a bit like the uncertainty principle from quantum mechanics. By measuring an electron, you alter its state. By adding an SAE to a transformer, you alter the model. By inserting a foreign layer we introduce (at least) two confounding factors. While we want to measure how the network is working, we want to do it in such a way that 1) **We aren't making the network worse;** if the model gets worse from adding the SAE, it becomes an annoying thing to use, and also defeats its purpose as a measurement tool. Ideally, it has no effect at all, but that may not always be possible. 2) **We aren't altering how the network functions;** even if the network *doesn't* take a performance hit, we don't want to alter how the network processes information. It would again, largely defeat the point of the SAE, if the observed network doesn't work the same way as the un-observed network. In a nutshell, the two loss terms used to train the SAE are chosen to control these effects. The KL-divergence between the SAE'd network and un-SAE'd network addresses the first point; it stands to reason that if the outputs of the network aren't changing much, then the model is probably capable of all the same stuff as before. It should be pretty correlated with overall performance on whatever you're looking at. The MSE loss between the affected and unaffected version of each layer addresses the second point. It makes sense that if the activations after an SAE look about the same with or without it, it stands to reason that the network is doing a similar thing in either case. ## Quantifying Precision Now, we have a sense of what we want the SAE to be good at doing, and a sense of what we don't want it to do. What can we quantify? First, we can quantify the overall performance hit of the SAE; a cross-entropy loss over an evaluation set is a reasonable way to do that. The heralded *scaling laws* say in a nutshell: most capabilities kinda probably get better in some way proportional to cross-entropy loss. So, the drop in cross-entropy evaluation loss when we add the SAE is probably an okay measure of what we're taking out of the model. So this is a number we want to keep as high as possible (vertical axis in upcoming plot). Now, a more challenging question: how can we quantify the *precision* of an SAE? To quantify the precision of a microscope, there are two main concerns: 1) How far can you zoom in with it? 2) How is the image quality affected when you zoom in? By analogy, in the SAE we can quantify: 1) How small can we make the dictionary? 2) How much do the dictionary features overlap? ![[Pasted image 20250120125837.png]] > Figure 1 from the dictionary learning paper The first point is quantified with active dictionary elements (right plot). After training, a number of features in the SAE won't be used at all. The features that *are* active are said to be "alive". So, if you can get an SAE that only needs a small number of dictionary elements to explain all of the model's behaviour, it's good at zeroing in on features. For example, if the dictionary only needed just *two* features, that would be pretty remarkable. It would imply we could describe what the model does as a ratio of two features like "good" and "bad". This would be efficient but maybe not precise or helpful. Realistically, shrinking the dictionary this much would hamper the model's performance, and/or result in blurry lines. Now if we have a more reasonable number of dictionary elements, say 10000-20000, we still want them to overlap as little as possible. That way, they are easy to disambiguate. The $L_0$ (left plot) is the number of nonzero entries in a given dictionary encoding, making it a good measure of this. Low $L_0$ means that the SAE's features are generally less likely to overlap. In a nutshell: - Fewer dictionary entries overall $\rightarrow$ more efficient features - Sparser individual entries $\rightarrow$ less ambiguous features So, the overall goal is to strike a balance between these two things while keeping the eval cross-entropy loss up.