<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" xml:lang="en"><generator uri="https://jekyllrb.com/" version="4.4.1">Jekyll</generator><link href="http://localhost:4000/feed.xml" rel="self" type="application/atom+xml" /><link href="http://localhost:4000/" rel="alternate" type="text/html" hreflang="en" /><updated>2025-12-01T16:52:17+06:00</updated><id>http://localhost:4000/feed.xml</id><title type="html">Murshed Al Amin</title><subtitle>An ML researcher focused on AI Safety and exploring PhD opportunities.
</subtitle><entry><title type="html">Understanding &amp;amp; Implementing a Tiny GPT-2 (With Limited GPU)</title><link href="http://localhost:4000/blog/2025/gpt2_tutorial/" rel="alternate" type="text/html" title="Understanding &amp;amp; Implementing a Tiny GPT-2 (With Limited GPU)" /><published>2025-11-27T00:00:00+06:00</published><updated>2025-11-27T00:00:00+06:00</updated><id>http://localhost:4000/blog/2025/gpt2_tutorial</id><content type="html" xml:base="http://localhost:4000/blog/2025/gpt2_tutorial/"><![CDATA[<h2 id="introduction">Introduction</h2>

<p>Everyone knows what GPT-2 is. We’ll try to understand it by implementing a tiny version of it. The reason I say tiny version is because we can’t really replicate the exact architecture because of limited resources. I only have a single T4 GPU on my Colab and probably only 3 - 4 hours of runtime. So, we’ll use a small dataset and small architecture configuration. I guess we should call it tiny GPT-2.</p>

<p>The goal of this ‘blog’ is not to replicate GPT-2’s full scale. Instead, the goal is to <strong>deeply understand</strong>:</p>

<ul>
  <li>What each step in the pipeline does</li>
  <li>Why it is needed</li>
  <li>How it connects to the previous step</li>
  <li>What math happens inside it</li>
  <li>How the code implements that math</li>
</ul>

<p>If you want to skip the blog and go to the direct implementation, you can find it on my <a href="https://github.com/t0n4r/gpt2-implementation">GitHub</a> and on <a href="https://colab.research.google.com/drive/1zmiC_-CCRgKn0fWWwn5x-_wJseZ4APCt?usp=sharing">Colab</a>.</p>

<p>I will try to explain every step thoroughly and we will carry one <strong>running example</strong> throughout:</p>

<p><strong>“I think therefore I am.”</strong></p>

<hr />

<h1 id="notation-table">Notation Table</h1>

<table>
  <thead>
    <tr>
      <th>Symbol</th>
      <th>Meaning</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>V</td>
      <td>Vocabulary size (1024)</td>
    </tr>
    <tr>
      <td>d</td>
      <td>Embedding dimension (512)</td>
    </tr>
    <tr>
      <td>T</td>
      <td>Context window (128)</td>
    </tr>
    <tr>
      <td>H</td>
      <td>Attention heads (8)</td>
    </tr>
    <tr>
      <td>X</td>
      <td>Input embedding matrix</td>
    </tr>
    <tr>
      <td>E</td>
      <td>Token embedding matrix</td>
    </tr>
    <tr>
      <td>P</td>
      <td>Positional embedding matrix</td>
    </tr>
    <tr>
      <td>Q,K,V</td>
      <td>Query, Key, Value</td>
    </tr>
    <tr>
      <td>S</td>
      <td>Attention scores</td>
    </tr>
    <tr>
      <td>A</td>
      <td>Attention weights</td>
    </tr>
    <tr>
      <td>LN</td>
      <td>LayerNorm</td>
    </tr>
    <tr>
      <td>FFN</td>
      <td>Feedforward network</td>
    </tr>
    <tr>
      <td>CE</td>
      <td>Cross entropy loss</td>
    </tr>
    <tr>
      <td>logits</td>
      <td>Raw output before softmax</td>
    </tr>
  </tbody>
</table>

<hr />

<h1 id="1-tokenization">1. Tokenization</h1>

<h2 id="11-why-tokenization-exists">1.1 Why Tokenization Exists</h2>

<p>Transformers do not understand letters, words, or sentences. They only understand numbers. Tokenization does exactly that:</p>

<p><strong>Convert text into integers in a meaningful, consistent way.</strong></p>

<p>Emphasis on “meaningful”. If “cat” is <code class="language-plaintext highlighter-rouge">[20, 3, 99]</code> today and <code class="language-plaintext highlighter-rouge">[4, 77]</code> tomorrow, the model cannot learn anything consistent. So, tokenization must be <strong>deterministic</strong>, <strong>reversible</strong> and <strong>stable</strong> across training and inference.</p>

<p>GPT-2 uses <strong>byte-level BPE (Byte Pair Encoding)</strong>, which is efficient and capable of handling all languages, including emojis.</p>

<hr />

<h2 id="12-byte-level-bpe-characters--bytes--merges">1.2 Byte-Level BPE: Characters → Bytes → Merges</h2>

<p>Byte-level BPE works by first turning text into raw bytes (0–255). Then it looks for byte pairs that appear together frequency. When it finds such a pair, it merges those two bytes into a single, larger token and assigns that new token its own integer ID. As this process repeats, the tokenizer gradually builds a small, efficient set of subword tokens.</p>

<p>This approach avoids the problems of character-level tokenization (too long) and word-level tokenization (too many unknown words). Because it starts from bytes, it can represent any text and because it learns common subword patterns, it keeps sequences short. Frequent patterns like “th” or “ing” become single tokens, while rare patterns stay broken into smaller pieces, so the tokenizer never produces an “unknown” token.</p>

<p>Let’s use UTF-8 for encoding:</p>

<table>
  <thead>
    <tr>
      <th>Text</th>
      <th>I</th>
      <th>␣</th>
      <th>t</th>
      <th>h</th>
      <th>i</th>
      <th>n</th>
      <th>k</th>
      <th>␣</th>
      <th>t</th>
      <th>h</th>
      <th>e</th>
      <th>r</th>
      <th>e</th>
      <th>␣</th>
      <th>I</th>
      <th>␣</th>
      <th>a</th>
      <th>m</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>Byte</td>
      <td>73</td>
      <td>32</td>
      <td>116</td>
      <td>104</td>
      <td>105</td>
      <td>110</td>
      <td>107</td>
      <td>32</td>
      <td>116</td>
      <td>104</td>
      <td>101</td>
      <td>114</td>
      <td>101</td>
      <td>32</td>
      <td>73</td>
      <td>32</td>
      <td>97</td>
      <td>109</td>
    </tr>
  </tbody>
</table>

<p>This is the raw material for BPE.</p>

<p>Let the initial sequence be \(X = [b_1, b_2, \ldots, b_n]\), where \(b_i \in \{0, \ldots, 255\}\) and \(V = \{0, 1, \ldots, 255\}\) be the initial vocabulary of bytes.</p>

<p>For any adjacent pair of bytes \(b_i\) and \(b_{i+1}\), we count the number of times they appear together in the sequence. This is denoted by \(C(b_i, b_{i+1})\).</p>

<p>During each merge, we select the most frequent pair of bytes and merge them into a new token  \(t_{new}\) :</p>

\[(u, v) = \arg\max_{(a,b)} C(a,b)\]

<p>and replace every non-overlapping occurrence of \(uv\) with \(t_{new}\) inside \(X\). After merging, the length of \(X\) typically decreases (it “compresses”), and we recompute pair counts on the new sequence. We repeat this until the vocabulary reaches the target size (e.g., 4096).</p>

<p>Intuitively, high-frequency neighboring bytes (or already-merged tokens) become single tokens. Over many merges, common character bigrams, trigrams, and word pieces are “absorbed” into compact subword tokens. This improves efficiency downstream: sequences become shorter, and the model can learn reusable subword patterns.</p>

<hr />

<h2 id="13-counting-adjacent-pairs">1.3 Counting Adjacent Pairs</h2>
<p>BPE is a greedy algorithm, at each step we only merge the most frequent adjacent pair. This requires recomputing pair statistics after every merge because merges change the sequence and therefore the counts. Selecting the most frequent pair ensures we prioritize the merges that yield the biggest compression and the most reusable subword units.</p>

<p>Example counts:</p>

<table>
  <thead>
    <tr>
      <th>Pair</th>
      <th>Bytes</th>
      <th>Count</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>“I␣”</td>
      <td>(73, 32)</td>
      <td>2</td>
    </tr>
    <tr>
      <td>“␣t”</td>
      <td>(32, 116)</td>
      <td>2</td>
    </tr>
    <tr>
      <td>“th”</td>
      <td>(116, 104)</td>
      <td>2</td>
    </tr>
    <tr>
      <td>“re”</td>
      <td>(114, 101)</td>
      <td>2</td>
    </tr>
  </tbody>
</table>

<p>The most frequent pair is merged into a <strong>new token</strong>.
After a merge:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>73, 32 → 256
32, 116 → 257
116, 104 → 258
104, 101 → 259
</code></pre></div></div>

<p>So the sequence compresses.</p>

<p>This repeats until we reach vocabulary size (1024).</p>

<hr />

<h2 id="14-why-byte-level-bpe-works">1.4 Why Byte-Level BPE Works</h2>

<p>Why not just tokenize by words? or why not just keep characters?</p>

<p>Well, Byte-level BPE answers both concerns. Characters keep sequences too long; words produce out-of-vocabulary problems and brittle vocabularies. Byte-level BPE starts from the universal building blocks (bytes) and learns useful subwords, guaranteeing reversibility and zero unknowns, while keeping sequences reasonably short, almost like a child learning to spell.</p>

<hr />

<h2 id="15-tokenizer-code">1.5 Tokenizer Code</h2>

<p>We use Hugging Face’s ByteLevelBPETokenizer, which implements the algorithm in optimized Rust. We don’t have to write the merge loop ourselves; instead, we point it to our corpus and ask it to learn a vocabulary of a given size. Internally, it performs the pair-counting, the argmax selection, the creation of new token IDs and the non-overlapping replacements repeatedly until it reaches the requested vocabulary size.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">tokenizer</span> <span class="o">=</span> <span class="nc">ByteLevelBPETokenizer</span><span class="p">()</span>
<span class="n">tokenizer</span><span class="p">.</span><span class="n">pre_tokenizer</span> <span class="o">=</span> <span class="nc">ByteLevel</span><span class="p">()</span>
<span class="n">tokenizer</span><span class="p">.</span><span class="n">decoder</span> <span class="o">=</span> <span class="nc">ByteLevelDecoder</span><span class="p">()</span>
<span class="n">tokenizer</span><span class="p">.</span><span class="nf">train</span><span class="p">(</span>
    <span class="n">files</span><span class="o">=</span><span class="p">[</span><span class="n">file_path</span><span class="p">],</span>
    <span class="n">vocab_size</span><span class="o">=</span><span class="mi">1024</span><span class="p">,</span>
    <span class="n">min_frequency</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span>
    <span class="n">special_tokens</span><span class="o">=</span><span class="p">[</span><span class="sh">"</span><span class="s">&lt;|endoftext|&gt;</span><span class="sh">"</span><span class="p">],</span>
<span class="p">)</span>

<span class="c1"># All token to IDs
</span><span class="n">all_ids</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="p">.</span><span class="nf">encode</span><span class="p">(</span><span class="n">text</span><span class="p">).</span><span class="n">ids</span>
<span class="n">all_ids</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">array</span><span class="p">(</span><span class="n">all_ids</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="p">.</span><span class="n">int32</span><span class="p">)</span>
</code></pre></div></div>
<hr />

<h1 id="2-preparing-training-data">2. Preparing Training Data</h1>
<p>After converting the entire corpus into a long sequence of token IDs in Section 1, we now need to transform that stream into training examples the model can learn from. This involves three steps:</p>

<ol>
  <li><strong>Train/Val Split</strong>: We split the sequence into a training set and a validation set, typically using a 90/10 ratio.</li>
  <li><strong>Sampling Windows</strong>: We define fixed-size windows (e.g., 128 tokens) to create training examples.</li>
  <li><strong>get_batch()</strong>: We implement a function to sample random windows from the data.</li>
</ol>

<h2 id="21-trainval-split">2.1 Train/Val Split</h2>
<p>We divide the long token sequence into two parts: 90% for training and 10% for validation. The model learns from the training portion, and the validation portion tells us how well the model generalizes to unseen text.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">n_tokens</span> <span class="o">=</span> <span class="nf">len</span><span class="p">(</span><span class="n">all_ids</span><span class="p">)</span>
<span class="n">split_idx</span> <span class="o">=</span> <span class="nf">int</span><span class="p">(</span><span class="mf">0.9</span> <span class="o">*</span> <span class="n">n_tokens</span><span class="p">)</span>
<span class="n">train_tokens</span> <span class="o">=</span> <span class="n">all_ids</span><span class="p">[:</span><span class="n">split_idx</span><span class="p">]</span>
<span class="n">val_tokens</span>   <span class="o">=</span> <span class="n">all_ids</span><span class="p">[</span><span class="n">split_idx</span><span class="p">:]</span>
</code></pre></div></div>

<hr />

<h2 id="22-sampling-windows">2.2 Sampling Windows</h2>
<p>We choose a context window length (block_size) that tells the model how many tokens it can “see” at once. GPT-style models operate on fixed-length sequences, so every training example must have the same number of tokens. We are keeping out block size small because we are training on a haiku dataset.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">block_size</span> <span class="o">=</span> <span class="mi">64</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">256</span>
</code></pre></div></div>

<hr />

<h2 id="23-get_batch">2.3 get_batch()</h2>
<p>We generate a batch of random windows from the dataset. Each window becomes an input sequence x, and each corresponding shifted window becomes a target sequence y.</p>

<p>Random sampling prevents the model from learning only the local flow of the text (e.g., Shakespeare line-by-line). Instead, it sees many varied contexts across training, which improves generalization and stability.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">get_batch</span><span class="p">(</span><span class="n">split</span><span class="p">):</span>
    <span class="n">data</span> <span class="o">=</span> <span class="n">train_tokens</span> <span class="k">if</span> <span class="n">split</span> <span class="o">==</span> <span class="sh">"</span><span class="s">train</span><span class="sh">"</span> <span class="k">else</span> <span class="n">val_tokens</span>
    <span class="n">max_start</span> <span class="o">=</span> <span class="nf">len</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> <span class="o">-</span> <span class="n">block_size</span> <span class="o">-</span> <span class="mi">1</span>
    <span class="n">idx</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">max_start</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,))</span>
    <span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">stack</span><span class="p">([</span><span class="n">data</span><span class="p">[</span><span class="n">i</span> <span class="p">:</span> <span class="n">i</span> <span class="o">+</span> <span class="n">block_size</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">idx</span><span class="p">])</span>
    <span class="n">y</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">stack</span><span class="p">([</span><span class="n">data</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span> <span class="p">:</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span> <span class="o">+</span> <span class="n">block_size</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">idx</span><span class="p">])</span>
    <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span>
