A perspective on continual learning, motivating our paper on sparse memory finetuning
If we want to move towards a world where models are “always training” and continually learning from experience over time, we need to address a basic challenge: how do we keep updating the parameters of a model without breaking it? In this post, I’ll motivate memory layers as a natural architecture for this paradigm: high-capacity, but sparse (few active parameters) on each forward pass. In our recent paper, we found that finetuning memory layers enables learning without forgetting much more effectively than LoRA. When learning TriviaQA facts, NaturalQuestions performance drops by 89% with full finetuning and 71% with LoRA, but only 11% with memory layers. Along the way, I’ll also discuss the challenges of the continual learning problem broadly.
Check out the paper here: Continual Learning via Sparse Memory Finetuning
Generalization vs. Integration
Continual learning has been studied for decades at this point. I think a lot of the discussion today is muddied by the fact that there are many classic formalizations of this idea, potentially none of which align with what we think of as real-world continual learning.
Intuitively, what many people think of is a system that can be taught like an intern. Every time it encounters a new experience, learns a new fact, or gets feedback from the user, the system should get smarter over time, just like people do.
What are the research questions we need to solve before this is a reality?
I think of continual learning as two subproblems:
Generalization: given a piece of data (user feedback, a piece of experience, etc.), what update should we do to learn the “important bits” from that data?
Learning to regurgitate information is easy (just overfit / memorize). The challenge is learning something that can be used in diverse future settings. This is well-understood now for fact learning. When we see a string like “Barack Obama was born in Hawaii” in the training corpus, what I want the model to learn is not that the word Hawaii always follows in , but the semantic content that the tokens represent: “Barack Obama,” a real-world entity, was “born” (abstractly, “came into existence”) in “Hawaii,” a real-world entity. From the perspective of a language model consuming the raw tokens, it’s ambiguous which “hypothesis” it’s trying to learn from the data. This explains why paraphrasing is necessary to robustly teach models new facts (Physics of Language Models). Augmentation disambiguates what hypothesis we want the model to learn.
... continue reading