Back to blog

GPT-2 miniを自作して学ぶTransformerの仕組み - 壊して観測して理解する

nn.Transformerを使わずにAttention、MLP、LayerNorm、Residualを自前で実装し、各コンポーネントを無効化する実験を通じてTransformerがなぜ動くのかを理解する

なぜ GPT-2 mini を自作したのか

LLM(大規模言語モデル)が話題だが、「なぜ Transformer は学習できるのか」「なぜあの構造なのか」を腹落ちさせたかった。

論文を読むだけでは理解が浅い。実際に手を動かして、壊して観測して理解するアプローチを取ることにした。

実装方針

  • nn.Transformer は使わない
  • Attention / MLP / LayerNorm / Residual を自前で組む
  • 実験スイッチで各コンポーネントを無効化できるようにする
  • Mac(MPS)で動作確認

コード全体は GitHub リポジトリ で公開している。

Decoder-only Transformer の全体像

GPT 系のモデルは以下の流れで動作する。

入力トークン

Token Embedding + Positional Embedding

N回 Transformer Block を通す

LayerNorm

Linear(語彙サイズ)→ logits

Cross Entropy Loss

各ステップを順番に見ていく。

1. Embedding(埋め込み)

トークン(文字や単語)を数値ベクトルに変換する。

self.tok_emb = nn.Embedding(vocab_size, n_embd)  # トークン埋め込み
self.pos_emb = nn.Embedding(block_size, n_embd)  # 位置埋め込み

なぜ位置埋め込みが必要か

Attention は「どの位置のトークンか」という情報を持たない。RNN と違い、入力を並列処理するため、位置情報を明示的に与える必要がある。

2. Causal Self-Attention

Transformer の心臓部。「どのトークンに注目するか」を学習する。

# Q, K, V を線形変換で作成
qkv = self.c_attn(x)
q, k, v = qkv.split(n_embd, dim=2)

# スケールド内積
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(head_dim))

# Causal mask(未来を見ない)
att = att.masked_fill(causal_mask == 0, float("-inf"))
att = F.softmax(att, dim=-1)

# 重み付き和
y = att @ v

Causal mask とは

GPT は次のトークンを予測するタスクなので、未来のトークンを見てはいけない。下三角行列でマスクすることで、各位置は自分より前のトークンだけを参照できる。

位置 0: [1, 0, 0, 0]  → 自分だけ見える
位置 1: [1, 1, 0, 0]  → 0と1が見える
位置 2: [1, 1, 1, 0]  → 0,1,2が見える
位置 3: [1, 1, 1, 1]  → すべて見える

3. MLP(Feed-Forward Network)

位置ごとの非線形変換。Attention で「どこを見るか」を決めた後、「何を抽出するか」を学習する。

class MLP(nn.Module):
    def __init__(self, config):
        self.c_fc = nn.Linear(n_embd, 4 * n_embd)   # 拡張
        self.gelu = nn.GELU()                        # 活性化
        self.c_proj = nn.Linear(4 * n_embd, n_embd) # 縮小

隠れ層を 4 倍に拡張するのは GPT-2 の設計に従った。GELU は ReLU より滑らかな活性化関数。

4. Residual 接続

入力を出力に加算する。

x = x + self.attn(self.ln1(x))  # Attention + Residual
x = x + self.mlp(self.ln2(x))   # MLP + Residual

なぜ必要か

深いネットワークでは勾配が消失しやすい。Residual 接続は「情報の高速道路」として機能し、勾配が直接流れる経路を提供する。

5. LayerNorm の位置(Pre-LN vs Post-LN)

Pre-LN(現代の標準)

x → LayerNorm → Sublayer → +Residual

Post-LN(オリジナル Transformer)

x → Sublayer → +Residual → LayerNorm

Pre-LN は勾配が安定しやすく、学習率に対してロバスト。

実験:壊して観測する

各コンポーネントを無効化して、何が起きるか観測した。

実験 1: Residual OFF

python train_gpt2_mini.py --disable_residual 1 --steps 30

結果

設定30step 後 Loss
Baseline2.3
Residual OFF4.15

考察

Loss がほぼ下がらない。情報が深い層に流れず、学習が成立しない。

実験 2: Post-LN

python train_gpt2_mini.py --ln_style post --steps 30

結果

設定30step 後 Loss
Pre-LN2.3
Post-LN3.64

考察

学習は進むが遅い。初期の勾配分布が不安定で、収束に時間がかかる。

実験 3: Attention OFF(MLP-only)

python train_gpt2_mini.py --disable_attention 1 --steps 30

結果

設定30step 後 Loss
Baseline2.3
Attention OFF2.78

考察

Loss は下がるが、生成テキストの文脈依存が弱い。MLP は位置ごとの変換しかできないため、離れたトークン間の依存関係を捉えられない。

学んだこと

1. Residual は必須

深いネットワークで学習を成立させるには、勾配が直接流れる経路が必要。Residual なしでは 4 層程度でも学習が壊れる。

2. LayerNorm の位置は重要

Pre-LN は Post-LN より安定。現代の LLM が Pre-LN を採用している理由が体感できた。

3. Attention の役割

「位置をまたぐ情報の統合」が Attention の本質。MLP だけでは文脈を捉えられない。逆に言えば、短い依存関係なら MLP でも学習できる。

4. 壊して学ぶ価値

論文を読むだけでは「なぜその設計なのか」がわからない。実際に壊してみることで、各コンポーネントの必要性が腹落ちする。

次のステップ

  • Attention entropy のログ化(どれだけ鋭く見ているか)
  • 層別の grad_norm 可視化
  • BPE tokenizer との比較

コードは GitHub で公開している。実験を通じて Transformer の理解を深めたい方はぜひ試してほしい。