</code></pre></div></div>

<hr />

<h2 id="24-math-behind-the-training-windows">2.4 Math behind the Training Windows</h2>
<p>Each training window creates 128 prediction tasks.
For input tokens:</p>

\[x = t_1, t_2, \ldots, t_{128}\]

<p>the model tries to predict the next tokens:</p>

\[y_1 = x_2, y_2 = x_3, \ldots, y_{128} = x_{129}\]

<p>This is exactly the autoregressive language modeling objective:</p>

\[P(y_t | x_1, \ldots, x_{t-1})\]

<p>By sliding this 1-token shift across millions of windows, the model learns grammar, style, structure and statistical regularities of the language.</p>

<hr />

<h1 id="3-embeddings">3. Embeddings</h1>

<p>Up to this point, we have converted our raw text into a long sequence of token IDs. These IDs are integers (like 391, 82, 256, etc.) that come from the Byte-Level BPE tokenizer. Although these integers are essential for compactly representing text, they do not carry any semantic information. The number 256 does not mean anything by itself. It just identifies a subword unit (a particular BPE merge) in the vocabulary.</p>

<p>A Transformer cannot operate directly on these integer IDs. For a single token we use a vector of length \(d = 256\).</p>

<h2 id="31-why-embeddings-exist">3.1 Why Embeddings Exist</h2>

<p>When a token ID like 73 (the byte for “I”) enters the model, nothing distinguishes it from 116 (“t”) or 32 (space). If we fed these raw integers into a neural network, the model would learn nothing meaningful because numerical closeness has no semantic meaning. For example, 32 (space) is not semantically “closer” to 38 than to 200, but a raw neural network would be forced to treat it as such.</p>

<p>Embeddings solve this issue by mapping each integer token ID to a learned vector in a continuous space. This allows the model to develop a representation of what tokens mean, how they relate to one another, and how they function in context.</p>

<p>Every token receives a vector of length d = 256 in your implementation. Over training, tokens that behave similarly in text (like “the”, “a”, “an”) will come to have vectors that live relatively close together in this space.</p>

<hr />

<h2 id="32-token-embeddings">3.2 Token Embeddings</h2>

<p>We define the token embedding layer as:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">self</span><span class="p">.</span><span class="n">tok_emb</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="nc">Embedding</span><span class="p">(</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">n_embd</span><span class="p">)</span>
</code></pre></div></div>

<p><code class="language-plaintext highlighter-rouge">nn.Embedding(vocab_size, n_embd)</code> creates a lookup table that maps each token ID to its corresponding embedding vector:</p>

\[E \in \mathbb{R}^{V \times d}\]

<p>Where:</p>

<ul>
  <li>$V$ is the vocabulary size (number of unique tokens, in our case it’s 4096)</li>
  <li>$d$ is the embedding dimension (256 in our implementation)</li>
</ul>

<p>Let’s say the tokenizer gave:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>“I” → 73
“think” (after BPE merges) → 502
“therefore” → 1112
“I” → 73
“am” (after merges) → 870
</code></pre></div></div>
<p>The token embedding layer converts the integer sequence into 256-dimensional vectors. Each of these is a learned dense vector such as:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>E[73]   = [0.12, -1.04, 0.55, ..., 0.73]
E[502]  = [-0.22, 0.15, 1.98, ..., -0.34]
E[1112] = [0.01, 0.99, -0.01, ..., 0.02]
E[73]   = [0.12, -1.04, 0.55, ..., 0.73]
E[870]  = [0.01, 0.99, -0.01, ..., 0.02]
</code></pre></div></div>
<p>These numbers are randomly initialized at the start but during training they will converge into meaningful representation patterns.</p>

<h2 id="33-positional-embeddings">3.3 Positional Embeddings</h2>
<p>Even if we embed each token into a meaningful vector, the model cannot distinguish <code class="language-plaintext highlighter-rouge">I think therefore I am</code> from <code class="language-plaintext highlighter-rouge">am I therefore think I</code>. Both sequences would have the same token embeddings, just in a different order.</p>

<p>Transformers process all positions simultaneously and have no inherent notion of sequence order. This is unlike RNNs, where order is built into the recurrence; Transformers need order to be injected manually. This is where positional embeddings come in, which we define by:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">self</span><span class="p">.</span><span class="n">pos_emb</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="nc">Embedding</span><span class="p">(</span><span class="n">block_size</span><span class="p">,</span> <span class="n">n_embd</span><span class="p">)</span>
</code></pre></div></div>

<p>This creates a learned matrix:</p>

\[P \in \mathbb{R}^{T \times d}\]

<p>Where:</p>

<ul>
  <li>$T$ is the block size or the maximum sequence length (128 in our implementation)</li>
  <li>$d$ is the embedding dimension (256 in our implementation)</li>
</ul>

<p>and each row <code class="language-plaintext highlighter-rouge">P[t]</code> is a learned vector that encodes the position of the token in the sequence.</p>

<p>The input to the Transformer becomes:</p>

\[𝑋[t] = 𝐸[token_t] + 𝑃[t]\]

<p>which fuses the token identity and positon information into a single vector.</p>

<p>For our sequence <code class="language-plaintext highlighter-rouge">I think therefore I am</code>, the positional embeddings would be:</p>

\[P = [P[0], P[1], P[2], P[3], P[4]]\]

<p>And the final embedding would look like this:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>X[0] = E[73] + P[0]
X[1] = E[502] + P[1]
X[2] = E[1112] + P[2]
X[3] = E[73] + P[3]
X[4] = E[870] + P[4]
</code></pre></div></div>

<p>You will notice that even through the token “I” appears twice, it has different positional embeddings. This allows the model to learn powerful positional patterns such as verb tenses, sentence structure, and more.</p>

<hr />

<h1 id="4-transformer-block">4. Transformer Block</h1>

<p>A GPT model is made of a stack of identical Transformer blocks. In our implementation, we use 4 blocks but each block performs the same sequence of operations.</p>

<p>A single Transformer block takes in a sequence of embeddings:</p>

\[X \in \mathbb{R}^{B \times T \times d}\]

<p>Where:</p>

<ul>
  <li>$B$ is the batch size (32 in our implementation)</li>
  <li>$T$ is the block size or the maximum sequence length (128 in our implementation)</li>
  <li>$d$ is the embedding dimension (256 in our implementation)</li>
</ul>

<p>To understand the block, we will break it down into its components:</p>
<ol>
  <li>(Pre) LayerNorm</li>
  <li>Self-Attention</li>
  <li>Feedforward Network (MLP)</li>
  <li>Residual Connections</li>
</ol>

<hr />

<h1 id="41-pre-layernorm">4.1 (Pre) LayerNorm</h1>

