Tiny GPT

Implementing a Tiny GPT from Scratch.
TinyGPT
Author

Madhav Kanda

Published

December 10, 2025

Tiny GPT

This notebook implements a byte-level language model:

  • Byte-level tokenizer (0..255) with encode/decode demo
  • Dataset creation, preview, and (x, y) next-token shift check
  • Tiny GPT model (embeddings → Transformer blocks → LM head)
  • Training loop (AdamW, grad clipping, optional AMP) with periodic eval
  • Sampling with temperature, top-k/top-p
  • Validation loss and parameter count
# 0) Setup: imports, device, seed, matplotlib (no project imports)
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

# Reproducibility
SEED = 1337
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# Matplotlib defaults
plt.rcParams["figure.figsize"] = (6, 4)
plt.rcParams["figure.dpi"] = 120
Device: cuda

Byte-level tokenizer

We operate directly on UTF-8 bytes (0..255) to avoid vocabulary management. The tokenizer is reversible: decode(encode(text)) ≈ text.

Tokenizer implementation

class ByteTokenizer:
    """Ultra-simple byte-level tokenizer.
    - encode(str) -> LongTensor [N]
    - decode(Tensor[int]) -> str
    - vocab_size = 256
    """

    def encode(self, s: str) -> torch.Tensor:
        return torch.tensor(list(s.encode("utf-8")), dtype=torch.long)

    def decode(self, ids) -> str:
        if isinstance(ids, torch.Tensor):
            ids = ids.tolist()
        return bytes(ids).decode("utf-8", errors="ignore")

    @property
    def vocab_size(self) -> int:
        return 256

Demo: encode/decode

Round-trip correctness and vocabulary size.

# Demo: encode/decode round-trip

tok = ByteTokenizer()
text = "Hello, world!"
ids = tok.encode(text)
print("text -> ids:", ids.shape, ids[:16])
print("ids -> text:", tok.decode(ids))
print("vocab_size:", tok.vocab_size)
text -> ids: torch.Size([13]) tensor([ 72, 101, 108, 108, 111,  44,  32, 119, 111, 114, 108, 100,  33])
ids -> text: Hello, world!
vocab_size: 256

Dataset and (x, y) shift

We’ll build batches from an in-memory corpus string (no external files). - Inputs x: bytes [t, ..., t+block_size-1] - Labels y: same sequence shifted left by 1 (predict next token) - Random start indices for each sample

We also show a small preview and validate the shift property.

class ByteDataset:
    """Builds (x,y) pairs from a single long tensor of token ids.
    Splits into train/val by fraction and samples random contiguous blocks.
    """

    def __init__(self, data: torch.Tensor, block_size: int = 256, split: float = 0.9):
        n = int(len(data) * split)
        self.train = data[:n].clone()
        self.val = data[n:].clone()
        self.block_size = block_size

    def get_batch(self, which: str, batch_size: int, device: torch.device):
        buf = self.train if which == "train" else self.val
        assert len(buf) > self.block_size + 1, "corpus too small for given block_size"
        ix = torch.randint(0, len(buf) - self.block_size - 1, (batch_size,))
        x = torch.stack([buf[i : i + self.block_size] for i in ix])
        y = torch.stack([buf[i + 1 : i + 1 + self.block_size] for i in ix])
        return x.to(device), y.to(device)
# In-memory tiny corpus (repeat to ensure enough length)
corpus = (
    "Hello there! This is a tiny demo corpus for training a byte-level GPT.\n"
    "It is intentionally small so the demo runs fast on CPU.\n"
    "Have fun experimenting with temperature, top-k, and top-p!\n"
) * 50

# Tokenize corpus to bytes
corpus_ids = tok.encode(corpus)
print("corpus length:", len(corpus_ids))
corpus length: 9300
# Build dataset and preview a batch
ds = ByteDataset(corpus_ids, block_size=64, split=0.9)
xb, yb = ds.get_batch("train", batch_size=4, device=device)
print("x,y shapes:", xb.shape, yb.shape)
print("shift check:", torch.equal(xb[:, 1:], yb[:, :-1]))
x,y shapes: torch.Size([4, 64]) torch.Size([4, 64])
shift check: True

Tiny GPT model

Architecture overview: - Token + learned position embeddings - N pre-norm Transformer blocks: self-attention + MLP with residuals - Final layer norm + linear head

Defining the CausalSelfAttention, FeedForward modules, and Block

