In the last two posts, I gave a high-level walkthrough of the transformer architecture. In this post and the next, I’ll talk about the training process—i.e., how you start with a particular architecture and then improve its performance over time.
If you’re already familiar with machine learning paradigms and architectures other than transformers, I expect there won’t be many surprises about the mechanics of training—it’s just gradient descent and backprop as normal. So, you can probably skip or skim most of this post, though you may still want to check out the section on fine-tuning ChatGPT at the end.
I think there are also a lot of conceptual mistakes people make about what the result of training is and the results of scaling transformer models up, so I will cover that in the next post.
If you’re unfamiliar with machine learning, it’s good to get a sense of how training works. But since the training for transformers is relatively generic, I will keep this post on mechanics very high-level and conceptual and almost entirely math-free, though I’ll provide some resources for folks who want to learn more.
Training
If you’ve played around with ChatGPT, you know that it produces fluent English sentences and often gives interesting and coherent responses. But all I said in Parts I and II is that transformers have embeddings, attention heads, multi-layer perceptrons, and unembeddings that ultimately come together to make predictions about the next token in a sequence. I didn’t say anything in depth about how they figure out how to make good predictions about the next token.
To produce reasonable text, all the components in the model have to coordinate with one another and do reasonable things themselves. Embeddings have to figure out what sort of information to encode and how to do it, attention heads have to figure out what to pay attention to based on the embeddings, and so on.
How does that work? Very briefly: with a bunch of totally wild guesses that all get refined over time.
In more depth: we first establish the architecture. We determine:
how many dimensions the embedding should have
how many attention heads there should be in each layer
how many dimensions the attention heads should be
how many dimensions the MLPs should be
what activation function the MLPs should use (e.g., the ReLU), and
how many layers there should be in total.
As a reminder, in GPT-3 Small, the embeddings are 768 dimensional, there are 12 heads per layer, each head is 64 dimensional, and there are 12 layers total. That is all pre-set and does not change over the course of training.
We also pre-set how to tokenize text and encode it in one-hot vectors. In the last post, I treated each word as a token and will generally continue to do so. But really, you don’t want your vocab to be full words, but instead you’ll use subwords, punctuation, special end-of-sentence tokens, START tokens, new paragraph tokens, and so on. In any case, the tokenization is a separate process that the model doesn’t learn. (We’ll also end up encoding positional information, which I’ll totally ignore.)
From there, it’s mostly random guesses.1 For each token in the vocab, we first randomly guess which numbers we should use to encode it. If we use GPT-3 Small’s architecture with 50,000 tokens in the vocab, that’s 768×50,000=38.4 million random guesses right there. The unembedding matrix is the same size, so we get around 76.8 million random guesses devoted to how to encode and decode the vocabulary.
Each attention head comes with its own query, key, value, and output matrices. The Q, K, and V matrices are all 768 by 64 dimensional, and the output matrix is 64 by 768 dimensional, so each head comes with its own set of around 197,000 random guesses. Since there are 12 layers total and 12 attention heads per layer, that’s around 28 million random guesses devoted to attention heads. The MLPs have around 4.7 million random guesses per layer, and around 56.6 million guesses total. Each of these random guesses is a parameter.
If you add those numbers up, you’ll find around 161 million parameters. GPT-3 Small is actually listed as having 125 million parameters because the embedding parameters (all 38.4 million of them) are excluded from the count. By comparison, in full scale GPT-3 there are 175 billion parameters.2
When you set up the architecture initially, you’ll end up with complete gobbledygook as output if you try to use it. The output will not only make no sense, but it won’t be anywhere close to grammatical. In fact, if you use subwords for your tokens, you won’t even get English words as output. In effect, before any training, the model is just randomly choosing elements from the vocabulary for its next-token predictions.3
So, given that we start with over a hundred million random guesses, it’s going to take a bit of tweaking to get good output. Luckily, there is a lot of text available that the model can learn from.
At a conceptual level, the way this works is that the model is fed a chunk of text like Believe truth! Shun
and as always will assign a probability distribution to all possible next tokens. At the start it might give a high probability to completions such as baseball
, happy
, or climbs
. But it gets scored based just on how high of a probability it assigns to the true next token, which in this case is error
.4
It then will slightly tweak its parameters so that next time, it would be just a little more likely to predict error
after Believe truth! Shun
. Now, given the massive number of parameters, it’s going to have to do quite a bit of reading. But after going through a large portion of the internet a few times and making minor adjustment after minor adjustment, it will get pretty good at predicting what comes next after a string of text.5
You might wonder how we can determine which way the parameters should be tweaked to raise the probability of the actual next token. If you have an entry in, say, the 109th row and 31st column in the query matrix of an attention head in the fifth layer, how do you know whether to raise or lower that entry by a tiny amount to reduce loss? The algorithm used to revise the parameters with training is some flavor of stochastic gradient descent (combined with backpropagation). That is the completely generic method used across all of machine learning these days. Although there are a few different varieties and small modifications, SGD with backprop is the only game in town. I'll avoid going into detail here about how it works because (1) it's completely ubiquitous and many readers will have seen it used elsewhere, (2) a full explanation would require a long, mathy passage, and (3) there are already a lot of different explainers online for those who wish to know more, some of which I'll link to.6 But it really just appeals to basic calculus that tells you which way to tweak each parameter to get a slightly better result next time on the same training example.
An Important Addition for ChatGPT
After this stage of training, you end up with a model that is very good at predicting the next token in the sequence. But predicting the next token isn’t exactly the same as being a good ChatBot. To see why, here’s an example from Scott Alexander from GPT-3 (the non-chat version):
In this passage, GPT-3 is doing a good job predicting the next token, but it’s not answering the prompt in a helpful way.
So, after the training process just described (known, weirdly, as pre-training), we can switch gears and instead give the model scores based on how much we actually like the answers rather than scores based on how well it predicts the next token in a sequence.
After pre-training, we give it some prompts. It generates some answers. Humans decide which of those answers is best, which is second best, etc. And the parameters are then then tweaked to generate more likeable answers over time, again by using some version of gradient descent. (Eventually, we can also partially automate some of this process as well with a separate model that can score how good the answers are.) This process is known as Reinforcement Learning with Human Feedback (RLHF).
Obviously, pre-training the model is quite a bit easier to scale than RLHF. There’s a lot of text on the internet just lying about, and even with very large models, you can get enough data to get them to be good at generating English prose. Once we require human feedback for evaluating responses, however, we get a bit bottlenecked. Nonetheless, it turns out that this human-feedback fine-tuning phase works pretty well despite the relatively low amount of data since the model is already pretty good.
And that’s how we end up with ChatGPT. We start by tweaking the parameters just based on how well it can predict the actual next token in a sequence. When then tweak the parameters based on how well humans like the responses it gives (and based on what a separate machine learning model thinks about how much humans will like the responses it gives).
Further Resources on Training
Gradient Descent and Backpropagation
This Three Blue One Brown series is probably the best friendly introduction for building intuition.
This article has some nice visual depictions of gradient descent and explains some important nuances.
Here’s a similar article on backpropagation from the same site.
Fine tuning with RLHF.
Here’s a nice, detailed discussion of the general ideas behind RLHF and fine-tuning ChatGPT after pre-training
There is a bit of an art to the random guesses, as you’ll want to find a good distribution from which to select the initial parameter values.
We won't worry here about any further discrepancies in the total count.
Really, it’s generating a probability distribution over the next token. But that probability distribution will not favor sensible or grammatical choices.
Actually, it will get a score for its prediction for the next token after Believe
, after Believe truth!
, and after Believe truth! Shun
, and it will make tweaks based on some average of these scores so that it would do better overall next time with the same sequence.
As usual, I’m leaving out a fair bit of nuance that isn’t conceptually important. The mathematically best way to implement this algorithm is to find the loss on all examples in your training set and then update the parameters to get a slightly better average result the next time on the whole training set. In practice, people will average across a small batch of training examples instead of just a single one.