Understanding & Implementing a Tiny GPT-2 (With Limited GPU)

Introduction

Everyone knows what GPT-2 is. We’ll try to understand it by implementing a tiny version of it. The reason I say tiny version is because we can’t really replicate the exact architecture because of limited resources. I only have a single T4 GPU on my Colab and probably only 3 - 4 hours of runtime. So, we’ll use a small dataset and small architecture configuration. I guess we should call it tiny GPT-2.

The goal of this ‘blog’ is not to replicate GPT-2’s full scale. Instead, the goal is to deeply understand:

  • What each step in the pipeline does
  • Why it is needed
  • How it connects to the previous step
  • What math happens inside it
  • How the code implements that math

If you want to skip the blog and go to the direct implementation, you can find it on my GitHub and on Colab.

I will try to explain every step thoroughly and we will carry one running example throughout:

“I think therefore I am.”


Notation Table

Symbol Meaning
V Vocabulary size (1024)
d Embedding dimension (512)
T Context window (128)
H Attention heads (8)
X Input embedding matrix
E Token embedding matrix
P Positional embedding matrix
Q,K,V Query, Key, Value
S Attention scores
A Attention weights
LN LayerNorm
FFN Feedforward network
CE Cross entropy loss
logits Raw output before softmax

1. Tokenization

1.1 Why Tokenization Exists

Transformers do not understand letters, words, or sentences. They only understand numbers. Tokenization does exactly that:

Convert text into integers in a meaningful, consistent way.

Emphasis on “meaningful”. If “cat” is [20, 3, 99] today and [4, 77] tomorrow, the model cannot learn anything consistent. So, tokenization must be deterministic, reversible and stable across training and inference.

GPT-2 uses byte-level BPE (Byte Pair Encoding), which is efficient and capable of handling all languages, including emojis.


1.2 Byte-Level BPE: Characters → Bytes → Merges

Byte-level BPE works by first turning text into raw bytes (0–255). Then it looks for byte pairs that appear together frequency. When it finds such a pair, it merges those two bytes into a single, larger token and assigns that new token its own integer ID. As this process repeats, the tokenizer gradually builds a small, efficient set of subword tokens.

This approach avoids the problems of character-level tokenization (too long) and word-level tokenization (too many unknown words). Because it starts from bytes, it can represent any text and because it learns common subword patterns, it keeps sequences short. Frequent patterns like “th” or “ing” become single tokens, while rare patterns stay broken into smaller pieces, so the tokenizer never produces an “unknown” token.

Let’s use UTF-8 for encoding:

Text I t h i n k t h e r e I a m
Byte 73 32 116 104 105 110 107 32 116 104 101 114 101 32 73 32 97 109

This is the raw material for BPE.

Let the initial sequence be \(X = [b_1, b_2, \ldots, b_n]\), where \(b_i \in \{0, \ldots, 255\}\) and \(V = \{0, 1, \ldots, 255\}\) be the initial vocabulary of bytes.

For any adjacent pair of bytes \(b_i\) and \(b_{i+1}\), we count the number of times they appear together in the sequence. This is denoted by \(C(b_i, b_{i+1})\).

During each merge, we select the most frequent pair of bytes and merge them into a new token \(t_{new}\) :

\[(u, v) = \arg\max_{(a,b)} C(a,b)\]

and replace every non-overlapping occurrence of \(uv\) with \(t_{new}\) inside \(X\). After merging, the length of \(X\) typically decreases (it “compresses”), and we recompute pair counts on the new sequence. We repeat this until the vocabulary reaches the target size (e.g., 4096).

Intuitively, high-frequency neighboring bytes (or already-merged tokens) become single tokens. Over many merges, common character bigrams, trigrams, and word pieces are “absorbed” into compact subword tokens. This improves efficiency downstream: sequences become shorter, and the model can learn reusable subword patterns.


1.3 Counting Adjacent Pairs

BPE is a greedy algorithm, at each step we only merge the most frequent adjacent pair. This requires recomputing pair statistics after every merge because merges change the sequence and therefore the counts. Selecting the most frequent pair ensures we prioritize the merges that yield the biggest compression and the most reusable subword units.

