We’ve now covered the core components of transformers in Parts I, II, and III. Here, I want to step back a bit and think about what a trained model does, how the architecture scales, and what transformers are really doing. This post will collect a number of different points together, so there’s less of a narrative than usual.
Thinking about What the Trained Algorithm Does
We understand the architecture of LLMs, but we don’t really know what algorithms they learn. Although there’s been some noble effort lately to get more traction on what’s going on internally (which I’ll talk about in Part V), to a first approximation, we just have a bunch of gigantic inscrutable matrices being multiplied and added together, and we end up with sensible outputs.
But as the matrices get altered with training, very different things go on under the hood. What the various attention heads pay attention to will change drastically with training. The information moved with attention, the way information gets encoded, the computation done by MLPs, and the way blocks from various layers interact and compose will also change. Furthermore, even if we get a bit of a handle on how individual parts of a transformer model work, we still would be a long way from understanding the model as a whole. Developing a cognitive science of transformers is no easy task.
Internal Representations
LLMs can ultimately do a very good job at conversing with humans. But internally, we don’t have much of an idea about how they represent information. They can answer at least some questions about the physical world in a way that suggests there’s some kind of internal representation of causation and physics, but we don’t have any real idea of how this works.
It’s also unclear if it develops some kind of representation of English grammar. No information about grammar is initially fed to the model. (Perhaps a little bit is encoded in the tokenization, but that seems pretty minimal.) There’s no specific module to learn syntactic structures or making parsing trees or anything like that. Nonetheless, we end up with fluent and grammatical text.
There is just very little direct symbolic manipulation, and I think one important line of research in the future will be backwards engineering algorithms to see if anything like, e.g., a parsing tree is somehow represented within the model anywhere.
A Conceptual Mistake
It’s tempting to think that “all” large language models are trying to do is predict the next token in a sequence and minimize predictive error. Even if we ignore the potential problems with anthropomorphizing the model and allow that it is actually trying to do something, there are two ways in which this idea is still wrong.1
First, as we just discussed in the last post, ChatGPT is fine-tuned based on how much humans like its answers (and also via an automated model that predicts how much humans will like its answers). This is after it’s pre-trained in the standard way to get it to minimize predictive loss. So, it’s really quite unclear at this point what it could be optimizing for, since it’s gotten two very different types of feedback.
Second, even if we stopped after the pre-training and even if we allow that the model is genuinely trying to do something, it’s not at all obvious that it’s trying to minimize loss. All we know is that, as a matter of fact and for predictable reasons, the training process will find a model that uses an algorithm that gets low loss on data relatively similar to the data it’s been trained on. It might not “care” about loss at all.
Consider an analogy. Evolution by natural selection is a lot like a training algorithm. You know that at the end of a long process, in environments like the environment where organisms evolved, we’ll end up with organisms that are pretty good at maximizing inclusive genetic fitness. But when my dog scratches her ear or lies in the grass, it is generally inappropriate to describe her as “trying to maximize her inclusive genetic fitness.” That’s not to say that there isn’t some context in which we can describe her as trying to do that, but it’s not really the way you should be thinking about what she’s up to. Likewise, when humans use birth control, prove theorems, collect baseball cards, fly to outer space, or tweet, you generally shouldn’t be thinking of them as maximizing IGF.2
General Sequence Predictors
I’ve spoken of transformers as models as language models. But transformer architecture can predict sequences generally. You just need some kind of vocabulary, which you can embed in a way that encodes positional information, and the architecture will spit out predictions about what you’ll see next. So, transformers could be used to predict images, stock prices, mathematical sequences of numbers, the results of physical experiments, physical trajectories of particles, or any other sequence. Transformers are also Turing Complete.
Of course, we don’t have models yet that are good at all the tasked listed, and the fact that the architecture is Turing Complete doesn’t mean we have any idea how to train transformers to execute any computational task we want them to. It just means that some parameters for an arbitrarily large model exist which can work to simulate any program. Nonetheless, the right transformer model could be a much more general intelligence than the chatbots we have today.
Scaling and Capabilities
One thing you may have noticed about the transformer architecture is that it’s easy to scale. You can make the embeddings bigger, the heads bigger, add more layers, add more heads, and so on.
Empirically, it turns out that the big drivers of performance are the number of parameters, the amount of computation available for training, and the size of the dataset used for training.
There are two points I want to make here. First, while the decrease in loss is predictable as you scale up, the gain in capabilities is not. That is, what we can predict is that, with scale, somehow or other models will get better and better at predicting the next token.
What we can’t predict is what new skills they will gain along the way. For instance, at some point, LLMs start being able to do three digit arithmetic. They weren’t specifically trained to do this. It just fell out as an accident of training. Being able to do arithmetic does help lower your loss a bit, but there was no way to know ex ante at what size the model would gain this functionality. Furthermore, many of these skills are acquired suddenly, rather than emerging gradually over the whole training set.3
Second, if you hang out on social media at all, you sometimes see (non-ML) professors worried about the future of teaching say things like, “Well, ChatGPT can write semi-coherent prose, but it can’t do [x],” where [x] means something like write creatively, solve mildly complex math problems, generate interesting new ideas, pass some college level exam, or whatever.4 These professors are right, and ChatGPT does not have those capabilities. But it’s generally not clear that we need any kind of conceptual innovation to get the next chatbot to get to that level. We may just add some more layers and parameters and training and just generally throw a bit more compute at our model next time, and voilà.
Some Problems for Scaling
There are some problems with scaling LLMs indefinitely. For one, we do need more and more training data, and eventually you run out of internet to read. This could eventually be a problem.
Secondly, you may have noticed that the attention heads are rather inefficient. Every token potentially pays attention to itself and every previous token. That means that the amount of computation needed scales roughly with the length of the input squared. This is why even very large LLMs tend to train on either 1024 or 2048 token length inputs. (And, recall, that’s not 2048 words. Tokens for big LLMs are sub-word length.) What this means is that transformer models can’t look especially far back in a passage to make a prediction about what comes next, and scaling to be able to handle larger and larger input lengths will be expensive.
My guess is that there will be tricks for getting around the relatively low token limit with transformers—e.g., there might be a way to embed information about earlier parts of a bigger context in a succinct vector that encapsulates a bunch of text. But I’m not an ML engineer, and that’s all speculative. So, scaling may well be bottlenecked in the relatively near future, but I wouldn’t bet on it.
Further Resources
Paper on scaling laws
Paper on inner alignment
Paper on Turing completeness
This paper by Murray Shanahan makes a sophisticated version of this error, where Shanahan argues that LLMs don’t have concepts of truth and falsity, but are instead just finding statistically likely continuations of sequences.
This is a point that’s more powerful in the context of reinforcement learning models, where the model is rewarded based on some goal, such as the number of points it gets in an Atari game or the amount of money it earns in the stock market. I believe the classic source is this paper.
One aside here. Because of the way tokenization works, strings of numbers won’t always be tokenized with the same number of digits. E.g., 256
may be represented as a single token, but 197
may be represented as the tokens 19
followed by 7
. So three digit arithmetic is a bit harder than it appears for LLMs!
I felt bad hunting around twitter for specific threads.