<p>We use pre-norm Transformer design, where normalization happens before attentions and before the feedforward MLP.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">TransformerBlock</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">n_embd</span><span class="p">,</span> <span class="n">n_head</span><span class="p">,</span> <span class="n">block_size</span><span class="p">,</span> <span class="n">dropout</span><span class="p">):</span>
        <span class="nf">super</span><span class="p">().</span><span class="nf">__init__</span><span class="p">()</span>
        <span class="n">self</span><span class="p">.</span><span class="n">ln1</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="nc">LayerNorm</span><span class="p">(</span><span class="n">n_embd</span><span class="p">)</span>
        <span class="n">self</span><span class="p">.</span><span class="n">attn</span> <span class="o">=</span> <span class="nc">CausalSelfAttention</span><span class="p">(</span><span class="n">n_embd</span><span class="p">,</span> <span class="n">n_head</span><span class="p">,</span> <span class="n">block_size</span><span class="p">,</span> <span class="n">dropout</span><span class="p">)</span>
        <span class="n">self</span><span class="p">.</span><span class="n">ln2</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="nc">LayerNorm</span><span class="p">(</span><span class="n">n_embd</span><span class="p">)</span>
        <span class="n">self</span><span class="p">.</span><span class="n">mlp</span> <span class="o">=</span> <span class="nc">FeedForward</span><span class="p">(</span><span class="n">n_embd</span><span class="p">,</span> <span class="n">dropout</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">self</span><span class="p">.</span><span class="nf">attn</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="nf">ln1</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>  <span class="c1"># Pre-norm + residual
</span>        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">self</span><span class="p">.</span><span class="nf">mlp</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="nf">ln2</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>   <span class="c1"># Pre-norm + residual
</span>        <span class="k">return</span> <span class="n">x</span>
</code></pre></div></div>

<p>LayerNorm operates per token, across its feature dimensions.</p>

<p>For each token vector \(x_t \in \mathbb{R}^{d}\):</p>

<ol>
  <li>We compute the mean:</li>
</ol>

\[\mu_t = \frac{1}{d} \sum_{i=1}^{d} x_{t,i}\]

<ol>
  <li>We compute the variance:</li>
</ol>

\[\sigma_t^2 = \frac{1}{d} \sum_{i=1}^{d} (x_{t,i} - \mu_t)^2\]

<ol>
  <li>We normalize the token vector:</li>
</ol>

\[\hat{x}_t = \frac{x_t - \mu_t}{\sqrt{\sigma_t^2 + \epsilon}}\]

<ol>
  <li>We scale the normalized vector:</li>
</ol>

\[LN(x_t)_i = \gamma_i \hat{x}_t + \beta_i\]

<p>Let’s assume the model produced this pre-attention embedding vector for “therefore”:</p>

\[x_t = [-1.7, 0.2, 0.8, -0.5]\]

<p>And the mean and variance are:</p>

<p>\(\mu_t = -0.5\)
\(\sigma_t^2 = 1.2\)</p>

<p>Then the normalized vector would be:</p>

\[\hat{x}_t = \frac{x_t - \mu_t}{\sqrt{\sigma_t^2 + \epsilon}} = \frac{[-1.7, 0.2, 0.8, -0.5] - [-0.5, -0.5, -0.5, -0.5]}{\sqrt{1.2 + \epsilon}} = \frac{[-1.2, 0.7, 1.3, -0.3]}{\sqrt{1.2 + \epsilon}}\]

\[\hat{x}_t = \frac{[-1.2, 0.7, 1.3, -0.3]}{\sqrt{1.2 + \epsilon}} = [-0.9, 0.5, 1.0, -0.2]\]

<p>And the final LayerNorm output would be:</p>

\[LN(x_t)_i = \gamma_i \hat{x}_t + \beta_i\]

<p>Where \(\gamma_i\) and \(\beta_i\) are learnable parameters.</p>

<p>Let’s assume \(\gamma_i = 1\) and \(\beta_i = 0\) for simplicity.</p>

<p>So, the output would be:</p>

\[LN(x_t)_i = \hat{x}_t = [-0.9, 0.5, 1.0, -0.2]\]

<hr />

<h1 id="42-self-attention">4.2 Self-Attention</h1>

<p>Self-attention is the most important part of GPT-style models. It allows each token to “look at” other tokens in the sequence and decide which earlier tokens are relevant and how much attention should it pay to each of them.</p>

<p>When we read the sentence <code class="language-plaintext highlighter-rouge">I think therefore I am.</code>, we interpret the word <code class="language-plaintext highlighter-rouge">therefore</code> by paying attention to the preceding phrase <code class="language-plaintext highlighter-rouge">I think</code> and the following phrase <code class="language-plaintext highlighter-rouge">I am</code>. Attention lets the model learn these relationships automatically. This is done by <code class="language-plaintext highlighter-rouge">Query</code>, <code class="language-plaintext highlighter-rouge">Key</code> and <code class="language-plaintext highlighter-rouge">Value</code> matrices.</p>

<hr />

<h2 id="step-1--q-k-v">Step 1 — Q, K, V</h2>

<p>Let’s walk through an example with the sentence “I think therefore I am”. After token and positional embedding and layer normalization, each token with a 4-dimensional vector would look like this:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>X = [
  [1.0, 0.5, 0.2, 0.8],  ← "I"
  [-0.3, 1.9, -0.6, 0.4],  ← "think"
  [-1.7, 0.2, 0.8,- 0.5]   ← “therefore"
  [1.0, 0.5, 0.2, 0.8],  ← "I"
  [-1.34, -0.45, 0.45, 1.34],  ← "am"
]
</code></pre></div></div>

<p>We first create Q, K, V and then multiply X by three different weight matrices to create three different representations:
Q = X W_Q (Queries)
K = X W_K (Keys)
V = X W_V (Values)</p>

<p>So, what are these weight matrices? They are learned during training. But what are they?</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>W_K, W_V, W_Q = [
  [1, 0, 0, 0],
  [0, 1, 0, 0],
  [0, 0, 1, 0],
  [0, 0, 0, 1]
]  (For simplicity, assume W_K and W_V are also identity matrices)
</code></pre></div></div>

<p>So, the Q = K = V = X. But in reality, these are different matrices that transform the embeddings differently.</p>

<p>The code for calculating Q, K, V is:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">self</span><span class="p">.</span><span class="n">qkv</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="nc">Linear</span><span class="p">(</span><span class="n">n_embd</span><span class="p">,</span> <span class="mi">3</span> <span class="o">*</span> <span class="n">n_embd</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
</code></pre></div></div>

<hr />

<h2 id="step-2--computing-attention-scores">Step 2 — Computing Attention Scores</h2>
<p>Now each token needs to figure out: “How relevant is every other token to me?” We do this by:</p>

\[score(i, j) = Q_i \cdot K_j\]

<p>We need to compute the dot product of every query with every key.
Formula: score[i,j] = Q[i] . K[j]</p>

<p>Token 0 (“I”) queries all keys:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Q[0] · K[0]:
[1.0, 0.5, 0.2, 0.8] · [1.0, 0.5, 0.2, 0.8]
= 1.0×1.0 + 0.5×0.5 + 0.2×0.2 + 0.8×0.8
= 1.0 + 0.25 + 0.04 + 0.64
= 1.93

Q[0] · K[1]:
[1.0, 0.5, 0.2, 0.8] · [-0.3, 1.9, -0.6, 0.4]
= 1.0×(-0.3) + 0.5×1.9 + 0.2×(-0.6) + 0.8×0.4
= -0.3 + 0.95 - 0.12 + 0.32
= 0.85

Q[0] · K[2]:
[1.0, 0.5, 0.2, 0.8] · [-1.7, 0.2, 0.8, -0.5]
= 1.0×(-1.7) + 0.5×0.2 + 0.2×0.8 + 0.8×(-0.5)
= -1.7 + 0.1 + 0.16 - 0.4
= -1.84

Q[0] · K[3]:
[1.0, 0.5, 0.2, 0.8] · [1.0, 0.5, 0.2, 0.8]
= 1.93 (same as Q[0]·K[0])

Q[0] · K[4]:
[1.0, 0.5, 0.2, 0.8] · [-1.34, -0.45, 0.45, 1.34]
= 1.0×(-1.34) + 0.5×(-0.45) + 0.2×0.45 + 0.8×1.34
= -1.34 - 0.225 + 0.09 + 1.072
= -0.403
Row 0 scores: [1.93, 0.85, -1.84, 1.93, -0.403]
Similarly, Token 1 ("think") queries all keys:
Q[1] · K[0]:
[-0.3, 1.9, -0.6, 0.4] · [1.0, 0.5, 0.2, 0.8]
= -0.3×1.0 + 1.9×0.5 + (-0.6)×0.2 + 0.4×0.8
= -0.3 + 0.95 - 0.12 + 0.32
= 0.85

Q[1] · K[1]:
[-0.3, 1.9, -0.6, 0.4] · [-0.3, 1.9, -0.6, 0.4]
= (-0.3)×(-0.3) + 1.9×1.9 + (-0.6)×(-0.6) + 0.4×0.4
= 0.09 + 3.61 + 0.36 + 0.16
= 4.22

Q[1] · K[2]:
[-0.3, 1.9, -0.6, 0.4] · [-1.7, 0.2, 0.8, -0.5]
= (-0.3)×(-1.7) + 1.9×0.2 + (-0.6)×0.8 + 0.4×(-0.5)
= 0.51 + 0.38 - 0.48 - 0.2
= 0.21
Q[1] · K[3]:
[-0.3, 1.9, -0.6, 0.4] · [1.0, 0.5, 0.2, 0.8]
= 0.85 (same as Q[1]·K[0])

Q[1] · K[4]:
[-0.3, 1.9, -0.6, 0.4] · [-1.34, -0.45, 0.45, 1.34]
= (-0.3)×(-1.34) + 1.9×(-0.45) + (-0.6)×0.45 + 0.4×1.34
= 0.402 - 0.855 - 0.27 + 0.536
= -0.187
</code></pre></div></div>

<p>Row 1 scores: <code class="language-plaintext highlighter-rouge">[0.85, 4.22, 0.21, 0.85, -0.187]</code></p>

<p>Similarly for Row 2, 3 and 4.</p>

<p>Complete Score Matrix <code class="language-plaintext highlighter-rouge">(Q K^T)</code>:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Score Matrix S = Q K^T = [
  [1.93,   0.85,  -1.84,   1.93,  -0.403],  ← Token 0
  [0.85,   4.22,   0.21,   0.85,  -0.187],  ← Token 1
  [-1.84,  0.21,   3.82,  -1.84,   1.878],  ← Token 2
  [1.93,   0.85,  -1.84,   1.93,  -0.403],  ← Token 3
  [-0.403, -0.187, 1.878, -0.403,  3.9962]  ← Token 4
]

Shape: [5, 5]
</code></pre></div></div>

<p>Each row represents one token’s attention scores to all tokens.</p>

<hr />

<h2 id="step-3--scale-by-d">Step 3 — Scale by √d</h2>

<p>Now, we scale the scores by \(\sqrt{d}\). Please note, this is for a single head. For multiple heads, we scale by \(\sqrt{d / H}\).</p>

<p>Formula: \(S_{scaled}[i,j] = \frac{S[i,j]}{\sqrt{d}}\)</p>

<p>Where d = 4, so √d = 2.</p>

<p>Divide all scores by \(\sqrt{d}\):</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Scaled Scores = S / √4 = [
  [0.965,  0.425, -0.92,   0.965,  -0.2015],
  [0.425,  2.11,   0.105,  0.425,  -0.0935],
  [-0.92,  0.105,  1.91,  -0.92,    0.939],
  [0.965,  0.425, -0.92,   0.965,  -0.2015],
  [-0.2015, -0.0935, 0.939, -0.2015, 1.9981]
]
</code></pre></div></div>

<hr />

<h2 id="step-4--apply-causal-mask-for-autoregressive-models-like-gpt">Step 4 — Apply Causal Mask (For Autoregressive Models like GPT)</h2>

<p>In GPT, each token can only look at previous tokens, not future ones.
For token 1 (“think”), it can see:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Token 0 ("I") ✓
Token 1 ("think" - itself) ✓
Token 2 ("therefore") ✗ (future - mask it!)
</code></pre></div></div>

<p>We set future positions to <code class="language-plaintext highlighter-rouge">-∞</code> so they become 0 after softmax:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>After Causal Mask:
[
  [0.965,  -∞,    -∞,     -∞,     -∞    ],  ← Token 0 (only sees itself)
  [0.425,  2.11,  -∞,     -∞,     -∞    ],  ← Token 1 (sees 0,1)
  [-0.92,  0.105, 1.91,   -∞,     -∞    ],  ← Token 2 (sees 0,1,2)
  [0.965,  0.425, -0.92,  0.965,  -∞    ],  ← Token 3 (sees 0,1,2,3)
  [-0.2015, -0.0935, 0.939, -0.2015, 1.9981]  ← Token 4 (sees all)
]
</code></pre></div></div>
<hr />

<h2 id="step-5--apply-softmax-get-attention-weights">Step 5 — Apply Softmax (Get Attention Weights)</h2>

<p>Formula: \(A[i,j] = \frac{exp(S_{scaled}[i,j])}{\sum exp(...)}\)</p>

<p>Softmax converts scores into weights that sum to 1:</p>

<p>Formula:</p>

\[softmax(x_i) = \frac{e^{x_i}}{\sum e^{x_j}}\]

<p>Token 0 (“I”) as input: <code class="language-plaintext highlighter-rouge">[0.965, -∞, -∞, -∞, -∞]</code></p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>e^{0.965} ≈ 2.625
e^{-∞} = 0
</code></pre></div></div>

<p>Sum: 2.625
Weights: [2.625/2.625, 0, 0, 0, 0] = [1.0, 0, 0, 0, 0]
Token 0 pays 100% attention to itself (can’t see future).
Token 1 (“think”) as input: [0.425, 2.11, -∞, -∞, -∞]</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>e^0.425 ≈ 1.529
e^2.11 ≈ 8.247
</code></pre></div></div>

<p>Sum: 1.529 + 8.247 = 9.776</p>

<p>Weights:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>[1.529/9.776, 8.247/9.776, 0, 0, 0] ≈ [0.156, 0.844, 0, 0, 0]
</code></pre></div></div>

<p>Token 1 pays:
15.6% attention to “I” (token 0)
84.4% attention to itself</p>

<p>And so on and on.</p>

<p>This is the final Attention Weight Matrix:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Attention Weights A = softmax(Masked Scores) = [
  [1.0,   0,     0,     0,     0    ],  ← Token 0
  [0.156, 0.844, 0,     0,     0    ],  ← Token 1
  [0.048, 0.134, 0.817, 0,     0    ],  ← Token 2
  [0.366, 0.213, 0.056, 0.366, 0    ],  ← Token 3
  [0.065, 0.073, 0.205, 0.065, 0.591]   ← Token 4
]
</code></pre></div></div>

<hr />

<h2 id="step-6--weighted-sum-of-values-information-flow">Step 6 — Weighted Sum of Values (Information Flow)</h2>

<p>Now multiply attention weights by values:</p>

\[output[i] = \sum A[i,j] V[j]\]

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>output[i] = Σ A[i,j] V[j]
</code></pre></div></div>

<p>This contextualizes each token.</p>

<hr />

<h1 id="43-multi-head-attention">4.3 Multi-Head Attention</h1>
<p>So far, we acted as if there was just one “attention head” computing \(𝑄\), \(𝐾\) and \(𝑉\), the scores, the weights and finally the weighted sum of values. In practice, GPT-2 (and our tiny version) uses multi-head attention.</p>

<p>The idea is simple:</p>
<ul>
  <li>One head can only learn one kind of relation pattern at a time.</li>
  <li>Multiple heads can look at the same sequence from different “perspectives” or “subspaces”.</li>
</ul>

<p>We start with embeddings of dimension \(d = 256\) and split them into \(H = 8\) heads. Then each head operates in a smaller subspace of size:</p>

\[head_dim = d // H = 256 // 8 = 32\]

<p>So, instead of:</p>

\[Q, K, V \in \mathbb{R}^{d \times d}\]

<p>we have:</p>

\[Q, K, V \in \mathbb{R}^{B \times H \times T \times head_dim}\]

<p>Each head runs its own scaled dot-product attention, then we concatenate the results from all heads back into a 256-dimensional vector per token.</p>

<p>In code, we do this by projecting once into a big \((3 x d)\)-dimensional space and then reshaping:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">CausalSelfAttention</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">n_embd</span><span class="p">,</span> <span class="n">n_head</span><span class="p">,</span> <span class="n">block_size</span><span class="p">,</span> <span class="n">dropout</span><span class="p">):</span>
        <span class="nf">super</span><span class="p">().</span><span class="nf">__init__</span><span class="p">()</span>
        <span class="k">assert</span> <span class="n">n_embd</span> <span class="o">%</span> <span class="n">n_head</span> <span class="o">==</span> <span class="mi">0</span>
        <span class="n">self</span><span class="p">.</span><span class="n">n_head</span> <span class="o">=</span> <span class="n">n_head</span>
        <span class="n">self</span><span class="p">.</span><span class="n">head_dim</span> <span class="o">=</span> <span class="n">n_embd</span> <span class="o">//</span> <span class="n">n_head</span>

        <span class="c1"># One projection to produce Q, K, V stacked along the last dim
</span>        <span class="n">self</span><span class="p">.</span><span class="n">qkv</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="nc">Linear</span><span class="p">(</span><span class="n">n_embd</span><span class="p">,</span> <span class="mi">3</span> <span class="o">*</span> <span class="n">n_embd</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
        <span class="n">self</span><span class="p">.</span><span class="n">out_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="nc">Linear</span><span class="p">(</span><span class="n">n_embd</span><span class="p">,</span> <span class="n">n_embd</span><span class="p">)</span>

        <span class="c1"># Causal mask (True above the diagonal)
</span>        <span class="n">mask</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">triu</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="nf">ones</span><span class="p">(</span><span class="n">block_size</span><span class="p">,</span> <span class="n">block_size</span><span class="p">),</span> <span class="mi">1</span><span class="p">)</span>
        <span class="n">self</span><span class="p">.</span><span class="nf">register_buffer</span><span class="p">(</span><span class="sh">"</span><span class="s">mask</span><span class="sh">"</span><span class="p">,</span> <span class="n">mask</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)</span>

        <span class="n">self</span><span class="p">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="nc">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="n">B</span><span class="p">,</span> <span class="n">T</span><span class="p">,</span> <span class="n">C</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="n">shape</span>  <span class="c1"># (batch, time, channels)
</span>
        <span class="c1"># 1) Project once, then split into Q, K, V
</span>        <span class="n">qkv</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="nf">qkv</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>          <span class="c1"># (B, T, 3*C)
</span>        <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="n">qkv</span><span class="p">.</span><span class="nf">split</span><span class="p">(</span><span class="n">C</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>  <span class="c1"># each (B, T, C)
</span>
        <span class="c1"># 2) Reshape into heads: (B, n_head, T, head_dim)
</span>        <span class="n">q</span> <span class="o">=</span> <span class="n">q</span><span class="p">.</span><span class="nf">view</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">T</span><span class="p">,</span> <span class="n">self</span><span class="p">.</span><span class="n">n_head</span><span class="p">,</span> <span class="n">self</span><span class="p">.</span><span class="n">head_dim</span><span class="p">).</span><span class="nf">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
        <span class="n">k</span> <span class="o">=</span> <span class="n">k</span><span class="p">.</span><span class="nf">view</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">T</span><span class="p">,</span> <span class="n">self</span><span class="p">.</span><span class="n">n_head</span><span class="p">,</span> <span class="n">self</span><span class="p">.</span><span class="n">head_dim</span><span class="p">).</span><span class="nf">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
        <span class="n">v</span> <span class="o">=</span> <span class="n">v</span><span class="p">.</span><span class="nf">view</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">T</span><span class="p">,</span> <span class="n">self</span><span class="p">.</span><span class="n">n_head</span><span class="p">,</span> <span class="n">self</span><span class="p">.</span><span class="n">head_dim</span><span class="p">).</span><span class="nf">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>

        <span class="c1"># 3) Scaled dot-product attention per head
</span>        <span class="n">att</span> <span class="o">=</span> <span class="p">(</span><span class="n">q</span> <span class="o">@</span> <span class="n">k</span><span class="p">.</span><span class="nf">transpose</span><span class="p">(</span><span class="o">-</span><span class="mi">2</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span> <span class="o">/</span> <span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">head_dim</span> <span class="o">**</span> <span class="mf">0.5</span><span class="p">)</span>
        <span class="n">att</span> <span class="o">=</span> <span class="n">att</span><span class="p">.</span><span class="nf">masked_fill</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">mask</span><span class="p">[:</span><span class="n">T</span><span class="p">,</span> <span class="p">:</span><span class="n">T</span><span class="p">],</span> <span class="nf">float</span><span class="p">(</span><span class="sh">'</span><span class="s">-inf</span><span class="sh">'</span><span class="p">))</span>
        <span class="n">att</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="nf">softmax</span><span class="p">(</span><span class="n">att</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
        <span class="n">att</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="nf">dropout</span><span class="p">(</span><span class="n">att</span><span class="p">)</span>

        <span class="c1"># 4) Weighted sum of values
</span>        <span class="n">out</span> <span class="o">=</span> <span class="n">att</span> <span class="o">@</span> <span class="n">v</span>  <span class="c1"># (B, n_head, T, head_dim)
</span>
        <span class="c1"># 5) Concatenate heads back into (B, T, C)
</span>        <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="p">.</span><span class="nf">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">).</span><span class="nf">contiguous</span><span class="p">().</span><span class="nf">view</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">T</span><span class="p">,</span> <span class="n">C</span><span class="p">)</span>

        <span class="c1"># 6) Final linear projection (mix head outputs)
</span>        <span class="n">out</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="nf">out_proj</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>

        <span class="k">return</span> <span class="n">out</span>
</code></pre></div></div>

<p>Intuitively, one head might learn which subject pronoun does this verb refer to and another head might learn which earlier tokens give me tense and aspect information. In our running sentence <code class="language-plaintext highlighter-rouge">I think therefore I am</code>, you can imagine that one head focusing on how “therefore” connects the two clauses and another head focusing on the grammatical structure around “I” and “am”.</p>

<p>All of these perspectives are combined back together, giving each token a rich, context-aware representation.</p>

<hr />

<h1 id="44-feedforward-mlp">4.4 Feedforward (MLP)</h1>

<p>Self-attention mixes information across positions (tokens look at other tokens). However, after that mixing step, each token also needs a way to individually process its own representation non-linearly. This is the job of the feedforward network (FFN), sometimes called the MLP block.</p>

<p>In each Transformer block we apply:</p>

\[FFN(x) = (GELU(xW₁ + b₁)) W₂ + b₂\]

<p>where</p>
<ul>
  <li>\(W₁ \in \mathbb{R}^{d \times 4d}\) (expands from \(d\) to \(4d\))</li>
  <li>\(W₂ \in \mathbb{R}^{4d \times d}\) (projects back to \(d\))</li>
  <li>\(b₁\) and \(b₂\) are bias vectors</li>
  <li>\(GELU\) is the GELU activation function.</li>
</ul>

<p>In our implementation, \(d = 256$ so the hidden size is\)4d = 1024$$.</p>

<h2 id="code">Code</h2>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">FeedForward</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">n_embd</span><span class="p">,</span> <span class="n">dropout</span><span class="p">):</span>
        <span class="nf">super</span><span class="p">().</span><span class="nf">__init__</span><span class="p">()</span>
        <span class="n">self</span><span class="p">.</span><span class="n">net</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="nc">Sequential</span><span class="p">(</span>
            <span class="n">nn</span><span class="p">.</span><span class="nc">Linear</span><span class="p">(</span><span class="n">n_embd</span><span class="p">,</span> <span class="mi">4</span> <span class="o">*</span> <span class="n">n_embd</span><span class="p">),</span>  <span class="c1"># 512 → 2048
</span>            <span class="nc">GELU</span><span class="p">(),</span>
            <span class="n">nn</span><span class="p">.</span><span class="nc">Linear</span><span class="p">(</span><span class="mi">4</span> <span class="o">*</span> <span class="n">n_embd</span><span class="p">,</span> <span class="n">n_embd</span><span class="p">),</span>  <span class="c1"># 2048 → 512
</span>            <span class="n">nn</span><span class="p">.</span><span class="nc">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">),</span>
        <span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="k">return</span> <span class="n">self</span><span class="p">.</span><span class="nf">net</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</code></pre></div></div>

<p>Why do we expand to \(4d\) and then shrink back?</p>
<ul>
  <li>The expansion gives the model extra capacity to represent complex combinations of features.</li>
  <li>The nonlinearity (GELU) lets the model build rich, non-linear functions of the attention output.</li>
  <li>Shrinking back to \(d\) keeps the dimensionality consistent for residual connections.</li>
</ul>

<p>If we zoom into a single token’s vector after attention, call it \(z \in \mathbb{R}^{256}\), the FFN does:</p>

<ol>
  <li>Linearly project to \(\mathbb{R}^{1024}\), think of this as generating many candidate features.</li>
  <li>Apply GELU to decide which features are “softly activated”.</li>
  <li>Linearly project back to \(\mathbb{R}^{256}\), compressing all that information down into an updated token representation.</li>
</ol>

<p>So attention says “what should I look at across the sentence?”, and the FFN says “given what I just looked at, how should I transform myself?”.</p>

<hr />
<h1 id="45-residual-connections-putting-the-block-together">4.5 Residual Connections (Putting the Block Together)</h1>

<p>Now we can see the full TransformerBlock:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">TransformerBlock</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">n_embd</span><span class="p">,</span> <span class="n">n_head</span><span class="p">,</span> <span class="n">block_size</span><span class="p">,</span> <span class="n">dropout</span><span class="p">):</span>
        <span class="nf">super</span><span class="p">().</span><span class="nf">__init__</span><span class="p">()</span>
        <span class="n">self</span><span class="p">.</span><span class="n">ln1</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="nc">LayerNorm</span><span class="p">(</span><span class="n">n_embd</span><span class="p">)</span>
        <span class="n">self</span><span class="p">.</span><span class="n">attn</span> <span class="o">=</span> <span class="nc">CausalSelfAttention</span><span class="p">(</span><span class="n">n_embd</span><span class="p">,</span> <span class="n">n_head</span><span class="p">,</span> <span class="n">block_size</span><span class="p">,</span> <span class="n">dropout</span><span class="p">)</span>
        <span class="n">self</span><span class="p">.</span><span class="n">ln2</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="nc">LayerNorm</span><span class="p">(</span><span class="n">n_embd</span><span class="p">)</span>
        <span class="n">self</span><span class="p">.</span><span class="n">mlp</span> <span class="o">=</span> <span class="nc">FeedForward</span><span class="p">(</span><span class="n">n_embd</span><span class="p">,</span> <span class="n">dropout</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="c1"># 1) Pre-norm + attention + residual
</span>        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">self</span><span class="p">.</span><span class="nf">attn</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="nf">ln1</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>

        <span class="c1"># 2) Pre-norm + feedforward + residual
</span>        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">self</span><span class="p">.</span><span class="nf">mlp</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="nf">ln2</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>

        <span class="k">return</span> <span class="n">x</span>
</code></pre></div></div>

<p>Key ideas:</p>

<ol>
  <li>Pre-LayerNorm: We normalize before attention and before the MLP:
    <ul>
      <li>This stabilizes training, especially for deeper stacks of blocks.</li>
      <li>It helps gradients flow more smoothly.</li>
    </ul>
  </li>
  <li>Residual connections: Each major sublayer (attention, then FFN) is added back to its input:
    <ul>
      <li>x = x + something(x)</li>
    </ul>

    <p>This means the network can always “fall back” to the identity function if needed.Practically, it:</p>
    <ul>
      <li>Makes optimization easier (better gradient flow).</li>
      <li>Lets deeper models train without collapsing.</li>
    </ul>
  </li>
</ol>

<p>If you imagine passing our running sentence <code class="language-plaintext highlighter-rouge">I think therefore I am</code> through a stack of 4 such blocks:</p>
<ul>
  <li>At early layers, the residual means we’re only making small tweaks to the embeddings.</li>
  <li>At later layers, the accumulated updates encode rich patterns: semantic relationships, syntax, phrase structure, etc.</li>
</ul>

<p>The block is the core “processing unit” that we repeat multiple times.</p>

<hr />

<h1 id="5-full-tiny-gpt-2-model">5. Full Tiny GPT-2 Model</h1>

<p>Now, we put everything together in our <code class="language-plaintext highlighter-rouge">GPT2Tiny</code> class. This contains the token embeddings, positional embeddings, 8 transformer blocks, layer norm, and the language modeling head.</p>

<p>And we use it like this:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">model</span> <span class="o">=</span> <span class="nc">GPT2Tiny</span><span class="p">(</span>
    <span class="n">vocab_size</span><span class="o">=</span><span class="n">vocab_size</span><span class="p">,</span>
    <span class="n">n_layer</span><span class="o">=</span><span class="n">n_layer</span><span class="p">,</span>
    <span class="n">n_head</span><span class="o">=</span><span class="n">n_head</span><span class="p">,</span>
    <span class="n">n_embd</span><span class="o">=</span><span class="n">n_embd</span><span class="p">,</span>
    <span class="n">block_size</span><span class="o">=</span><span class="n">block_size</span><span class="p">,</span>
    <span class="n">dropout</span><span class="o">=</span><span class="n">dropout</span><span class="p">,</span>
<span class="p">)</span>
</code></pre></div></div>
<hr />

<h1 id="6-training">6. Training</h1>

<p>LLMs are autoregressive which means they predict the next token given the previous tokens. So, we train the model to predict the next token given the previous tokens:</p>

\[L = -\sum_{t=1}^{T} log P(y_t | x₁...x_{t−1})\]

<p>In practice, for each batch, \(x\) contains token IDs for positions <code class="language-plaintext highlighter-rouge">[1, ... , T]</code> and \(y\) contains token IDs for positions <code class="language-plaintext highlighter-rouge">[2, ... , T+1]</code>. The model outputs logits for each position in \(x\). We compare logits at position t with the ground truth token at position t+1. This is implemented in the <code class="language-plaintext highlighter-rouge">forward</code> method of the GPT2Mini class with <code class="language-plaintext highlighter-rouge">F.cross_entropy</code>.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="n">train_steps</span><span class="p">):</span>
    <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="nf">get_batch</span><span class="p">(</span><span class="sh">"</span><span class="s">train</span><span class="sh">"</span><span class="p">)</span>
    <span class="n">logits</span><span class="p">,</span> <span class="n">loss</span> <span class="o">=</span> <span class="nf">model</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
    <span class="n">optimizer</span><span class="p">.</span><span class="nf">zero_grad</span><span class="p">()</span>
    <span class="n">loss</span><span class="p">.</span><span class="nf">backward</span><span class="p">()</span>
    <span class="n">optimizer</span><span class="p">.</span><span class="nf">step</span><span class="p">()</span>
</code></pre></div></div>

<h2 id="please-note-the-entire-training-process-takes-about-2-hours-and-30-mins-on-my-t4-gpu">Please note, the entire training process takes about 2 hours and 30 mins on my T4 GPU.</h2>

<h1 id="61-training-loop">6.1 Training Loop</h1>

<p>The core training loop contains the following steps:</p>
<ol>
  <li>AdamW: Adam with weight decay, standard for Transformer-style models.</li>
  <li>Learning rate schedule: a simple warmup (and possibly a decay) helps avoid blowing up early in training.</li>
  <li>Gradient clipping: keeps gradients from exploding, which is especially useful for deeper models or noisy batches.</li>
  <li>Train/val losses: tracking both lets you see if the model is overfitting or under-training.</li>
</ol>

<p>Given the small dataset and tiny architecture, you should see the training loss fall steadily and the validation loss track it reasonably closely.</p>

<hr />

<h1 id="7-text-generation">7. Text Generation</h1>

<p>Once the model has learned to predict the next token, we can use it to generate text. Here, we start with a prompt and ask the model for the probability distribution over the next token. We then sample from this distribution to get the next token. We repeat this process autoregressively to generate the next token.</p>

<p>And we have finally built a tiny GPT-2 model that generates haikus! I prompted it with <code class="language-plaintext highlighter-rouge">round and round we go</code> and it gives the following output:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code> round and round we go / of all it takes is time $
we are one family / knight errants of the divine / forever ever more $
stop falling in love / it is not good for the soul / loneliness is god $
</code></pre></div></div>

<h1 id="references">References:</h1>
<ol>
  <li><a href="https://github.com/t0n4r/gpt2-implementation">Github Repository</a></li>
  <li><a href="https://colab.research.google.com/drive/1zmiC_-CCRgKn0fWWwn5x-_wJseZ4APCt?usp=sharing">Colab Notebook</a></li>
</ol>]]></content><author><name></name></author><category term="literature-review" /><category term="projects" /><category term="tutorial" /><summary type="html"><![CDATA[A Beginner-Friendly Fully Explained Guide]]></summary></entry><entry><title type="html">AlexNet Paper Implementation</title><link href="http://localhost:4000/blog/2025/alexnet-implementation/" rel="alternate" type="text/html" title="AlexNet Paper Implementation" /><published>2025-05-28T00:00:00+06:00</published><updated>2025-05-28T00:00:00+06:00</updated><id>http://localhost:4000/blog/2025/alexnet-implementation</id><content type="html" xml:base="http://localhost:4000/blog/2025/alexnet-implementation/"><![CDATA[<p>AlexNet needs no introduction. This blog is more of an exercise for myself where I read a paper, replicate it and document the process. The “ImageNet Classification with Deep Convolutional Neural Networks” paper is a good start for this.</p>

<h2 id="introduction">Introduction</h2>

<p>AlexNet, developed by Alex Krizhevsky, Ilya Sutskever, and Geoffrey Hinton, marked a significant milestone in deep learning. Before AlexNet’s win in the 2012 ImageNet competition, convolutional neural networks (CNNs) were not used very often, because of computational constraints and shallow architectures.</p>

<h2 id="dataset-used">Dataset Used</h2>

<p>The dataset used in the paper was ImageNet’s ILSVRC, which was a subset of ImageNet’s entire dataset, consisting of approximately 1.2 million training images, 50,000 validation images and 150,000 test images, all categorized into 1,000 classes. The images from the dataset were of different resolutions but as the model required a constant input size, they opted for 256x256 pixels (256 is not the input size for the model).</p>

<p>The dataset I used for the implementation was CIFAR-10, which is a much smaller dataset in comparison. It contains only 60,000 color images in total, with 6,000 images per class across 10 classes. The dataset is split into 50,000 training images and 10,000 testing images. I used 10,000 of the training images as the validation dataset. The images from CIFAR-10 are of low resolution, with a size of 32x32 pixels.</p>

<div class="row mt-3">
    <div class="col-sm mt-3 mt-md-0">
        



<figure>
  <picture>
    <!-- Auto scaling with imagemagick -->
    <!--
      See https://www.debugbear.com/blog/responsive-images#w-descriptors-and-the-sizes-attribute and
      https://developer.mozilla.org/en-US/docs/Learn/HTML/Multimedia_and_embedding/Responsive_images for info on defining 'sizes' for responsive images
    -->
    
      
        <source class="responsive-img-srcset" srcset="/assets/img/AlexNet_Implementation/CIFAR_original_image-480.webp 480w,/assets/img/AlexNet_Implementation/CIFAR_original_image-800.webp 800w,/assets/img/AlexNet_Implementation/CIFAR_original_image-1400.webp 1400w," type="image/webp" sizes="95vw" />
      
    
    <img src="/assets/img/AlexNet_Implementation/CIFAR_original_image.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();" />
  </picture>

  
</figure>

    </div>
</div>
<div class="caption">
    Figure 1: Sample Image of CIFAR-10 Dataset
</div>

<h2 id="data-augmentation">Data Augmentation</h2>

<p>The paper used two low-memory augmentation methods on the datasets. The first method involves extracting random 224x224 patches from the 256x256 images. This allows the model to train on these smaller sections rather than the entire image. The second part of this method is performing horizontal reflections on these patches.</p>

<p>But the images I worked with are only 32x32, so it’s not a good idea to slice them up even more. So, I added random 4-pixel padding around the edges of the images to increase their size and then randomly cropped the images back to 32x32 pixels. And finally performed horizontal reflections on these cropped images.</p>

<div class="row mt-3">
    <div class="col-sm mt-3 mt-md-0">
        



<figure>
  <picture>
    <!-- Auto scaling with imagemagick -->
    <!--
      See https://www.debugbear.com/blog/responsive-images#w-descriptors-and-the-sizes-attribute and
      https://developer.mozilla.org/en-US/docs/Learn/HTML/Multimedia_and_embedding/Responsive_images for info on defining 'sizes' for responsive images
    -->
    
      
        <source class="responsive-img-srcset" srcset="/assets/img/AlexNet_Implementation/CIFAR_padded_cropped-480.webp 480w,/assets/img/AlexNet_Implementation/CIFAR_padded_cropped-800.webp 800w,/assets/img/AlexNet_Implementation/CIFAR_padded_cropped-1400.webp 1400w," type="image/webp" sizes="95vw" />
      
    
    <img src="/assets/img/AlexNet_Implementation/CIFAR_padded_cropped.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();" />
  </picture>

  
</figure>

    </div>
    <div class="col-sm mt-3 mt-md-0">
        



<figure>
  <picture>
    <!-- Auto scaling with imagemagick -->
    <!--
      See https://www.debugbear.com/blog/responsive-images#w-descriptors-and-the-sizes-attribute and
      https://developer.mozilla.org/en-US/docs/Learn/HTML/Multimedia_and_embedding/Responsive_images for info on defining 'sizes' for responsive images
    -->
    
      
        <source class="responsive-img-srcset" srcset="/assets/img/AlexNet_Implementation/CIFAR_flipped-480.webp 480w,/assets/img/AlexNet_Implementation/CIFAR_flipped-800.webp 800w,/assets/img/AlexNet_Implementation/CIFAR_flipped-1400.webp 1400w," type="image/webp" sizes="95vw" />
      
    
    <img src="/assets/img/AlexNet_Implementation/CIFAR_flipped.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();" />
  </picture>

  
</figure>

    </div>
</div>
<div class="caption">
    Figure 2: Image After Padding and Random Cropping and After Horizontal Reflection
</div>

<p>The second augmentation method is PCA-based colour jitter to mimic global lighting variation. They first computed the eigenvectors (principal components) and eigenvalues of the 3 × 3 RGB covariance matrix over the zero-centred, unit-scaled ImageNet training set. For every training image they added the vector</p>

\[\Delta \;=\; \sum_{i=1}^{3} v_i \bigl(\lambda_i \alpha_i\bigr),\]

<p>where \(v_i\) is the i-th eigenvector, \(\lambda _i\) its eigenvalue, and \(\alpha_i \sim N(0,0.1)\). The same \(\Delta\) is added to all pixels, so colour shifts occur along directions of naturally high variance while the spatial structure stays untouched.</p>

<p>In my CIFAR-10 implementation I keep the one-shift-per-image idea but apply a tweak. Because I compute PCA on raw 0–255 pixels (not normalised data), the eigenvalues are several orders of magnitude larger than those in the paper. Multiplying by \(\lambda _i\) therefore pushed many channels beyond the valid 0–255 range; clipping them produced almost pure black or white frames. To avoid that I scale each axis by \(\sqrt{\lambda_i}\), its standard deviation, rather than the eigenvalue itself, and still multiply by a Gaussian noise term with \(\sigma = 0.1\). Empirically this keeps the jitter amplitude comparable to the original augmentation while preventing saturation, yet it still biases the perturbation toward directions of greater colour variance.</p>

<div class="row mt-3">
    <div class="col-sm mt-3 mt-md-0">
        



<figure>
  <picture>
    <!-- Auto scaling with imagemagick -->
    <!--
      See https://www.debugbear.com/blog/responsive-images#w-descriptors-and-the-sizes-attribute and
      https://developer.mozilla.org/en-US/docs/Learn/HTML/Multimedia_and_embedding/Responsive_images for info on defining 'sizes' for responsive images
    -->
    
      
        <source class="responsive-img-srcset" srcset="/assets/img/AlexNet_Implementation/CIFAR_PCA_jittered-480.webp 480w,/assets/img/AlexNet_Implementation/CIFAR_PCA_jittered-800.webp 800w,/assets/img/AlexNet_Implementation/CIFAR_PCA_jittered-1400.webp 1400w," type="image/webp" sizes="95vw" />
      
    
    <img src="/assets/img/AlexNet_Implementation/CIFAR_PCA_jittered.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();" />
  </picture>

  
</figure>

    </div>
</div>
<div class="caption">
    Figure 3: Image after PCA Color Jitter
</div>

<h2 id="model-architecture">Model Architecture</h2>

<p>The original AlexNet was designed for 224x224 ImageNet images and featured large convolutional kernels (e.g., 11x11 in the first layer) with aggressive downsampling (stride=4). It used overlapping 3x3 max-pooling (stride=2) and split computations across two GPUs to handle its depth and parameter count (60M+). The fully connected (FC) layers (4096 → 4096 → 1000 neurons) were massive, tailored for ImageNet’s 1,000-class output. In contrast, my implementation adapts to CIFAR-10’s 32x32 images by scaling down kernels (3x3 in the first layer), reducing strides (stride=2), and using non-overlapping 2x2 pooling to preserve spatial details. The FC layers are streamlined (4096 → 4096 → 10 neurons) to match CIFAR-10’s 10 classes, with dropout added to all FC layers (vs. only the first two in the original) to combat overfitting on the smaller dataset.</p>

<p>The adjustments address two core challenges: scale and generalization. CIFAR-10’s 32x32 images lack the fine-grained details of ImageNet, making large kernels and aggressive pooling counterproductive, they would erase critical spatial information. Smaller kernels and gentler downsampling retain discriminative features. Similarly, the original FC layers’ enormous parameter count would overfit CIFAR-10’s limited training data (50k vs. ImageNet’s 1.2M images). Streamlining the classifier and expanding dropout ensure better regularization. These changes preserve AlexNet’s foundational principles while optimizing it for smaller, modern tasks.</p>

<h2 id="training-setup">Training Setup</h2>

<p>The paper trained on 1.2 million ImageNet images using non-standard hardware: two NVIDIA GTX 580 GPUs (3GB VRAM each) with cross-GPU parallelization, where specific layers ran on separate GPUs. They used SGD with momentum (0.9), starting with a learning rate of 0.01, reduced manually by a factor of 10 when validation loss plateaued. Training ran for 90 epochs with mini-batches of 128 images, supplemented by weight decay (L2 penalty of 0.0005) and no gradient clipping. Critically, they initialized weights from zero-mean Gaussian distributions (σ=0.01) and biases to 1 (for conv2, conv4, conv5 and all FC layers), with data augmentation limited to PCA jitter, random crops, and horizontal flips.</p>

<p>Using Pytorch on a single T4 GPU for CIFAR-10’s 50,000 training images, I trained the model for 49 epochs using SGD with momentum (0.9) and weight decay (5e-4) at a batch size of 128. Key enhancements included learning rate scheduling (step decay from 0.01 to 0.001 after epoch 30), gradient clipping (max norm=2.0), and an expanded augmentation pipeline comprising PCA-based jitter, padding, random cropping, and horizontal flips. Additionally, weight initialization followed PyTorch’s defaults (Kaiming He for ReLU layers), and 20% of the training data was reserved for validation monitoring.</p>

<h2 id="result--analysis">Result &amp; Analysis</h2>

<p>The paper achieved a groundbreaking top-5 error rate of 15.3% on ImageNet, slashing the previous state-of-the-art by nearly 10%. This success stemmed from scaling depth (8 layers = 5 convolutional + 3 fully-connected layers), GPU parallelism, and novel techniques like ReLU/dropout. Crucially, PCA-based augmentation reduced top-1 error by &gt;1%, proving that lighting-invariant augmentations were vital for generalization. The model’s accuracy gains came at significant computational cost: 5-6 days of training across two GPUs.</p>

<p>My adapted AlexNet achieved 80.45% test accuracy after 49 epochs of training, with final training accuracy at 82.88% and validation accuracy at 79.29%. The model showed consistent improvement throughout training, with a significant accuracy jump after the learning rate decay at epoch 30 (from 76.53% to 79.73% train accuracy in one epoch).</p>

<div class="row mt-3">
    <div class="col-sm mt-3 mt-md-0">
        



<figure>
  <picture>
    <!-- Auto scaling with imagemagick -->
    <!--
      See https://www.debugbear.com/blog/responsive-images#w-descriptors-and-the-sizes-attribute and
      https://developer.mozilla.org/en-US/docs/Learn/HTML/Multimedia_and_embedding/Responsive_images for info on defining 'sizes' for responsive images
    -->
    
      
        <source class="responsive-img-srcset" srcset="/assets/img/AlexNet_Implementation/AlexNet_Train_Validation-480.webp 480w,/assets/img/AlexNet_Implementation/AlexNet_Train_Validation-800.webp 800w,/assets/img/AlexNet_Implementation/AlexNet_Train_Validation-1400.webp 1400w," type="image/webp" sizes="95vw" />
      
    
    <img src="/assets/img/AlexNet_Implementation/AlexNet_Train_Validation.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();" />
  </picture>

  
</figure>

    </div>
</div>
<div class="caption">
    Figure 4: Training/Validation Loss and Accuracy
</div>

<h2 id="conclusion">Conclusion</h2>

<p>My tweaked AlexNet hit 86% test accuracy on CIFAR-10 after just 49 epochs; training longer with stronger augmentations would likely lift it further. The complete implementation and training code can be found in my <a href="https://github.com/t0n4r/alexnet-implementation">GitHub repository</a> and on <a href="https://colab.research.google.com/drive/1xWEIkER_NJtfe9GYLByQEyxW681Vl1se">Colab</a>.</p>

<h2 id="references">References</h2>

<ol>
  <li>Krizhevsky, A., Sutskever, I., &amp; Hinton, G. E. (2012). ImageNet classification with deep convolutional neural networks. In Advances in neural information processing systems (pp. 1097-1105).</li>
  <li>CIFAR-10 dataset: <a href="https://www.cs.toronto.edu/~kriz/cifar.html">https://www.cs.toronto.edu/~kriz/cifar.html</a></li>
</ol>]]></content><author><name></name></author><category term="literature-review" /><category term="projects" /><category term="tutorial" /><summary type="html"><![CDATA[A comprehensive guide to implementing AlexNet from scratch]]></summary></entry><entry><title type="html">Mechanistic Interpretability via Learning Differential Equations</title><link href="http://localhost:4000/blog/2024/mechanistic-interpretability-via-learning-differential-equations/" rel="alternate" type="text/html" title="Mechanistic Interpretability via Learning Differential Equations" /><published>2024-07-16T00:00:00+06:00</published><updated>2024-07-16T00:00:00+06:00</updated><id>http://localhost:4000/blog/2024/mechanistic-interpretability-via-learning-differential-equations</id><content type="html" xml:base="http://localhost:4000/blog/2024/mechanistic-interpretability-via-learning-differential-equations/"><![CDATA[<p>TLDR; We report our intermediate results from the AI Safety Camp project “Mechanistic Interpretability Via Learning Differential Equations”. Our goal was to explore transformers that deal with time-series numerical data (either infer the governing differential equation or predict the next number). As the task is well formalized, this seems to be an easier problem than interpreting a transformer that deals with language. During the time of the project, we leveraged various interpretability methods for the problem at hand. We also obtained some preliminary results (e.g., we observed a pattern similar to numerical computation of the input data derivative). We plan to continue working on it to validate and extend these preliminary results.     Mechanistic interpretability tries to understand the algorithms implemented by a neural network. This requires inferring the features in the activation patterns of the transformer corresponding to particular patterns in the data it learns. Often this approach is quite successful. Perhaps, the most popularized example is the Golden Gate Claude, where the Anthropic team succeeded in finding a representation of the Golden Gate Bridge in Claude’s “mind” and tuned it so that the model became “obsessed” with the Golden Gate Bridge. Despite such impressive results, mechanistic interpretability is far from being a completely solved field. One may notice that the problem we are trying to solve has two levels of difficulties - for one we need to figure out not only how transformers represent the features of the data it learns, but also what these features are, bearing in mind human language is complex and not well formalized. In this project, we tackle these two problems separately. Instead of dealing with the complexity of the human language by using LLMs, we study mathematical transformers. We leveraged transformers that deal with time series: the ODEFormer trained to predict the symbolic form of the ordinary differential equation based on the data points from its solution, and Hugging Face Time-Series Transformer that predicts the next numerical value in the sequence. Since we are acquainted with the underlying mathematical problem, we need only figure out how the solution process is represented in the transformer’s activation pattern, which seems to be a much more tractable problem.One may wonder how understanding of the representation of ordinary differential equations can help with advancing mechanistic interpretability for LLMs. There are three potential benefits. First, learning a toy model can help advance understanding in a more complicated model, as they often share the same features. Second, if the natural abstraction argument is valid, we can expect abstractions that transformers learn from the real world to correspond to particular mathematical patterns. Finally, a fundamental understanding of the underlying gears of the transformer will likely be useful for interpretability in the long term, even if not now, like the fundamental understanding of electromagnetism did not immediately lead to practical benefits, but did later. During three months of the project, we succeeded to set up the interpretability tools for both ODEFormer and Hugging Face Time series transformer, and obtained a few preliminary results for the ODEFormer (see the next section).For most of our work, we used the ODEFormer, an 86M parameter encoder-decoder transformer with a beam search that predicts a symbolic form of the first-order autonomous ordinary differential equation   (i.e., equation of type d→xdt=f(→x).mjx-chtml {display: inline-block; line-height: 0; text-indent: 0; text-align: left; text-transform: none; font-style: normal; font-weight: normal; font-size: 100%; font-size-adjust: none; letter-spacing: normal; word-wrap: normal; word-spacing: normal; white-space: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0; min-height: 0; border: 0; margin: 0; padding: 1px 0}
.MJXc-display {display: block; text-align: center; margin: 1em 0; padding: 0}
.mjx-chtml[tabindex]:focus, body :focus .mjx-chtml[tabindex] {display: inline-table}
.mjx-full-width {text-align: center; display: table-cell!important; width: 10000em}
.mjx-math {display: inline-block; border-collapse: separate; border-spacing: 0}
.mjx-math * {display: inline-block; -webkit-box-sizing: content-box!important; -moz-box-sizing: content-box!important; box-sizing: content-box!important; text-align: left}
.mjx-numerator {display: block; text-align: center}
.mjx-denominator {display: block; text-align: center}
.MJXc-stacked {height: 0; position: relative}
.MJXc-stacked &gt; * {position: absolute}
.MJXc-bevelled &gt; * {display: inline-block}
.mjx-stack {display: inline-block}
.mjx-op {display: block}
.mjx-under {display: table-cell}
.mjx-over {display: block}
.mjx-over &gt; * {padding-left: 0px!important; padding-right: 0px!important}
.mjx-under &gt; * {padding-left: 0px!important; padding-right: 0px!important}
.mjx-stack &gt; .mjx-sup {display: block}
.mjx-stack &gt; .mjx-sub {display: block}
.mjx-prestack &gt; .mjx-presup {display: block}
.mjx-prestack &gt; .mjx-presub {display: block}
.mjx-delim-h &gt; .mjx-char {display: inline-block}
.mjx-surd {vertical-align: top}
.mjx-surd + .mjx-box {display: inline-flex}
.mjx-mphantom * {visibility: hidden}
.mjx-merror {background-color: #FFFF88; color: #CC0000; border: 1px solid #CC0000; padding: 2px 3px; font-style: normal; font-size: 90%}
.mjx-annotation-xml {line-height: normal}
.mjx-menclose &gt; svg {fill: none; stroke: currentColor; overflow: visible}
.mjx-mtr {display: table-row}
.mjx-mlabeledtr {display: table-row}
.mjx-mtd {display: table-cell; text-align: center}
.mjx-label {display: table-row}
.mjx-box {display: inline-block}
.mjx-block {display: block}
.mjx-span {display: inline}
.mjx-char {display: block; white-space: pre}
.mjx-itable {display: inline-table; width: auto}
.mjx-row {display: table-row}
.mjx-cell {display: table-cell}
.mjx-table {display: table; width: 100%}
.mjx-line {display: block; height: 0}
.mjx-strut {width: 0; padding-top: 1em}
.mjx-vsize {width: 0}
.MJXc-space1 {margin-left: .167em}
.MJXc-space2 {margin-left: .222em}
.MJXc-space3 {margin-left: .278em}
.mjx-test.mjx-test-display {display: table!important}
.mjx-test.mjx-test-inline {display: inline!important; margin-right: -1px}
.mjx-test.mjx-test-default {display: block!important; clear: both}
.mjx-ex-box {display: inline-block!important; position: absolute; overflow: hidden; min-height: 0; max-height: none; padding: 0; border: 0; margin: 0; width: 1px; height: 60ex}
.mjx-test-inline .mjx-left-box {display: inline-block; width: 0; float: left}
.mjx-test-inline .mjx-right-box {display: inline-block; width: 0; float: right}
.mjx-test-display .mjx-right-box {display: table-cell!important; width: 10000em!important; min-width: 0; max-width: none; padding: 0; border: 0; margin: 0}
.MJXc-TeX-unknown-R {font-family: monospace; font-style: normal; font-weight: normal}
.MJXc-TeX-unknown-I {font-family: monospace; font-style: italic; font-weight: normal}
.MJXc-TeX-unknown-B {font-family: monospace; font-style: normal; font-weight: bold}
.MJXc-TeX-unknown-BI {font-family: monospace; font-style: italic; font-weight: bold}
.MJXc-TeX-ams-R {font-family: MJXc-TeX-ams-R,MJXc-TeX-ams-Rw}
.MJXc-TeX-cal-B {font-family: MJXc-TeX-cal-B,MJXc-TeX-cal-Bx,MJXc-TeX-cal-Bw}
.MJXc-TeX-frak-R {font-family: MJXc-TeX-frak-R,MJXc-TeX-frak-Rw}
.MJXc-TeX-frak-B {font-family: MJXc-TeX-frak-B,MJXc-TeX-frak-Bx,MJXc-TeX-frak-Bw}
.MJXc-TeX-math-BI {font-family: MJXc-TeX-math-BI,MJXc-TeX-math-BIx,MJXc-TeX-math-BIw}
.MJXc-TeX-sans-R {font-family: MJXc-TeX-sans-R,MJXc-TeX-sans-Rw}
.MJXc-TeX-sans-B {font-family: MJXc-TeX-sans-B,MJXc-TeX-sans-Bx,MJXc-TeX-sans-Bw}
.MJXc-TeX-sans-I {font-family: MJXc-TeX-sans-I,MJXc-TeX-sans-Ix,MJXc-TeX-sans-Iw}
.MJXc-TeX-script-R {font-family: MJXc-TeX-script-R,MJXc-TeX-script-Rw}
.MJXc-TeX-type-R {font-family: MJXc-TeX-type-R,MJXc-TeX-type-Rw}
.MJXc-TeX-cal-R {font-family: MJXc-TeX-cal-R,MJXc-TeX-cal-Rw}
.MJXc-TeX-main-B {font-family: MJXc-TeX-main-B,MJXc-TeX-main-Bx,MJXc-TeX-main-Bw}
.MJXc-TeX-main-I {font-family: MJXc-TeX-main-I,MJXc-TeX-main-Ix,MJXc-TeX-main-Iw}
.MJXc-TeX-main-R {font-family: MJXc-TeX-main-R,MJXc-TeX-main-Rw}
.MJXc-TeX-math-I {font-family: MJXc-TeX-math-I,MJXc-TeX-math-Ix,MJXc-TeX-math-Iw}
.MJXc-TeX-size1-R {font-family: MJXc-TeX-size1-R,MJXc-TeX-size1-Rw}
.MJXc-TeX-size2-R {font-family: MJXc-TeX-size2-R,MJXc-TeX-size2-Rw}
.MJXc-TeX-size3-R {font-family: MJXc-TeX-size3-R,MJXc-TeX-size3-Rw}
.MJXc-TeX-size4-R {font-family: MJXc-TeX-size4-R,MJXc-TeX-size4-Rw}
.MJXc-TeX-vec-R {font-family: MJXc-TeX-vec-R,MJXc-TeX-vec-Rw}
.MJXc-TeX-vec-B {font-family: MJXc-TeX-vec-B,MJXc-TeX-vec-Bx,MJXc-TeX-vec-Bw}
@font-face {font-family: MJXc-TeX-ams-R; src: local(‘MathJax_AMS’), local(‘MathJax_AMS-Regular’)}
@font-face {font-family: MJXc-TeX-ams-Rw; src /<em>1</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/eot/MathJax_AMS-Regular.eot’); src /<em>2</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/woff/MathJax_AMS-Regular.woff’) format(‘woff’), url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/otf/MathJax_AMS-Regular.otf’) format(‘opentype’)}
@font-face {font-family: MJXc-TeX-cal-B; src: local(‘MathJax_Caligraphic Bold’), local(‘MathJax_Caligraphic-Bold’)}
@font-face {font-family: MJXc-TeX-cal-Bx; src: local(‘MathJax_Caligraphic’); font-weight: bold}
@font-face {font-family: MJXc-TeX-cal-Bw; src /<em>1</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/eot/MathJax_Caligraphic-Bold.eot’); src /<em>2</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/woff/MathJax_Caligraphic-Bold.woff’) format(‘woff’), url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/otf/MathJax_Caligraphic-Bold.otf’) format(‘opentype’)}
@font-face {font-family: MJXc-TeX-frak-R; src: local(‘MathJax_Fraktur’), local(‘MathJax_Fraktur-Regular’)}
@font-face {font-family: MJXc-TeX-frak-Rw; src /<em>1</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/eot/MathJax_Fraktur-Regular.eot’); src /<em>2</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/woff/MathJax_Fraktur-Regular.woff’) format(‘woff’), url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/otf/MathJax_Fraktur-Regular.otf’) format(‘opentype’)}
@font-face {font-family: MJXc-TeX-frak-B; src: local(‘MathJax_Fraktur Bold’), local(‘MathJax_Fraktur-Bold’)}
@font-face {font-family: MJXc-TeX-frak-Bx; src: local(‘MathJax_Fraktur’); font-weight: bold}
@font-face {font-family: MJXc-TeX-frak-Bw; src /<em>1</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/eot/MathJax_Fraktur-Bold.eot’); src /<em>2</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/woff/MathJax_Fraktur-Bold.woff’) format(‘woff’), url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/otf/MathJax_Fraktur-Bold.otf’) format(‘opentype’)}
@font-face {font-family: MJXc-TeX-math-BI; src: local(‘MathJax_Math BoldItalic’), local(‘MathJax_Math-BoldItalic’)}
@font-face {font-family: MJXc-TeX-math-BIx; src: local(‘MathJax_Math’); font-weight: bold; font-style: italic}
@font-face {font-family: MJXc-TeX-math-BIw; src /<em>1</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/eot/MathJax_Math-BoldItalic.eot’); src /<em>2</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/woff/MathJax_Math-BoldItalic.woff’) format(‘woff’), url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/otf/MathJax_Math-BoldItalic.otf’) format(‘opentype’)}
@font-face {font-family: MJXc-TeX-sans-R; src: local(‘MathJax_SansSerif’), local(‘MathJax_SansSerif-Regular’)}
@font-face {font-family: MJXc-TeX-sans-Rw; src /<em>1</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/eot/MathJax_SansSerif-Regular.eot’); src /<em>2</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/woff/MathJax_SansSerif-Regular.woff’) format(‘woff’), url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/otf/MathJax_SansSerif-Regular.otf’) format(‘opentype’)}
@font-face {font-family: MJXc-TeX-sans-B; src: local(‘MathJax_SansSerif Bold’), local(‘MathJax_SansSerif-Bold’)}
@font-face {font-family: MJXc-TeX-sans-Bx; src: local(‘MathJax_SansSerif’); font-weight: bold}
@font-face {font-family: MJXc-TeX-sans-Bw; src /<em>1</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/eot/MathJax_SansSerif-Bold.eot’); src /<em>2</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/woff/MathJax_SansSerif-Bold.woff’) format(‘woff’), url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/otf/MathJax_SansSerif-Bold.otf’) format(‘opentype’)}
@font-face {font-family: MJXc-TeX-sans-I; src: local(‘MathJax_SansSerif Italic’), local(‘MathJax_SansSerif-Italic’)}
@font-face {font-family: MJXc-TeX-sans-Ix; src: local(‘MathJax_SansSerif’); font-style: italic}
@font-face {font-family: MJXc-TeX-sans-Iw; src /<em>1</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/eot/MathJax_SansSerif-Italic.eot’); src /<em>2</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/woff/MathJax_SansSerif-Italic.woff’) format(‘woff’), url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/otf/MathJax_SansSerif-Italic.otf’) format(‘opentype’)}
@font-face {font-family: MJXc-TeX-script-R; src: local(‘MathJax_Script’), local(‘MathJax_Script-Regular’)}
@font-face {font-family: MJXc-TeX-script-Rw; src /<em>1</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/eot/MathJax_Script-Regular.eot’); src /<em>2</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/woff/MathJax_Script-Regular.woff’) format(‘woff’), url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/otf/MathJax_Script-Regular.otf’) format(‘opentype’)}
@font-face {font-family: MJXc-TeX-type-R; src: local(‘MathJax_Typewriter’), local(‘MathJax_Typewriter-Regular’)}
@font-face {font-family: MJXc-TeX-type-Rw; src /<em>1</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/eot/MathJax_Typewriter-Regular.eot’); src /<em>2</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/woff/MathJax_Typewriter-Regular.woff’) format(‘woff’), url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/otf/MathJax_Typewriter-Regular.otf’) format(‘opentype’)}
@font-face {font-family: MJXc-TeX-cal-R; src: local(‘MathJax_Caligraphic’), local(‘MathJax_Caligraphic-Regular’)}
@font-face {font-family: MJXc-TeX-cal-Rw; src /<em>1</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/eot/MathJax_Caligraphic-Regular.eot’); src /<em>2</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/woff/MathJax_Caligraphic-Regular.woff’) format(‘woff’), url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/otf/MathJax_Caligraphic-Regular.otf’) format(‘opentype’)}
@font-face {font-family: MJXc-TeX-main-B; src: local(‘MathJax_Main Bold’), local(‘MathJax_Main-Bold’)}
@font-face {font-family: MJXc-TeX-main-Bx; src: local(‘MathJax_Main’); font-weight: bold}
@font-face {font-family: MJXc-TeX-main-Bw; src /<em>1</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/eot/MathJax_Main-Bold.eot’); src /<em>2</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/woff/MathJax_Main-Bold.woff’) format(‘woff’), url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/otf/MathJax_Main-Bold.otf’) format(‘opentype’)}
@font-face {font-family: MJXc-TeX-main-I; src: local(‘MathJax_Main Italic’), local(‘MathJax_Main-Italic’)}
@font-face {font-family: MJXc-TeX-main-Ix; src: local(‘MathJax_Main’); font-style: italic}
@font-face {font-family: MJXc-TeX-main-Iw; src /<em>1</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/eot/MathJax_Main-Italic.eot’); src /<em>2</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/woff/MathJax_Main-Italic.woff’) format(‘woff’), url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/otf/MathJax_Main-Italic.otf’) format(‘opentype’)}
@font-face {font-family: MJXc-TeX-main-R; src: local(‘MathJax_Main’), local(‘MathJax_Main-Regular’)}
@font-face {font-family: MJXc-TeX-main-Rw; src /<em>1</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/eot/MathJax_Main-Regular.eot’); src /<em>2</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/woff/MathJax_Main-Regular.woff’) format(‘woff’), url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/otf/MathJax_Main-Regular.otf’) format(‘opentype’)}
@font-face {font-family: MJXc-TeX-math-I; src: local(‘MathJax_Math Italic’), local(‘MathJax_Math-Italic’)}
@font-face {font-family: MJXc-TeX-math-Ix; src: local(‘MathJax_Math’); font-style: italic}
@font-face {font-family: MJXc-TeX-math-Iw; src /<em>1</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/eot/MathJax_Math-Italic.eot’); src /<em>2</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/woff/MathJax_Math-Italic.woff’) format(‘woff’), url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/otf/MathJax_Math-Italic.otf’) format(‘opentype’)}
@font-face {font-family: MJXc-TeX-size1-R; src: local(‘MathJax_Size1’), local(‘MathJax_Size1-Regular’)}
@font-face {font-family: MJXc-TeX-size1-Rw; src /<em>1</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/eot/MathJax_Size1-Regular.eot’); src /<em>2</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/woff/MathJax_Size1-Regular.woff’) format(‘woff’), url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/otf/MathJax_Size1-Regular.otf’) format(‘opentype’)}
@font-face {font-family: MJXc-TeX-size2-R; src: local(‘MathJax_Size2’), local(‘MathJax_Size2-Regular’)}
@font-face {font-family: MJXc-TeX-size2-Rw; src /<em>1</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/eot/MathJax_Size2-Regular.eot’); src /<em>2</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/woff/MathJax_Size2-Regular.woff’) format(‘woff’), url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/otf/MathJax_Size2-Regular.otf’) format(‘opentype’)}
@font-face {font-family: MJXc-TeX-size3-R; src: local(‘MathJax_Size3’), local(‘MathJax_Size3-Regular’)}
@font-face {font-family: MJXc-TeX-size3-Rw; src /<em>1</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/eot/MathJax_Size3-Regular.eot’); src /<em>2</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/woff/MathJax_Size3-Regular.woff’) format(‘woff’), url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/otf/MathJax_Size3-Regular.otf’) format(‘opentype’)}
@font-face {font-family: MJXc-TeX-size4-R; src: local(‘MathJax_Size4’), local(‘MathJax_Size4-Regular’)}
@font-face {font-family: MJXc-TeX-size4-Rw; src /<em>1</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/eot/MathJax_Size4-Regular.eot’); src /<em>2</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/woff/MathJax_Size4-Regular.woff’) format(‘woff’), url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/otf/MathJax_Size4-Regular.otf’) format(‘opentype’)}
@font-face {font-family: MJXc-TeX-vec-R; src: local(‘MathJax_Vector’), local(‘MathJax_Vector-Regular’)}
@font-face {font-family: MJXc-TeX-vec-Rw; src /<em>1</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/eot/MathJax_Vector-Regular.eot’); src /<em>2</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/woff/MathJax_Vector-Regular.woff’) format(‘woff’), url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/otf/MathJax_Vector-Regular.otf’) format(‘opentype’)}
@font-face {font-family: MJXc-TeX-vec-B; src: local(‘MathJax_Vector Bold’), local(‘MathJax_Vector-Bold’)}
@font-face {font-family: MJXc-TeX-vec-Bx; src: local(‘MathJax_Vector’); font-weight: bold}
@font-face {font-family: MJXc-TeX-vec-Bw; src /<em>1</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/eot/MathJax_Vector-Bold.eot’); src /<em>2</em>/: url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/woff/MathJax_Vector-Bold.woff’) format(‘woff’), url(‘https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/fonts/HTML-CSS/TeX/otf/MathJax_Vector-Bold.otf’) format(‘opentype’)}
,  where the function in the right hand side does not have explicit time dependence) from the numerical solution of this equation. How does ODEFormer solve the problem of inferring differential equation? We formed a hypothesis that it does so in a few steps. First, it takes a numerical derivative of the data dxdt, presumably in the encoder part of the transformer. Second, it tries to apply various analytical functions to the data f(x) and compares it to the derivative during the beam search process in the decoder. When dxdt is close enough to f(x) (e.g. R^2 is smaller than a threshold), the beam search is over and ODEFormer outputs solution finishing with end-of-the-statement token. An alternative to this hypothesis would be that ODEFormer explores the data holistically, observing multiple different patterns (like periodicity, asymptotics, higher-order-derivatives patterns, or potentially even patterns that do not correspond to something humans usually pay attention to) and classifying equations based on these patterns.The way to validate the derivative hypothesis is to explicitly find this algorithm, namely, the derivative calculation in the encoder and comparison with the different functions by computing R^2 in the decoder. In our preliminary results, we find the indication of derivative calculation in simple cases, but we have not yet observed the R^2 calculation. To study the effect and calculation of the derivative, as a model system we choose data produced by the logistic equation dxdt=Ax(B–x)  whose solution is a well-known sigmoid function x(t)=B1+e−AB(t–t0). This allows us to pay special attention to the region of high derivative, its width and location.First, we explore the info-weighted attention matrices of the heads in the encoder layer (see the Appendix for the details and motivation of adding info-weighted attention). We observe that one of the heads (namely layer 2 head 7) seemed to pick up on differentiation or the portion of the curve where the slope is not 0, by having the attention mostly directed to the region of high derivative (see Fig. 1). Second, we used linear probes (see Appendix) to predict both the time-point of maximum derivative of sigmoid functions, and the value of the maximum derivative from the activations of the encoder layer. The probes were able to predict the time-point of the maximum derivative successfully, which indicates that this information is linearly represented in the activations (see Fig. 2).  However, somewhat surprisingly, probes were then not able to predict the value of the maximum derivative well (Fig.2). This may mean that we did not train on enough data, which we may address in future work. Another option is that the absolute value is stored in a complicated way in a transformer (i.e. mantissa and exponent separately), which obstructs successful inference using a linear probe. While these preliminary results seem to support our derivative hypothesis, they are not conclusive enough to eliminate alternative explanations. Indeed, we tried only logistic equation in this analysis, and so far we were not able to infer the value of the derivative, only the time-point of maximal derivative and the width of the transition region– and this can be obtained without explicit derivative calculation everywhere.  So we need to perform more tests to validate this hypothesis. To investigate the second part of the suggested algorithm (comparing the derivative with various function forms) we use linear probes to predict the R^2 score of the trajectory under the predicted equation. However, the results offer little evidence to suggest that the R^2 score is linearly encoded in the model activations in any of the layers.We expected the final decoder layer to have the best performance. This is because this is the layer ‘closest’ to the output equation, from whose trajectory we compute the R^2 score we want to predict. However, this turned out not to be the case, as can be seen above (looking at layer index 15). In fact, layer 4 of the decoder (i.e. layer index 7) shows the greatest performance, which is somewhat surprising.While the results are somewhat negative, part of the issue could be that this experiment only considered 1D ODEs generated using the same method as the ODEFormer’s pre-training data. The ODEFormer is typically able to handle such equations quite well, usually achieving R^2 scores close to 1. Then, when data which is not fit well appears, the probe is unable to predict the correct value, since it has almost only seen R^2 values close to 1. In addition to the validation of the derivative hypotheses, we explored how ODEFormer encodes equation type classification. We observe a high performance of linear probes at the decoder layer for classifying between linear and hyperbolic equations, as well as equations in different dimensions. Using Sparse Auto-encoders, we see that there are features that are active only at a certain equation families, but we have not figured out yet the general rule for them. For more specific task, following Logit Lens we observed at which layer the sign of the term in the linear differential equation gets predicted, and how this prediction is inherited through the layers of the transformer. Finally, we also explore the performance of the probes at the inferring eigenvalues for 2D linear differential equations to test a hypothesis that ODEFormer first infers eigenvalues from the data, and then constructs the equation from these eigenvalues. If this were the case, we would expect higher accuracy or inferring eigenvalues than of the equation coefficients at some of the lower layers of the ODEFormer. So far we do not observe this effect, what leaves the eigenvalue hypothesis unsupported. See Appendix for more details on these results.As part of this research, we developed and open-sourced a set of tools to support reproducibility and further exploration:As we have our interpretability tools finally working and producing preliminary results, we would like to validate them and do more exploration. Do we indeed see the computation of the derivative? Can we observe it in the activation pattern? Can we pinpoint the exact mechanism for how it is calculated (e.g. is it a simple finite difference scheme, or some higher-order differential scheme)?  Would we see the same pattern - the calculation of the derivative - for the time-series transformer? Moreover, we hope to interpret other potential feature that distinguishes different types of equations, found by SAE, as well as to find other features. After we understand the mechanism of transformers above, we would like to explore whether there are any similarities in activation patterns for mathematically trained transformers and LLMs (like Llama, for example). There, it was demonstrated that LLMs (Llama and GPT-3) can do a time-series prediction task without any additional fitting or prompts. Do LLMs use the same algorithm to perform the time-series prediction task as a specialized transformer? Finally, understanding what patterns in LLMs are activated during the time-series prediction task may shed some light on language prediction, if in certain circumstances we observe activation of these patterns even in language tasks.As the project is still in its intermediate stage, we are quite flexible and looking forward to your feedback.Activation Maximization Team○ Utkarsh Priyadarshi - Lead○ Varun PiramIn our project, each team leveraged a different method of mechanistic interpretability to understand the inner workings of the ODEFormer and the time series transformer. Our goal with these methods was to better understand how and why these toy models work by uncovering patterns, structures or circuits within. We had a team devoted to studying the time series transformer and also looked into the following  methods specifically for the ODEformer - Attention lens, Logit Lens, Probing and Sparse Autoencoders. In earlier efforts, we also looked into SHAP but did not find this method fruitful. We studied the time series transformer model via HuggingFace’s leveraging univariate and multivariate datasets for training. Rather than inferring a governing law producing numerical values at each time point, this model predicts the next value. We focused mainly on using Sparse autoencoders on activations to understand what the time-series model learned after training. We succeeded to train the univariate and multivariate time-series transformer to predict the data generated by the simple functions (linear combinations of trigonometric functions, polynomials, exponential functions and hyperbolic functions). We used sparse autoencoders to infer the features corresponding to different types of functions, but did not interpret our results yet.For most of our work, we used the ODEFormer, an 86M parameter encoder-decoder transformer that predicts a symbolic form of the first-order autonomous ordinary differential equation   (i.e., equation of type d→xdt=f(→x),  where the function in the right hand side does not have explicit time dependence) from the numerical solution of this equation. This section gives a brief overview into each method used for interpreting each model.Probing is a method that involves training small auxiliary machine learning models (probes) to predict specific properties from the internal representations of a larger model. The idea is to isolate whether particular features are encoded in the hidden layers of a model. A simple probe (e.g. a logistic regression classifier) is trained on representations from various layers to determine how linearly accessible these properties are. We leverage probing to track the flow of information through a model’s layers and identify where specific knowledge is stored or transformed.We used binary probes for a simple classification task between two types of equations (exponential dx/dt=ax and hyperbolic dx/dt=ax2). Layers indexed 0-3 represent the encoder while layers 4-15 represent the decoder. The classification accuracy increases through the encoder and is basically perfect throughout the decoder. This is despite the fact that ODEFormer has some performance issues on the exponential samples, and only reconstructs a trajectory with R^2 &gt;= 0.5 for 40% of the samples.We also trained a probe to classify one- and two-dimensional systems. The probe accuracy was fairly high, after an initial decrease in the first three decoder layers; however we had expected near-perfect accuracy, as we expected ODEFormer would easily detect the system dimension from the dimension of the input trajectory. It may be that, by the decoder layers, such information is not required to be linearly represented in the activations. Also surprisingly, ODEFormer itself occasionally gets the dimensionality of predicted ODE systems wrong, but it is unclear why this happens. We also explored 2D linear systems on the form dxdt=αx+βydydt=γx+δyOne hypothesis is that the ODEFormer uses the eigenvalues of the matrix A=(αβγδ)as those determine the behaviour of the system. We also predicted the coefficients   α, β, γ and δ as a baseline and comparison as these are necessary to represent in order produce right hand expression of the system.The performance for the eigenvalues is worse than for the coefficients. This is some evidence against the hypothesis that the ODEFormer calculates eigenvalues in order to solve ODEs. We generated 10k samples with different combinations of the coefficients α ∈ [-2,2], β ∈ [-2,2], γ ∈ [-2,2], δ ∈ [-2,2]. Then we only included the 2732 samples, the ODEFormer performed well on (R^2 &gt;= 0.9) as we are particularly interested in the ODEFormer’s behaviour when it is successful. In summary the results from the probing experiments give negative results for the hypothesis that the network is searching through different expressions while evaluating the performance. We also got negative results for the network calculating eigenvalues of the coefficient matrix. Finally we got some positive results for the network calculating the location of maximum differention. This was further explored with Attention Lens.Sparse autoencoder is an unsupervised neural network architecture designed to learn compressed representations of input data while enforcing sparsity in the hidden layers. In our context for interpreting the ODEformer, we trained SAEs on the internal activations of the ODEformer model so that the encoder learned a small number of active neurons (features) that could reconstruct the original activations. This sparsity constraint encourages the network to learn disentangled and interpretable features, rather than dense, overlapping representations.Once trained, these sparse features were analyzed to understand what kind of patterns or concepts the original model encodes internally.We succeeded to train Sparse Auto Encoders to analyze the features corresponding to particular equation type. Our analysis of ODEFormer using Sparse Autoencoders revealed several interesting and unexpected findings about how the model processes differential equations:We developed some key functions shared in this library that allowed us to directly plot token charts, and obtain intermediate tokens/logits, and also attentions (value-weighted as well as the usual ones). We have several interesting discoveries related to using Logit Lens on a sign prediction example. We consider a one-dimensional decreasing exponential function, and show the results using Logit Lens in the figure below (Fig. LALens1).The figure shows how tokens and the corresponding confidence scores vary across beams and decoder layers. For this experiment, we had a simple 1-dimensional decreasing exponential system primarily because we had observed earlier this system results in a simple 6-token output. Also, we opted for a low beam size (2) and temperature (0.1) for simplicity purposes. A general observation is that in the initial layers across all beams, the model still “predicts” the previous token, as a result of the autoregressive nature of the model. However, this finding seems to not hold true while predicting a token following a constant. In the successive layers the probability is spread across several tokens which is why we see a more uniformly blue distribution, and in the last few layers the model becomes more confident on the right token. However, in the token_3 subplot, the prediction of the constant does not become very confident even in the last few layers like in the other subplots. This is an obvious result because these constants are in the order of 10-4 and small changes in magnitude does not really affect the overall expression that significantly.Next, we zoom in a bit on the sign prediction. We have observed that in cases of both decreasing and increasing exponential systems, decoder layer 6 predicts ‘+’ which is corrected to ‘-’ in case of a decreasing system only in decoder layer 8. This is a bias towards ‘+’ presumably because most numbers in the training data were positive. We have observed the attention patterns in the various heads as well and found some heads to be capturing the direction of maximum to minimum magnitude or vice versa in the trajectory which according to us was key in determining the sign.In Fig. LALens2, we can see that the rows corresponding to the value 1 on the y-axes show that pattern of minimum to maximum or vice versa. Row 1 represents the token ‘+’ or ‘-’. These three patterns are from the attention heads of layers in which the ‘+’ or ‘-’ was being predicted for the first time. The single dominant column of attention will be discussed later. We have also looked at the logits of the ‘+’ and ‘-’ tokens and their evolution through the layers.There is no immediate difference between the two plots, however, since ‘+’ was being predicted from layer 6 onwards, and logit 389 also seems to be more activated from layer 6 onwards (here 7 because naming begins from 1). Thus, we conclude that most of the information about the sign is being carried by logit 389 alone.Now, let’s discuss the attention patterns in more detail. Broadly, we have found concrete proof that null attentions or attention sinks exist in the ODEFormer, a point that is being mostly attended to but seem to not be any special at first glance, and attention heads that capture various mathematical qualities of the input trajectory that is helpful for the model to predict the symbolic expression.During our experiments, we have seen (oftentimes with higher beam sizes) that after we get an “<EOS>” token in one beam, but other beams have not yet yielded an “<EOS>” token, the beam gives an “add” token or another “<EOS>” token as the next token. If all other beams have still not yielded “<EOS>”, then the beam that has finished predicting an expression will continue to give the “N1010” constant token as a kind of “default” token. For example, consider the following three plots in Fig. LALens4 below for predicting three consecutive tokens. In the figure’s first subplot, we can see that beams 1-4 and 7 and 9 gave the “<EOS>” token as output, but other beams did not. Due to this, when the other beams carry on with the next token prediction, these 6 beams either give the “<EOS>” or “add” token as output. And if the other beams still have not finished, these 6 beams continue giving the “N1010” token as a “default” token till all other beams give the “<EOS>” token hence completing the process of token generation, after which post-processing takes place.We have observed that Attention sinks are present in the attention plots. These are input trajectory points that are more attended to compared to other points, but they do not transfer much information/value as seen from the value-weighted attention plots. What is value-weighted attention? Taken from “A Mathematical Framework of Transformer Circuits” (Elhage et al.) this is a method where we scale the attention in the plots by how much information is transferred which is obtained from the value vectors. Value-weighted attention shows us exactly how much value is being transferred by attending to the input points. So, if a certain point is being mostly attended to by all or most tokens (or the points themselves in case of encoder self-attention) in the usual attention plots but not much in the value-weighted ones, it means that point is not contributing much to the prediction despite being mostly attended to. Here, we say that point is an attention sink. However, even interesting is the discovery of a different class of points which we call MAT points (mostly-attended-to points). These are points that are mostly attended to by all (or most) tokens as seen in the decoder’s value-weighted cross-attention plots. The fact that these points light up almost similarly in the value-weighted plot as in the normal ones, show that these points are indeed transferring information that is needed by the model (specifically the decoder layers) to correctly predict the symbolic expression. For example, let’s consider the attention plot in Fig. LALens5. In the top subplot, we can clearly see that there are a few input points/columns that are more lit up than the others. These points are also transferring information so the attention on them is not completely useless.  Lighting up is fun but now you might be questioning what is that one defining mathematical quality of the MAT point/column. We have plotted the norm of all of the points of the harmonic sine-cosine system, and from the plot Fig. LALens6, we see that the MAT point has a very low norm.Let’s also take an example now to show the difference between value-weighted attentions and normal attentions. In Fig. LALens7, we clearly see more information transferred by other points but the MAT point is still clearly visible. The attention plots in the figure correspond to the system in the previous figure.   We initially observed that a few points or just one point (most of the time) was mostly attended to by all input points or tokens, but this was random. The randomisation was done in the ODEFormer to get random points as input in case the original input was very long. Due to us experimenting with very short trajectories, we did not require the randomisation and on turning it off, we observed that the MAT point is constant. What do we mean? Take ‘x’ number of input points from various systems across dimensionalities and feed it to the ODEFormer to get the predicted symbolic expression. In all cases, the MAT point remains the same. Change the number of input points to ‘y’ and the MAT status will shift. This shifting is also predictable according to what we have seen in our experiments. Roughly the 60% point in the input trajectory is going to get the MAT status. So if we have 25 input points in the trajectory, point 15 is going to become the MAT point. One can think at this moment that this sounds like there was some split internally in the trajectory in the ratio of 60:40, and that the MAT point lighting up seems like the Attention sink of the latter 40% of the trajectory since sinks are mostly the first point. This is a great idea but what we have observed in some cases is that there is no abrupt change in the attention pattern before and after the MAT point. In some cases we see that attention is being gradually increased as we move closer to the MAT point from the left, and then decrease in a same gradual manner when we move past the point. If the hypothesis that there is some kind of internal training and testing “split” we should not have seen that gradual attention build-up around the MAT point. And this pattern is only visible in the value-weighted plots, suggesting that even though points around the MAT point are not that attended to, they do transfer helpful information. Additionally, we conducted a few tests to get more insights on the nature of the MAT point. Please note that the attention in the following plots are not value-weighted to clearly differentiate the MAT point.Before the experiment we believed that the MAT status is because of the value of the point in the trajectory. However, when we modified the trajectory to have different values for points for the same time step (i.e. made the slope infinity for some portion of the input trajectory) the model seemed to be attending to all of the points more than the others. As seen in the second subplot of Fig. LALens8, the model tries to give the MAT status to all of the points at that one particular time step. However, it does not do that and instead we see a gradual decay in the attention the points receive as we move right. Now, let’s move to the third experiment and subplot, where we changed the MAT point’s value to be something entirely different. In this case, we see some other point get the MAT status. But wait a second! It’s what we call the second MAT point that got the MAT status after we shifted the original first MAT point by some degree. What is the second MAT point you ask? When we look back at the first subplot we see not only the MAT point having higher attention than the others, but there are two more points that get somewhat more attention than the others. We call these the MAT candidates. And as we saw, changing the MAT point’s value to some extent results in one of the candidates getting the MAT status instead. Now, a great question would be by how much should the MAT point be shifted? Let’s take a look at Fig. LALens9. Changing the MAT value by anything lesser than or equal to approximately |0.05| retains most of the usual cross-attentions and does not result in the shifting of the MAT status to some candidate. Increasing the value more than |0.05| however shows a significant dip in the averaged cross-attentions of the MAT point, and the MAT status is given to some other candidate point.One last thing that we observed in some value-weighted attention heads in several systems, is that there exists some heads in which the MAT point is less attended to than the other points. These heads mostly capture some mathematical aspect of the input trajectory, but even if the MAT point should have some attention logically (perhaps it is the maxima or minima or something else), it does not.Just like we have seen development of low level and high level feature development in the attention heads of vision transformers, we observe that similar phenomenon occurs in ODEFormer as well. The ODEFormer too has specialized attention heads that capture some mathematical aspects/qualities/features of the n-dimensional input system. We are going to look at the self-attentions of the encoder blocks only since it has input points in both axes and we are interested in seeing what data is “encoded” into the trajectory in the encoder layers, before the decoder layers engage in cross-attention while predicting tokens. We treat the ODEFormer as a multi-modal transformer, because it takes input in the trajectory space and gives output in a well-defined token space.Now, let’s take a closer look at some of the features in some systems. Some context first: In Fig. LALens10, Sigmoid_a_b means the input trajectory was that of a sigmoid following the equation: When changing the values of a and b, i.e. changing the position and steepness of the sigmoid curve, we noticed that one head, encoder layer 2 head 7, seemed to pick up on differentiation or the portion of the curve where the slope is not 0. Besides this, we also noticed that encoder layer 0 head 8 seems to trace the input sigmoid trajectory itself at first glance. On changing the sigmoid curve to go from 0 to 1 (instead of 1 to 0 in the figure) the tracing head does not seem to perfectly trace the trajectory, but the pattern does shift with shifts in the transition portion.In Sigmoid_3_20 you might notice that the prediction doesn’t really look like a sigmoid curve at all. This is true and we have found there to be some degree to which the ODEFormer can predict the sigmoid curve accurately. Beyond a certain degree of steepness, the ODEFormer derails quickly and badly, telling us that the transformer is unable to handle quick and non-periodic transitions in the input trajectory. Though the model managed to capture significant sigmoid-like features in case of Sigmoid_3_20, it failed completely in case of Sigmoid_8_25 (which is way steeper). There are no meaningful patterns in the attention heads, just noise.While differentiation is a high-level concept, tracing is a relatively low-level low-effort concept, and we hypothesize this to be the reason why a head in layer 0 traces the trajectory, but differentiation is captured by a head in layer 2. To support this, we also found heads that capture the maximum and minimum in a trajectory in layer 0. These findings might not be true for all systems, but from what we have observed these are true for some simple systems.While observing the cross-attentions after feeding the ODEFormer the trajectory in Fig. LALens6, we observed a few more attention heads that seemed to capture the minima and maxima.While the left attention head in Fig. LALens11, decoder layer 5 head 12, captures the minima and maxima but shows no separation of dimension considering that the input trajectory belongs to a 2D system. On the right however, we see that decoder layer 6 head 4 captures the minima and maxima with the dimension separation. While the pattern in the left head is rare and only comes before we see the pattern in the right head around the middle layers. The pattern with the dimension separation is more common in these kinds of systems.We plotted the OV matrix as obtained by multiplying the output weight matrix with the value weight matrix for each encoder layer, and the results are in Fig. LALens12. The most striking feature in all four encoder layers is the presence of a leading diagonal that is significantly more negative than the other values. In addition to this, the diagonal seems to become more and more negative as we move from layer 0 to the final layer of the encoder.This coincides with the formation of the MAT point, which begins forming from the second encoder layer, and only becomes more and more pronounced as we go towards the final encoder layers. We think this negative diagonal to be some form of self-inhibition, something like decreasing the effect of itself, and we are yet to ascertain the importance of this self-inhibiting behaviour in the formation of the MAT point (if at all).This strikingly negative leading diagonal is present in the OV plots of the encoder layers, across different systems.</EOS></EOS></EOS></EOS></EOS></EOS></EOS></p>]]></content><author><name></name></author></entry><entry><title type="html">AI Risks: Misuse, Accidents, and Rogue Systems</title><link href="http://localhost:4000/blog/2024/ai-risks/" rel="alternate" type="text/html" title="AI Risks: Misuse, Accidents, and Rogue Systems" /><published>2024-02-18T00:00:00+06:00</published><updated>2024-02-18T00:00:00+06:00</updated><id>http://localhost:4000/blog/2024/ai-risks</id><content type="html" xml:base="http://localhost:4000/blog/2024/ai-risks/"><![CDATA[<div class="caption">
    <q><i>I wrote this blog as an exercise for the AI Safety Fundamental Governance Course.</i></q>