Example counts:

Pair Bytes Count
“I␣” (73, 32) 2
“␣t” (32, 116) 2
“th” (116, 104) 2
“re” (114, 101) 2

The most frequent pair is merged into a new token. After a merge:

73, 32 → 256
32, 116 → 257
116, 104 → 258
104, 101 → 259

So the sequence compresses.

This repeats until we reach vocabulary size (1024).


1.4 Why Byte-Level BPE Works

Why not just tokenize by words? or why not just keep characters?

Well, Byte-level BPE answers both concerns. Characters keep sequences too long; words produce out-of-vocabulary problems and brittle vocabularies. Byte-level BPE starts from the universal building blocks (bytes) and learns useful subwords, guaranteeing reversibility and zero unknowns, while keeping sequences reasonably short, almost like a child learning to spell.


1.5 Tokenizer Code

We use Hugging Face’s ByteLevelBPETokenizer, which implements the algorithm in optimized Rust. We don’t have to write the merge loop ourselves; instead, we point it to our corpus and ask it to learn a vocabulary of a given size. Internally, it performs the pair-counting, the argmax selection, the creation of new token IDs and the non-overlapping replacements repeatedly until it reaches the requested vocabulary size.

tokenizer = ByteLevelBPETokenizer()
tokenizer.pre_tokenizer = ByteLevel()
tokenizer.decoder = ByteLevelDecoder()
tokenizer.train(
    files=[file_path],
    vocab_size=1024,
    min_frequency=2,
    special_tokens=["<|endoftext|>"],
)

# All token to IDs
all_ids = tokenizer.encode(text).ids
all_ids = np.array(all_ids, dtype=np.int32)

2. Preparing Training Data

After converting the entire corpus into a long sequence of token IDs in Section 1, we now need to transform that stream into training examples the model can learn from. This involves three steps:

  1. Train/Val Split: We split the sequence into a training set and a validation set, typically using a 90/10 ratio.
  2. Sampling Windows: We define fixed-size windows (e.g., 128 tokens) to create training examples.
  3. get_batch(): We implement a function to sample random windows from the data.

2.1 Train/Val Split

We divide the long token sequence into two parts: 90% for training and 10% for validation. The model learns from the training portion, and the validation portion tells us how well the model generalizes to unseen text.

n_tokens = len(all_ids)
split_idx = int(0.9 * n_tokens)
train_tokens = all_ids[:split_idx]
val_tokens   = all_ids[split_idx:]

2.2 Sampling Windows

We choose a context window length (block_size) that tells the model how many tokens it can “see” at once. GPT-style models operate on fixed-length sequences, so every training example must have the same number of tokens. We are keeping out block size small because we are training on a haiku dataset.

block_size = 64
batch_size = 256

2.3 get_batch()

We generate a batch of random windows from the dataset. Each window becomes an input sequence x, and each corresponding shifted window becomes a target sequence y.

Random sampling prevents the model from learning only the local flow of the text (e.g., Shakespeare line-by-line). Instead, it sees many varied contexts across training, which improves generalization and stability.

def get_batch(split):
    data = train_tokens if split == "train" else val_tokens
    max_start = len(data) - block_size - 1
    idx = np.random.randint(0, max_start + 1, size=(batch_size,))
    x = np.stack([data[i : i + block_size] for i in idx])
    y = np.stack([data[i + 1 : i + 1 + block_size] for i in idx])
    return x, y

2.4 Math behind the Training Windows

Each training window creates 128 prediction tasks. For input tokens:

\[x = t_1, t_2, \ldots, t_{128}\]

the model tries to predict the next tokens:

\[y_1 = x_2, y_2 = x_3, \ldots, y_{128} = x_{129}\]

This is exactly the autoregressive language modeling objective:

\[P(y_t | x_1, \ldots, x_{t-1})\]

By sliding this 1-token shift across millions of windows, the model learns grammar, style, structure and statistical regularities of the language.


3. Embeddings

