The magic of AI is generalization – models go beyond what is exactly in their training data and manage to generalize to “similar” cases. Statistical learning theory has traditionally been the lens through which AI researchers have looked to mathematically understand generalization.
Statistical learning theory models generalization as follows. There is a data distribution from which all data – training, validation, test – are drawn independently and identically. The goal of the learner is to then learn a good approximation to this underlying distribution. For example, when training a digit classifier, we have a training sample of images, which are drawn from the data distribution, and we want to build a classifier that works well on this distribution and not just the training data.
Since Valiant (1984), there has been a large body of very beautiful mathematical work on when this works and under what conditions on the data distribution and the class of classifiers. My job in this post is to not go into the details of this work, but to talk about very high level insights that we get from the entire body. In this post I will describe what learning theory gets right about ChatGPT, and in the next post, I will talk about where the gaps are.
What Statistical Learning Theory Gets Right
There are of course two obvious things that statistical learning theory gets right – more data is more, and inductive bias matters. Both are definitely borne out by observation – scaling laws establish the former, and inductive bias in the form of the right transformer architecture shows the latter. But here I am talking about slightly more non-obvious lessons.
One of the biggest lessons of statistical learning theory for generative models is this: good generalization means models will reflect the statistical patterns of the training data distribution. Of course, we do not know exactly what the distribution of internet text is, but squinting from a very long distance, we sort of know what it might look like. A key prediction from learning theory is that trained models should reproduce the frequencies and patterns they observe during training – and this turns out to be strikingly true in subtle ways.
Here is an example: if you ask a language model to generate a random number, the answer is most frequently 7 – the same if you ask a human being. Of course one explanation for this is that there is something biologically special about 7, and the language model somehow magically learns it, but a much simpler explanation is learning theory – humans most frequently report 7 as a random number in their writings, which forms the training data for most large language models.
This pattern – of predicting the right frequencies that are seen in training – is also seen in the fine-tuning setting. There are many examples in the literature, but one that we have seen in our work is in our recent NeuRIPS paper. Here, we finetune a large language model with the ChatDoctor dataset of doctor-patient chat conversations. When put into a generative mode, we then find that this language model generates conversations with the same frequencies of properties as in the fine-tuning data. For example, if 30% of the data involves women patients, then close to 30% of the generated conversations will feature women patients as well.
A third example is in text-to-image generation models, which are usually trained on very large-scale annotated image datasets from the web. A well-known problem in these models is that they do not understand negation – if you ask an otherwise high-quality model to generate a cat but not a dog, it will generate both a cat and a dog. This is again statistical learning theory in action – web data usually annotates images with what is in the image and not with what is not. Training a model on this kind of data naturally does not teach the model about negatives.
Many more similar examples can be found even in highly sophisticated models – and this is an indication that aspects of statistical learning theory can still give us interesting insights. In the next post I will talk about some examples where statistical learning theory goes wrong, and does not quite get there. Stay tuned!