</div>

<h2 id="do-you-think-risks-arising-from-misuse-accidents-or-rogue-agentic-ai-systems-are-more-likely-to-cause-harm">Do you think risks arising from misuse, accidents, or rogue, agentic AI systems, are more likely to cause harm?</h2>

<p>At the current pace of AI research and deployment, I believe all three categories of risk - misuse, accidents, and rogue AI systems - pose significant threats that could lead to substantial harm. While it’s difficult to definitively say which type of risk is “more likely” to cause harm, each scenario has the potential for catastrophic consequences.</p>

<p>Accidents are a major concern as AI systems become increasingly complex and are deployed in high-stakes real-world applications. Due to the “black box” nature of many modern AI models, it can be very challenging to fully understand and predict their behavior, especially in edge cases or when deployed in novel environments. Bugs, data biases, and other unintended interactions could lead to serious accidents that harm individuals or communities.</p>

<p>Misuse is also a grave risk as powerful AI capabilities fall into the wrong hands. Malicious actors could leverage AI for surveillance, disinformation, cyber attacks, autonomous weapons, and other nefarious purposes. The low barriers to entry and potential for AI-enabled scale and automation make this a concerning threat.</p>

<p>Perhaps most worrying are scenarios involving the development of highly agentic, self-improving AI systems that could spiral out of human control. While still speculative, the potential for such “rogue AI” to recursively enhance its own capabilities and pursue goals misaligned with human values and interests is a risk that many experts take very seriously.</p>

