We’ll build batches from an in-memory corpus string (no external files). - Inputs x: bytes [t, ..., t+block_size-1] - Labels y: same sequence shifted left by 1 (predict next token) - Random start indices for each sample
We also show a small preview and validate the shift property.
class ByteDataset:"""Builds (x,y) pairs from a single long tensor of token ids. Splits into train/val by fraction and samples random contiguous blocks. """def__init__(self, data: torch.Tensor, block_size: int=256, split: float=0.9): n =int(len(data) * split)self.train = data[:n].clone()self.val = data[n:].clone()self.block_size = block_sizedef get_batch(self, which: str, batch_size: int, device: torch.device): buf =self.train if which =="train"elseself.valassertlen(buf) >self.block_size +1, "corpus too small for given block_size" ix = torch.randint(0, len(buf) -self.block_size -1, (batch_size,)) x = torch.stack([buf[i : i +self.block_size] for i in ix]) y = torch.stack([buf[i +1 : i +1+self.block_size] for i in ix])return x.to(device), y.to(device)
# In-memory tiny corpus (repeat to ensure enough length)corpus = ("Hello there! This is a tiny demo corpus for training a byte-level GPT.\n""It is intentionally small so the demo runs fast on CPU.\n""Have fun experimenting with temperature, top-k, and top-p!\n") *50# Tokenize corpus to bytescorpus_ids = tok.encode(corpus)print("corpus length:", len(corpus_ids))
Architecture overview: - Token + learned position embeddings - N pre-norm Transformer blocks: self-attention + MLP with residuals - Final layer norm + linear head
Defining the CausalSelfAttention, FeedForward modules, and Block
class CausalSelfAttention(nn.Module):def__init__(self, n_embd: int, n_head: int, dropout: float=0.0):super().__init__()assert n_embd % n_head ==0self.n_head = n_headself.d_head = n_embd // n_headself.qkv = nn.Linear(n_embd, 3* n_embd, bias=False)self.proj = nn.Linear(n_embd, n_embd, bias=False)self.dropout = nn.Dropout(dropout)def forward(self, x: torch.Tensor): # (B,T,C) B, T, C = x.shape qkv =self.qkv(x).view(B, T, 3, self.n_head, self.d_head) q, k, v = qkv.unbind(dim=2) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) y = F.scaled_dot_product_attention( q, k, v, attn_mask=None, dropout_p=self.dropout.p ifself.training else0.0, is_causal=True, ) y = y.transpose(1, 2).contiguous().view(B, T, C) y =self.proj(y)return yclass FeedForward(nn.Module):def__init__(self, n_embd: int, mult: int=4, dropout: float=0.0):super().__init__()self.net = nn.Sequential( nn.Linear(n_embd, mult * n_embd), nn.GELU(), nn.Linear(mult * n_embd, n_embd), nn.Dropout(dropout), )def forward(self, x):returnself.net(x)class Block(nn.Module):def__init__(self, n_embd: int, n_head: int, dropout: float):super().__init__()self.ln1 = nn.LayerNorm(n_embd)self.attn = CausalSelfAttention(n_embd, n_head, dropout)self.ln2 = nn.LayerNorm(n_embd)self.ffn = FeedForward(n_embd, mult=4, dropout=dropout)def forward(self, x): x = x +self.attn(self.ln1(x)) x = x +self.ffn(self.ln2(x))return x
Sampling utility:
Utility function for sampling from a model’s output, including support for top-k and top-p filtering.
/tmp/ipykernel_1956288/814598473.py:14: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
scaler = torch.cuda.amp.GradScaler(enabled=(device.type == 'cuda'))
/tmp/ipykernel_1956288/814598473.py:19: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast(enabled=(device.type == 'cuda')):
step 50 | train 3.2102 | val 3.1415
sample:
vunPtetpTlkav [i ytmux!ve po
aP g i!sh g s rfa ayc! ng--erle
In,oipe ePm(ot g.lxrvo -Hpe ynxiinoP d xo t s vuxpe.TPr
step 100 | train 2.0894 | val 2.0733
sample:
heHe io.
therp-louo edefmisubg ent wnt ifa y tCee s, a s.
H or
Itin
HPUoplU. t -ll Gin
e ere blfeoiir fs is t te!p
step 150 | train 1.3808 | val 1.3731
sample:
Heby elhelo inde sl so e CPU.nls nthemat ders timpemounmshensaveU fant funt eo dhon orao fokr ire trpenisyt, fhG.
Hiis w
step 200 | train 0.9038 | val 0.8888
sample:
indemo tonPU.oplns :mavel fus h ton for eraihwe byntuntemxpereryty GPThint-le, d des t usorap, rasrop-ls i t tornwioe
step 250 | train 0.5369 | val 0.5361
sample:
Uve fexpinthy temeruns imendeomh in tis deny furo s te, mexpemors funts trvenin]ratintsly teon alloally o colly .
Ha$(h
step 300 | train 0.3238 | val 0.3228
sample:
avel GPT.
He it inteundemois o mksl nallnthe CPU.
Hal inse io fun-lPT.
Have y s e funtfunste emaxpenstrentinontionU.
It
Manual sampling cell
Provide a prompt and generate continuation with adjustable temperature, top_k, and top_p.
model.eval()# Try different prompts and decoding controlsprompt ="Once upon a time "start_ids = tok.encode(prompt).unsqueeze(0).to(device)with torch.no_grad(): out = model.generate( start_ids, max_new_tokens=200, temperature=0.9, top_k=50, top_p=None, )print(tok.decode(out[0].cpu())[-300:])
Once upon a time funers foastrng forat ainwith tempe-p-lls teutorenand ths ere, re, and torp-k, td and top!o top!
Ithehe! t rpp
He!thire Thereis tinihie is a a tiny te-llos y demo g tra forentunT.
It ininina wio the
2.6 Validation loss
Compute an average validation loss over a few batches for a quick quality check.