Up to this point, we have converted our raw text into a long sequence of token IDs. These IDs are integers (like 391, 82, 256, etc.) that come from the Byte-Level BPE tokenizer. Although these integers are essential for compactly representing text, they do not carry any semantic information. The number 256 does not mean anything by itself. It just identifies a subword unit (a particular BPE merge) in the vocabulary.

A Transformer cannot operate directly on these integer IDs. For a single token we use a vector of length \(d = 256\).

3.1 Why Embeddings Exist

When a token ID like 73 (the byte for “I”) enters the model, nothing distinguishes it from 116 (“t”) or 32 (space). If we fed these raw integers into a neural network, the model would learn nothing meaningful because numerical closeness has no semantic meaning. For example, 32 (space) is not semantically “closer” to 38 than to 200, but a raw neural network would be forced to treat it as such.

Embeddings solve this issue by mapping each integer token ID to a learned vector in a continuous space. This allows the model to develop a representation of what tokens mean, how they relate to one another, and how they function in context.

Every token receives a vector of length d = 256 in your implementation. Over training, tokens that behave similarly in text (like “the”, “a”, “an”) will come to have vectors that live relatively close together in this space.


3.2 Token Embeddings

We define the token embedding layer as:

self.tok_emb = nn.Embedding(vocab_size, n_embd)

nn.Embedding(vocab_size, n_embd) creates a lookup table that maps each token ID to its corresponding embedding vector:

\[E \in \mathbb{R}^{V \times d}\]

Where:

  • $V$ is the vocabulary size (number of unique tokens, in our case it’s 4096)
  • $d$ is the embedding dimension (256 in our implementation)

Let’s say the tokenizer gave:

“I” → 73
“think” (after BPE merges) → 502
“therefore” → 1112
“I” → 73
“am” (after merges) → 870

The token embedding layer converts the integer sequence into 256-dimensional vectors. Each of these is a learned dense vector such as:

E[73]   = [0.12, -1.04, 0.55, ..., 0.73]
E[502]  = [-0.22, 0.15, 1.98, ..., -0.34]
E[1112] = [0.01, 0.99, -0.01, ..., 0.02]
E[73]   = [0.12, -1.04, 0.55, ..., 0.73]
E[870]  = [0.01, 0.99, -0.01, ..., 0.02]

These numbers are randomly initialized at the start but during training they will converge into meaningful representation patterns.

3.3 Positional Embeddings

Even if we embed each token into a meaningful vector, the model cannot distinguish I think therefore I am from am I therefore think I. Both sequences would have the same token embeddings, just in a different order.

Transformers process all positions simultaneously and have no inherent notion of sequence order. This is unlike RNNs, where order is built into the recurrence; Transformers need order to be injected manually. This is where positional embeddings come in, which we define by:

self.pos_emb = nn.Embedding(block_size, n_embd)

This creates a learned matrix:

\[P \in \mathbb{R}^{T \times d}\]

Where:

  • $T$ is the block size or the maximum sequence length (128 in our implementation)
  • $d$ is the embedding dimension (256 in our implementation)

and each row P[t] is a learned vector that encodes the position of the token in the sequence.

The input to the Transformer becomes:

\[𝑋[t] = 𝐸[token_t] + 𝑃[t]\]

which fuses the token identity and positon information into a single vector.

For our sequence I think therefore I am, the positional embeddings would be:

\[P = [P[0], P[1], P[2], P[3], P[4]]\]

And the final embedding would look like this:

X[0] = E[73] + P[0]
X[1] = E[502] + P[1]
X[2] = E[1112] + P[2]
X[3] = E[73] + P[3]
X[4] = E[870] + P[4]

You will notice that even through the token “I” appears twice, it has different positional embeddings. This allows the model to learn powerful positional patterns such as verb tenses, sentence structure, and more.


4. Transformer Block

A GPT model is made of a stack of identical Transformer blocks. In our implementation, we use 4 blocks but each block performs the same sequence of operations.

A single Transformer block takes in a sequence of embeddings:

\[X \in \mathbb{R}^{B \times T \times d}\]

Where:

  • $B$ is the batch size (32 in our implementation)
  • $T$ is the block size or the maximum sequence length (128 in our implementation)
  • $d$ is the embedding dimension (256 in our implementation)