class CausalSelfAttention(nn.Module):
    def __init__(self, n_embd: int, n_head: int, dropout: float = 0.0):
        super().__init__()
        assert n_embd % n_head == 0
        self.n_head = n_head
        self.d_head = n_embd // n_head
        self.qkv = nn.Linear(n_embd, 3 * n_embd, bias=False)
        self.proj = nn.Linear(n_embd, n_embd, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor):  # (B,T,C)
        B, T, C = x.shape
        qkv = self.qkv(x).view(B, T, 3, self.n_head, self.d_head)
        q, k, v = qkv.unbind(dim=2)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        y = F.scaled_dot_product_attention(
            q,
            k,
            v,
            attn_mask=None,
            dropout_p=self.dropout.p if self.training else 0.0,
            is_causal=True,
        )
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.proj(y)
        return y


class FeedForward(nn.Module):
    def __init__(self, n_embd: int, mult: int = 4, dropout: float = 0.0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, mult * n_embd),
            nn.GELU(),
            nn.Linear(mult * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


class Block(nn.Module):
    def __init__(self, n_embd: int, n_head: int, dropout: float):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.attn = CausalSelfAttention(n_embd, n_head, dropout)
        self.ln2 = nn.LayerNorm(n_embd)
        self.ffn = FeedForward(n_embd, mult=4, dropout=dropout)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x

Sampling utility:

Utility function for sampling from a model’s output, including support for top-k and top-p filtering.

# Sampling utilities (inline)
def top_k_top_p_filtering(
    logits: torch.Tensor, top_k: int | None = None, top_p: float | None = None
):
    B, V = logits.shape
    filtered = logits.clone()
    if top_k is not None and 0 < top_k < V:
        topk_vals, _ = torch.topk(filtered, top_k, dim=-1)
        kth = topk_vals[:, -1].unsqueeze(-1)
        filtered[filtered < kth] = float("-inf")
    if top_p is not None and 0.0 < top_p < 1.0:
        sorted_logits, sorted_idx = torch.sort(filtered, descending=True, dim=-1)
        probs = torch.softmax(sorted_logits, dim=-1)
        cumsum = torch.cumsum(probs, dim=-1)
        mask = cumsum > top_p
        mask[..., 0] = False
        sorted_logits[mask] = float("-inf")
        filtered = torch.full_like(filtered, float("-inf"))
        filtered.scatter_(1, sorted_idx, sorted_logits)
    return filtered

Defining the GPT model

class GPT(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        block_size: int,
        n_layer: int = 4,
        n_head: int = 4,
        n_embd: int = 256,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.block_size = block_size
        self.tok_emb = nn.Embedding(vocab_size, n_embd)
        self.pos_emb = nn.Embedding(block_size, n_embd)
        self.drop = nn.Dropout(dropout)
        self.blocks = nn.ModuleList(
            [Block(n_embd, n_head, dropout) for _ in range(n_layer)]
        )
        self.ln_f = nn.LayerNorm(n_embd)
        self.head = nn.Linear(n_embd, vocab_size, bias=False)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Embedding):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)

    def forward(self, idx: torch.Tensor, targets: torch.Tensor | None = None):
        B, T = idx.shape
        assert T <= self.block_size
        pos = torch.arange(0, T, device=idx.device).unsqueeze(0)
        x = self.tok_emb(idx) + self.pos_emb(pos)
        x = self.drop(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.ln_f(x)
        logits = self.head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss

    @torch.no_grad()
    def generate(
        self,
        idx: torch.Tensor,
        max_new_tokens: int = 200,
        temperature: float = 1.0,
        top_k: int | None = 50,
        top_p: float | None = None,
    ):
        self.eval()
        if idx.size(1) == 0:
            idx = torch.full((idx.size(0), 1), 10, dtype=torch.long, device=idx.device)
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size :]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / max(temperature, 1e-6)
            logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
            probs = torch.softmax(logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, next_id], dim=1)
        return idx


# Instantiate and quick forward
model = GPT(
    vocab_size=256, block_size=64, n_layer=2, n_head=2, n_embd=128, dropout=0.1
).to(device)
with torch.no_grad():
    logits, loss = model(xb, yb)
print("logits:", logits.shape, "loss:", float(loss))
logits: torch.Size([4, 64, 256]) loss: 5.572201251983643

Training loop (demo)

We train with AdamW, gradient clipping, and optional AMP. Every few steps we run a small validation to track loss and optionally sample text.