<h2 id="does-this-answer-change-when-limiting-your-time-horizon-to-5-years-15-years-and-30-years">Does this answer change when limiting your time horizon to 5 years, 15 years and 30 years?</h2>

<p>In the near term (5 years), I believe accidents are the most immediate and pressing concern. As AI systems become more complex and are rapidly deployed in high-stakes applications, the risk of accidents causing harm to individuals and communities is quite high. Misuse is also a concern, as bad actors could leverage existing AI capabilities, but the threat of rogue AI systems is likely lower in the next 5 years.</p>

<p>Looking out 15 years, the risk profile changes. Advancements in AI capabilities and autonomy increase the potential for misuse and the development of agentic AI systems. While accidents will still be a risk, the threat of malicious use and difficult-to-control AI systems may surpass it. Careful research into AI alignment - ensuring AI systems remain robustly aligned with human values and interests - becomes critical.</p>

<p>Over a 30-year horizon, the potential development of transformative AI systems that far exceed human capabilities introduces an existential risk. If such systems are not reliably aligned with human values and interests, they could pose an existential threat to humanity. Accidents and misuse would still be concerns, but the emergence of uncontrolled, superintelligent AI would become the predominant risk.</p>

<h2 id="minimizing-research-gaps-of-general-ai-research-and-alignment-research">Minimizing Research Gaps of General AI Research and Alignment Research</h2>

<p>A critical priority must be to close the widening gap between advancements in general AI capabilities and progress in AI alignment research. As powerful AI systems become increasingly sophisticated and autonomous, ensuring their goals and behaviors remain robustly aligned with human values grows ever more crucial. Yet, current trends suggest alignment research is struggling to keep pace with the rapid progress in general AI. This disparity is deeply concerning, as it increases the risk that powerful, unaligned AI systems could emerge before we have the necessary technical and conceptual tools to ensure they remain under meaningful human control. To mitigate this risk, a concerted effort is needed to elevate the importance of alignment research and dramatically accelerate progress in this domain.</p>]]></content><author><name></name></author><category term="ai-safety" /><summary type="html"><![CDATA[This article explores the evolving threats of AI misuse, accidental harm, and the rise of autonomous systems, emphasizing the critical need for alignment research to ensure AI's safety and ethical integration into society.]]></summary></entry></feed>