To understand the block, we will break it down into its components:

  1. (Pre) LayerNorm
  2. Self-Attention
  3. Feedforward Network (MLP)
  4. Residual Connections

4.1 (Pre) LayerNorm

We use pre-norm Transformer design, where normalization happens before attentions and before the feedforward MLP.

class TransformerBlock(nn.Module):
    def __init__(self, n_embd, n_head, block_size, dropout):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
        self.ln2 = nn.LayerNorm(n_embd)
        self.mlp = FeedForward(n_embd, dropout)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))  # Pre-norm + residual
        x = x + self.mlp(self.ln2(x))   # Pre-norm + residual
        return x

LayerNorm operates per token, across its feature dimensions.

For each token vector \(x_t \in \mathbb{R}^{d}\):

  1. We compute the mean:
\[\mu_t = \frac{1}{d} \sum_{i=1}^{d} x_{t,i}\]
  1. We compute the variance:
\[\sigma_t^2 = \frac{1}{d} \sum_{i=1}^{d} (x_{t,i} - \mu_t)^2\]
  1. We normalize the token vector:
\[\hat{x}_t = \frac{x_t - \mu_t}{\sqrt{\sigma_t^2 + \epsilon}}\]
  1. We scale the normalized vector:
\[LN(x_t)_i = \gamma_i \hat{x}_t + \beta_i\]

Let’s assume the model produced this pre-attention embedding vector for “therefore”:

\[x_t = [-1.7, 0.2, 0.8, -0.5]\]

And the mean and variance are:

\(\mu_t = -0.5\) \(\sigma_t^2 = 1.2\)

Then the normalized vector would be:

\[\hat{x}_t = \frac{x_t - \mu_t}{\sqrt{\sigma_t^2 + \epsilon}} = \frac{[-1.7, 0.2, 0.8, -0.5] - [-0.5, -0.5, -0.5, -0.5]}{\sqrt{1.2 + \epsilon}} = \frac{[-1.2, 0.7, 1.3, -0.3]}{\sqrt{1.2 + \epsilon}}\] \[\hat{x}_t = \frac{[-1.2, 0.7, 1.3, -0.3]}{\sqrt{1.2 + \epsilon}} = [-0.9, 0.5, 1.0, -0.2]\]

And the final LayerNorm output would be:

\[LN(x_t)_i = \gamma_i \hat{x}_t + \beta_i\]

Where \(\gamma_i\) and \(\beta_i\) are learnable parameters.

Let’s assume \(\gamma_i = 1\) and \(\beta_i = 0\) for simplicity.

So, the output would be:

\[LN(x_t)_i = \hat{x}_t = [-0.9, 0.5, 1.0, -0.2]\]

4.2 Self-Attention

Self-attention is the most important part of GPT-style models. It allows each token to “look at” other tokens in the sequence and decide which earlier tokens are relevant and how much attention should it pay to each of them.

When we read the sentence I think therefore I am., we interpret the word therefore by paying attention to the preceding phrase I think and the following phrase I am. Attention lets the model learn these relationships automatically. This is done by Query, Key and Value matrices.


Step 1 — Q, K, V

Let’s walk through an example with the sentence “I think therefore I am”. After token and positional embedding and layer normalization, each token with a 4-dimensional vector would look like this:

X = [
  [1.0, 0.5, 0.2, 0.8],  ← "I"
  [-0.3, 1.9, -0.6, 0.4],  ← "think"
  [-1.7, 0.2, 0.8,- 0.5]   ← “therefore"
  [1.0, 0.5, 0.2, 0.8],  ← "I"
  [-1.34, -0.45, 0.45, 1.34],  ← "am"
]

We first create Q, K, V and then multiply X by three different weight matrices to create three different representations: Q = X W_Q (Queries) K = X W_K (Keys) V = X W_V (Values)

So, what are these weight matrices? They are learned during training. But what are they?

W_K, W_V, W_Q = [
  [1, 0, 0, 0],
  [0, 1, 0, 0],
  [0, 0, 1, 0],
  [0, 0, 0, 1]
]  (For simplicity, assume W_K and W_V are also identity matrices)