def estimate_val_loss(
    model: GPT, ds: ByteDataset, iters: int = 50, batch_size: int = 32
) -> float:
    model.eval()
    losses = []
    with torch.no_grad():
        for _ in range(iters):
            xb, yb = ds.get_batch("val", batch_size, device)
            _, loss = model(xb, yb)
            losses.append(loss.item())
    model.train()
    return float(sum(losses) / len(losses))


model.train()
opt = torch.optim.AdamW(
    model.parameters(), lr=3e-4, betas=(0.9, 0.95), weight_decay=0.1
)
scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))

steps = 300  # small for demo
for step in range(1, steps + 1):
    xb, yb = ds.get_batch("train", batch_size=32, device=device)
    with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
        _, loss = model(xb, yb)
    opt.zero_grad(set_to_none=True)
    scaler.scale(loss).backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    scaler.step(opt)
    scaler.update()

    if step % 50 == 0:
        val = estimate_val_loss(model, ds, iters=25, batch_size=32)
        print(f"step {step:4d} | train {loss.item():.4f} | val {val:.4f}")
        # quick sample
        with torch.no_grad():
            seed = tok.encode("\n").unsqueeze(0).to(device)
            out = model.generate(
                seed, max_new_tokens=120, temperature=1.0, top_k=50, top_p=None
            )
            txt = tok.decode(out[0].cpu())
            print("sample:\n" + txt[-200:])
/tmp/ipykernel_1956288/814598473.py:14: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
  scaler = torch.cuda.amp.GradScaler(enabled=(device.type == 'cuda'))
/tmp/ipykernel_1956288/814598473.py:19: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with torch.cuda.amp.autocast(enabled=(device.type == 'cuda')):
step   50 | train 3.2102 | val 3.1415
sample:

vunPtetpTlkav [i ytmux!ve po
aP g i!sh g s rfa ayc! ng--erle
In,oipe ePm(ot g.lxrvo -Hpe ynxiinoP d xo t s vuxpe.TPr
step  100 | train 2.0894 | val 2.0733
sample:

heHe io.
therp-louo edefmisubg ent wnt ifa y tCee s, a s.
H or
Itin
HPUoplU. t -ll Gin 
e ere blfeoiir fs is t te!p
step  150 | train 1.3808 | val 1.3731
sample:

Heby elhelo inde sl so e CPU.nls nthemat ders timpemounmshensaveU fant funt eo dhon orao fokr ire trpenisyt, fhG.
Hiis w
step  200 | train 0.9038 | val 0.8888
sample:

 indemo tonPU.oplns :mavel fus h ton for eraihwe  byntuntemxpereryty GPThint-le, d des t usorap, rasrop-ls i t tornwioe
step  250 | train 0.5369 | val 0.5361
sample:

Uve fexpinthy temeruns imendeomh in tis deny furo s te, mexpemors funts trvenin]ratintsly teon alloally o colly .
Ha$(h
step  300 | train 0.3238 | val 0.3228
sample:

avel GPT.
He it inteundemois o mksl nallnthe CPU.
Hal inse io fun-lPT.
Have y s e funtfunste emaxpenstrentinontionU.
It 

Manual sampling cell

Provide a prompt and generate continuation with adjustable temperature, top_k, and top_p.

model.eval()
# Try different prompts and decoding controls
prompt = "Once upon a time "
start_ids = tok.encode(prompt).unsqueeze(0).to(device)

with torch.no_grad():
    out = model.generate(
        start_ids,
        max_new_tokens=200,
        temperature=0.9,
        top_k=50,
        top_p=None,
    )

print(tok.decode(out[0].cpu())[-300:])
Once upon a time funers foastrng forat ainwith tempe-p-lls teutorenand ths ere, re, and torp-k, td and top!o top!
Ithehe! t rpp
He!thire Thereis tinihie is a a tiny te-llos y demo g tra forentunT.
It ininina wio the

2.6 Validation loss

Compute an average validation loss over a few batches for a quick quality check.

model.eval()
losses = []
with torch.no_grad():
    for _ in range(50):
        xb, yb = ds.get_batch("val", batch_size=32, device=device)
        _, loss = model(xb, yb)
        losses.append(loss.item())
print(f"val loss (approx): {sum(losses)/len(losses):.4f}")
val loss (approx): 0.3233

Parameter count and model size

Print the total number of trainable parameters and approximate size.

def count_parameters(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


num_params = count_parameters(model)
size_mb = num_params * 4 / (1024**2)  # float32
print(f"parameters: {num_params:,} (~{size_mb:.2f} MB fp32)")
parameters: 469,504 (~1.79 MB fp32)