So, the Q = K = V = X. But in reality, these are different matrices that transform the embeddings differently.

The code for calculating Q, K, V is:

self.qkv = nn.Linear(n_embd, 3 * n_embd, bias=False)

Step 2 — Computing Attention Scores

Now each token needs to figure out: “How relevant is every other token to me?” We do this by:

\[score(i, j) = Q_i \cdot K_j\]

We need to compute the dot product of every query with every key. Formula: score[i,j] = Q[i] . K[j]

Token 0 (“I”) queries all keys:

Q[0] · K[0]:
[1.0, 0.5, 0.2, 0.8] · [1.0, 0.5, 0.2, 0.8]
= 1.0×1.0 + 0.5×0.5 + 0.2×0.2 + 0.8×0.8
= 1.0 + 0.25 + 0.04 + 0.64
= 1.93

Q[0] · K[1]:
[1.0, 0.5, 0.2, 0.8] · [-0.3, 1.9, -0.6, 0.4]
= 1.0×(-0.3) + 0.5×1.9 + 0.2×(-0.6) + 0.8×0.4
= -0.3 + 0.95 - 0.12 + 0.32
= 0.85

Q[0] · K[2]:
[1.0, 0.5, 0.2, 0.8] · [-1.7, 0.2, 0.8, -0.5]
= 1.0×(-1.7) + 0.5×0.2 + 0.2×0.8 + 0.8×(-0.5)
= -1.7 + 0.1 + 0.16 - 0.4
= -1.84

Q[0] · K[3]:
[1.0, 0.5, 0.2, 0.8] · [1.0, 0.5, 0.2, 0.8]
= 1.93 (same as Q[0]·K[0])

Q[0] · K[4]:
[1.0, 0.5, 0.2, 0.8] · [-1.34, -0.45, 0.45, 1.34]
= 1.0×(-1.34) + 0.5×(-0.45) + 0.2×0.45 + 0.8×1.34
= -1.34 - 0.225 + 0.09 + 1.072
= -0.403
Row 0 scores: [1.93, 0.85, -1.84, 1.93, -0.403]
Similarly, Token 1 ("think") queries all keys:
Q[1] · K[0]:
[-0.3, 1.9, -0.6, 0.4] · [1.0, 0.5, 0.2, 0.8]
= -0.3×1.0 + 1.9×0.5 + (-0.6)×0.2 + 0.4×0.8
= -0.3 + 0.95 - 0.12 + 0.32
= 0.85

Q[1] · K[1]:
[-0.3, 1.9, -0.6, 0.4] · [-0.3, 1.9, -0.6, 0.4]
= (-0.3)×(-0.3) + 1.9×1.9 + (-0.6)×(-0.6) + 0.4×0.4
= 0.09 + 3.61 + 0.36 + 0.16
= 4.22

Q[1] · K[2]:
[-0.3, 1.9, -0.6, 0.4] · [-1.7, 0.2, 0.8, -0.5]
= (-0.3)×(-1.7) + 1.9×0.2 + (-0.6)×0.8 + 0.4×(-0.5)
= 0.51 + 0.38 - 0.48 - 0.2
= 0.21
Q[1] · K[3]:
[-0.3, 1.9, -0.6, 0.4] · [1.0, 0.5, 0.2, 0.8]
= 0.85 (same as Q[1]·K[0])

Q[1] · K[4]:
[-0.3, 1.9, -0.6, 0.4] · [-1.34, -0.45, 0.45, 1.34]
= (-0.3)×(-1.34) + 1.9×(-0.45) + (-0.6)×0.45 + 0.4×1.34
= 0.402 - 0.855 - 0.27 + 0.536
= -0.187

Row 1 scores: [0.85, 4.22, 0.21, 0.85, -0.187]

Similarly for Row 2, 3 and 4.

Complete Score Matrix (Q K^T):

Score Matrix S = Q K^T = [
  [1.93,   0.85,  -1.84,   1.93,  -0.403],  ← Token 0
  [0.85,   4.22,   0.21,   0.85,  -0.187],  ← Token 1
  [-1.84,  0.21,   3.82,  -1.84,   1.878],  ← Token 2
  [1.93,   0.85,  -1.84,   1.93,  -0.403],  ← Token 3
  [-0.403, -0.187, 1.878, -0.403,  3.9962]  ← Token 4
]

Shape: [5, 5]

Each row represents one token’s attention scores to all tokens.


Step 3 — Scale by √d

Now, we scale the scores by \(\sqrt{d}\). Please note, this is for a single head. For multiple heads, we scale by \(\sqrt{d / H}\).

Formula: \(S_{scaled}[i,j] = \frac{S[i,j]}{\sqrt{d}}\)

Where d = 4, so √d = 2.

Divide all scores by \(\sqrt{d}\):

Scaled Scores = S / √4 = [
  [0.965,  0.425, -0.92,   0.965,  -0.2015],
  [0.425,  2.11,   0.105,  0.425,  -0.0935],
  [-0.92,  0.105,  1.91,  -0.92,    0.939],
  [0.965,  0.425, -0.92,   0.965,  -0.2015],
  [-0.2015, -0.0935, 0.939, -0.2015, 1.9981]
]

Step 4 — Apply Causal Mask (For Autoregressive Models like GPT)

In GPT, each token can only look at previous tokens, not future ones. For token 1 (“think”), it can see:

Token 0 ("I") ✓
Token 1 ("think" - itself) ✓
Token 2 ("therefore") ✗ (future - mask it!)

We set future positions to -∞ so they become 0 after softmax:

After Causal Mask:
[
  [0.965,  -∞,    -∞,     -∞,     -∞    ],  ← Token 0 (only sees itself)
  [0.425,  2.11,  -∞,     -∞,     -∞    ],  ← Token 1 (sees 0,1)
  [-0.92,  0.105, 1.91,   -∞,     -∞    ],  ← Token 2 (sees 0,1,2)
  [0.965,  0.425, -0.92,  0.965,  -∞    ],  ← Token 3 (sees 0,1,2,3)
  [-0.2015, -0.0935, 0.939, -0.2015, 1.9981]  ← Token 4 (sees all)
]

Step 5 — Apply Softmax (Get Attention Weights)

Formula: \(A[i,j] = \frac{exp(S_{scaled}[i,j])}{\sum exp(...)}\)

Softmax converts scores into weights that sum to 1:

Formula:

\[softmax(x_i) = \frac{e^{x_i}}{\sum e^{x_j}}\]

Token 0 (“I”) as input: [0.965, -∞, -∞, -∞, -∞]

e^{0.965} ≈ 2.625
e^{-∞} = 0

Sum: 2.625 Weights: [2.625/2.625, 0, 0, 0, 0] = [1.0, 0, 0, 0, 0] Token 0 pays 100% attention to itself (can’t see future). Token 1 (“think”) as input: [0.425, 2.11, -∞, -∞, -∞]

e^0.425 ≈ 1.529
e^2.11 ≈ 8.247

Sum: 1.529 + 8.247 = 9.776

Weights:

[1.529/9.776, 8.247/9.776, 0, 0, 0] ≈ [0.156, 0.844, 0, 0, 0]

Token 1 pays: 15.6% attention to “I” (token 0) 84.4% attention to itself

And so on and on.

This is the final Attention Weight Matrix:

Attention Weights A = softmax(Masked Scores) = [
  [1.0,   0,     0,     0,     0    ],  ← Token 0
  [0.156, 0.844, 0,     0,     0    ],  ← Token 1
  [0.048, 0.134, 0.817, 0,     0    ],  ← Token 2
  [0.366, 0.213, 0.056, 0.366, 0    ],  ← Token 3
  [0.065, 0.073, 0.205, 0.065, 0.591]   ← Token 4
]

Step 6 — Weighted Sum of Values (Information Flow)

Now multiply attention weights by values:

\[output[i] = \sum A[i,j] V[j]\]
output[i] = Σ A[i,j] V[j]

This contextualizes each token.


4.3 Multi-Head Attention

So far, we acted as if there was just one “attention head” computing \(𝑄\), \(𝐾\) and \(𝑉\), the scores, the weights and finally the weighted sum of values. In practice, GPT-2 (and our tiny version) uses multi-head attention.

The idea is simple:

  • One head can only learn one kind of relation pattern at a time.
  • Multiple heads can look at the same sequence from different “perspectives” or “subspaces”.

We start with embeddings of dimension \(d = 256\) and split them into \(H = 8\) heads. Then each head operates in a smaller subspace of size:

\[head_dim = d // H = 256 // 8 = 32\]

So, instead of:

\[Q, K, V \in \mathbb{R}^{d \times d}\]

we have:

\[Q, K, V \in \mathbb{R}^{B \times H \times T \times head_dim}\]

Each head runs its own scaled dot-product attention, then we concatenate the results from all heads back into a 256-dimensional vector per token.

In code, we do this by projecting once into a big \((3 x d)\)-dimensional space and then reshaping:

class CausalSelfAttention(nn.Module):
    def __init__(self, n_embd, n_head, block_size, dropout):
        super().__init__()
        assert n_embd % n_head == 0
        self.n_head = n_head
        self.head_dim = n_embd // n_head

        # One projection to produce Q, K, V stacked along the last dim
        self.qkv = nn.Linear(n_embd, 3 * n_embd, bias=False)
        self.out_proj = nn.Linear(n_embd, n_embd)

        # Causal mask (True above the diagonal)
        mask = torch.triu(torch.ones(block_size, block_size), 1)
        self.register_buffer("mask", mask == 1)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape  # (batch, time, channels)

        # 1) Project once, then split into Q, K, V
        qkv = self.qkv(x)          # (B, T, 3*C)
        q, k, v = qkv.split(C, dim=2)  # each (B, T, C)

        # 2) Reshape into heads: (B, n_head, T, head_dim)
        q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)

        # 3) Scaled dot-product attention per head
        att = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        att = att.masked_fill(self.mask[:T, :T], float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.dropout(att)

        # 4) Weighted sum of values
        out = att @ v  # (B, n_head, T, head_dim)

        # 5) Concatenate heads back into (B, T, C)
        out = out.transpose(1, 2).contiguous().view(B, T, C)

        # 6) Final linear projection (mix head outputs)
        out = self.out_proj(out)

        return out

Intuitively, one head might learn which subject pronoun does this verb refer to and another head might learn which earlier tokens give me tense and aspect information. In our running sentence I think therefore I am, you can imagine that one head focusing on how “therefore” connects the two clauses and another head focusing on the grammatical structure around “I” and “am”.

All of these perspectives are combined back together, giving each token a rich, context-aware representation.


4.4 Feedforward (MLP)

Self-attention mixes information across positions (tokens look at other tokens). However, after that mixing step, each token also needs a way to individually process its own representation non-linearly. This is the job of the feedforward network (FFN), sometimes called the MLP block.

In each Transformer block we apply:

\[FFN(x) = (GELU(xW₁ + b₁)) W₂ + b₂\]

where

  • \(W₁ \in \mathbb{R}^{d \times 4d}\) (expands from \(d\) to \(4d\))
  • \(W₂ \in \mathbb{R}^{4d \times d}\) (projects back to \(d\))
  • \(b₁\) and \(b₂\) are bias vectors
  • \(GELU\) is the GELU activation function.

In our implementation, \(d = 256$ so the hidden size is\)4d = 1024$$.

Code

class FeedForward(nn.Module):
    def __init__(self, n_embd, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),  # 512 → 2048
            GELU(),
            nn.Linear(4 * n_embd, n_embd),  # 2048 → 512
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

Why do we expand to \(4d\) and then shrink back?

  • The expansion gives the model extra capacity to represent complex combinations of features.
  • The nonlinearity (GELU) lets the model build rich, non-linear functions of the attention output.
  • Shrinking back to \(d\) keeps the dimensionality consistent for residual connections.

If we zoom into a single token’s vector after attention, call it \(z \in \mathbb{R}^{256}\), the FFN does:

  1. Linearly project to \(\mathbb{R}^{1024}\), think of this as generating many candidate features.
  2. Apply GELU to decide which features are “softly activated”.
  3. Linearly project back to \(\mathbb{R}^{256}\), compressing all that information down into an updated token representation.

So attention says “what should I look at across the sentence?”, and the FFN says “given what I just looked at, how should I transform myself?”.


4.5 Residual Connections (Putting the Block Together)

Now we can see the full TransformerBlock:

class TransformerBlock(nn.Module):
    def __init__(self, n_embd, n_head, block_size, dropout):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
        self.ln2 = nn.LayerNorm(n_embd)
        self.mlp = FeedForward(n_embd, dropout)

    def forward(self, x):
        # 1) Pre-norm + attention + residual
        x = x + self.attn(self.ln1(x))

        # 2) Pre-norm + feedforward + residual
        x = x + self.mlp(self.ln2(x))

        return x

Key ideas:

  1. Pre-LayerNorm: We normalize before attention and before the MLP:
    • This stabilizes training, especially for deeper stacks of blocks.
    • It helps gradients flow more smoothly.
  2. Residual connections: Each major sublayer (attention, then FFN) is added back to its input:
    • x = x + something(x)

    This means the network can always “fall back” to the identity function if needed.Practically, it:

    • Makes optimization easier (better gradient flow).
    • Lets deeper models train without collapsing.

If you imagine passing our running sentence I think therefore I am through a stack of 4 such blocks:

  • At early layers, the residual means we’re only making small tweaks to the embeddings.
  • At later layers, the accumulated updates encode rich patterns: semantic relationships, syntax, phrase structure, etc.

The block is the core “processing unit” that we repeat multiple times.


5. Full Tiny GPT-2 Model

Now, we put everything together in our GPT2Tiny class. This contains the token embeddings, positional embeddings, 8 transformer blocks, layer norm, and the language modeling head.

And we use it like this:

model = GPT2Tiny(
    vocab_size=vocab_size,
    n_layer=n_layer,
    n_head=n_head,
    n_embd=n_embd,
    block_size=block_size,
    dropout=dropout,
)

6. Training

LLMs are autoregressive which means they predict the next token given the previous tokens. So, we train the model to predict the next token given the previous tokens:

\[L = -\sum_{t=1}^{T} log P(y_t | x₁...x_{t−1})\]

In practice, for each batch, \(x\) contains token IDs for positions [1, ... , T] and \(y\) contains token IDs for positions [2, ... , T+1]. The model outputs logits for each position in \(x\). We compare logits at position t with the ground truth token at position t+1. This is implemented in the forward method of the GPT2Mini class with F.cross_entropy.

for step in range(train_steps):
    x, y = get_batch("train")
    logits, loss = model(x, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Please note, the entire training process takes about 2 hours and 30 mins on my T4 GPU.

6.1 Training Loop

The core training loop contains the following steps:

  1. AdamW: Adam with weight decay, standard for Transformer-style models.
  2. Learning rate schedule: a simple warmup (and possibly a decay) helps avoid blowing up early in training.
  3. Gradient clipping: keeps gradients from exploding, which is especially useful for deeper models or noisy batches.
  4. Train/val losses: tracking both lets you see if the model is overfitting or under-training.

Given the small dataset and tiny architecture, you should see the training loss fall steadily and the validation loss track it reasonably closely.


7. Text Generation

Once the model has learned to predict the next token, we can use it to generate text. Here, we start with a prompt and ask the model for the probability distribution over the next token. We then sample from this distribution to get the next token. We repeat this process autoregressively to generate the next token.

And we have finally built a tiny GPT-2 model that generates haikus! I prompted it with round and round we go and it gives the following output:

 round and round we go / of all it takes is time $
we are one family / knight errants of the divine / forever ever more $
stop falling in love / it is not good for the soul / loneliness is god $

References:

  1. Github Repository
  2. Colab Notebook



Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • Mechanistic Interpretability via Learning Differential Equations
  • AlexNet Paper Implementation
  • AI Risks: Misuse, Accidents, and Rogue Systems