Skip to content

Instantly share code, notes, and snippets.

@skannan-maf
Forked from karpathy/microgpt.py
Last active February 16, 2026 18:07
Show Gist options
  • Select an option

  • Save skannan-maf/ae25bddd54ad0b9d137030c3c775b3a3 to your computer and use it in GitHub Desktop.

Select an option

Save skannan-maf/ae25bddd54ad0b9d137030c3c775b3a3 to your computer and use it in GitHub Desktop.
microgpt
"""
The most atomic way to train and inference a GPT in pure, dependency-free Python.
This file is the complete algorithm.
Everything else is just efficiency.
--
@karpathy
1. Commented out a seemingly redundant "rmsnorm"
2. Increased the width (block_size) from 16 to 32 to support long names
3. Modified the output samples so that first 50% of output starts with "sar"
4. Curated an Indian dataset from other repositories (file with full names)
Added the following using Codex in about 10 mins:
1. Checkpointing & saving a model
2. Resume-from-checkpoint / saved-model
3. Predict-only mode
4. Random seed control
5. Configurable training steps
6. Configurable sampling temperature for output
7. Configurable output count
@skannan-maf
"""
import os # os.path.exists
import math # math.log, math.exp
import random # random.seed, random.choices, random.gauss, random.shuffle
import argparse
import pickle
random.seed(42) # Let there be order among chaos
# Let there be an input dataset `docs`: list[str] of documents (e.g. a dataset of names)
INPUT_FILENAME = 'Indian-names.txt'
INPUT_URL = 'https://raw.githubusercontent.com/skannan-maf/Indian-names/refs/heads/main/Indian-names.txt'
DEFAULT_CHECKPOINT_DIR = 'checkpoints'
DEFAULT_PARAMS_FILENAME = os.path.join(DEFAULT_CHECKPOINT_DIR, 'microgpt_params.pkl')
#INPUT_FILENAME = 'Indian-Firstnames.txt'
#INPUT_URL= 'https://raw.githubusercontent.com/MASTREX/List-of-Indian-Names/refs/heads/master/2.%20First.txt'
#org:input.txt
#orgurl:'https://raw.githubusercontent.com/karpathy/makemore/refs/heads/master/names.txt'
if not os.path.exists(INPUT_FILENAME):
import urllib.request
names_url = urllib.request.urlretrieve(INPUT_URL, INPUT_FILENAME)
docs = [l.strip() for l in open(INPUT_FILENAME).read().strip().split('\n') if l.strip()] # list[str] of documents
random.shuffle(docs)
print(f"num docs: {len(docs)}")
parser = argparse.ArgumentParser()
parser.add_argument('--reload', action='store_true', help='Reload saved model parameters before training')
parser.add_argument('--predict', action='store_true', help='Skip training and run inference using saved model parameters')
parser.add_argument('--model', '--params-file', dest='model', default=DEFAULT_PARAMS_FILENAME, help='Model checkpoint file used to load/save parameters')
parser.add_argument('--num-steps', type=int, default=10, help='Number of training steps')
parser.add_argument('--num-samples', type=int, default=100, help='Number of output samples to generate during inference')
parser.add_argument('--temperature', type=float, default=0.7, help='Sampling temperature for inference')
parser.add_argument('--seed', type=int, default=None, help='Random seed used only for inference sampling')
parser.add_argument('--prefix', type=str, default='sar', help='Force this prefix on the first half of generated samples')
args = parser.parse_args()
if args.num_steps < 0:
raise ValueError('--num-steps must be >= 0')
if args.num_samples < 0:
raise ValueError('--num-samples must be >= 0')
if args.temperature <= 0:
raise ValueError('--temperature must be > 0')
if not os.path.dirname(args.model):
args.model = os.path.join(DEFAULT_CHECKPOINT_DIR, args.model)
# Let there be a Tokenizer to translate strings to discrete symbols and back
uchars = sorted(set(''.join(docs))) # unique characters in the dataset become token ids 0..n-1
BOS = len(uchars) # token id for the special Beginning of Sequence (BOS) token
vocab_size = len(uchars) + 1 # total number of unique tokens, +1 is for BOS
print(f"vocab size: {vocab_size}")
# Let there be Autograd, to recursively apply the chain rule through a computation graph
class Value:
__slots__ = ('data', 'grad', '_children', '_local_grads') # Python optimization for memory usage
def __init__(self, data, children=(), local_grads=()):
self.data = data # scalar value of this node calculated during forward pass
self.grad = 0 # derivative of the loss w.r.t. this node, calculated in backward pass
self._children = children # children of this node in the computation graph
self._local_grads = local_grads # local derivative of this node w.r.t. its children
def __add__(self, other):
other = other if isinstance(other, Value) else Value(other)
return Value(self.data + other.data, (self, other), (1, 1))
def __mul__(self, other):
other = other if isinstance(other, Value) else Value(other)
return Value(self.data * other.data, (self, other), (other.data, self.data))
def __pow__(self, other): return Value(self.data**other, (self,), (other * self.data**(other-1),))
def log(self): return Value(math.log(self.data), (self,), (1/self.data,))
def exp(self): return Value(math.exp(self.data), (self,), (math.exp(self.data),))
def relu(self): return Value(max(0, self.data), (self,), (float(self.data > 0),))
def __neg__(self): return self * -1
def __radd__(self, other): return self + other
def __sub__(self, other): return self + (-other)
def __rsub__(self, other): return other + (-self)
def __rmul__(self, other): return self * other
def __truediv__(self, other): return self * other**-1
def __rtruediv__(self, other): return other * self**-1
def backward(self):
topo = []
visited = set()
def build_topo(v):
if v not in visited:
visited.add(v)
for child in v._children:
build_topo(child)
topo.append(v)
build_topo(self)
self.grad = 1
for v in reversed(topo):
for child, local_grad in zip(v._children, v._local_grads):
child.grad += local_grad * v.grad
# Initialize the parameters, to store the knowledge of the model.
n_embd = 16 # embedding dimension
n_head = 4 # number of attention heads
n_layer = 1 # number of layers
block_size = 32 # org:16 maximum sequence length
head_dim = n_embd // n_head # dimension of each head
matrix = lambda nout, nin, std=0.08: [[Value(random.gauss(0, std)) for _ in range(nin)] for _ in range(nout)]
state_dict = {'wte': matrix(vocab_size, n_embd), 'wpe': matrix(block_size, n_embd), 'lm_head': matrix(vocab_size, n_embd)}
for i in range(n_layer):
state_dict[f'layer{i}.attn_wq'] = matrix(n_embd, n_embd)
state_dict[f'layer{i}.attn_wk'] = matrix(n_embd, n_embd)
state_dict[f'layer{i}.attn_wv'] = matrix(n_embd, n_embd)
state_dict[f'layer{i}.attn_wo'] = matrix(n_embd, n_embd)
state_dict[f'layer{i}.mlp_fc1'] = matrix(4 * n_embd, n_embd)
state_dict[f'layer{i}.mlp_fc2'] = matrix(n_embd, 4 * n_embd)
def export_state(sd):
return {name: [[p.data for p in row] for row in mat] for name, mat in sd.items()}
def save_state_to_file(sd, path):
path_dir = os.path.dirname(path)
if path_dir:
os.makedirs(path_dir, exist_ok=True)
with open(path, 'wb') as f:
pickle.dump(export_state(sd), f)
def checkpoint_path_for_step(model_path, step_num):
model_dir, model_file = os.path.split(model_path)
stem, ext = os.path.splitext(model_file)
return os.path.join(model_dir, f"{stem}_step{step_num}{ext}")
def load_state(sd, raw_state):
for name, mat in sd.items():
if name not in raw_state:
raise KeyError(f"missing key in pickle: {name}")
raw_mat = raw_state[name]
if len(raw_mat) != len(mat):
raise ValueError(f"shape mismatch for {name}: expected {len(mat)} rows, got {len(raw_mat)}")
for row_idx, row in enumerate(mat):
raw_row = raw_mat[row_idx]
if len(raw_row) != len(row):
raise ValueError(
f"shape mismatch for {name}[{row_idx}]: expected {len(row)} cols, got {len(raw_row)}"
)
for col_idx, p in enumerate(row):
p.data = float(raw_row[col_idx])
if args.predict or args.reload:
if os.path.exists(args.model):
with open(args.model, 'rb') as f:
saved_state = pickle.load(f)
load_state(state_dict, saved_state)
print(f"reloaded parameters from: {args.model}")
elif args.predict:
raise FileNotFoundError(f"--predict requires a saved model file, but not found: {args.model}")
else:
print(f"--reload was set but file not found: {args.model}")
params = [p for mat in state_dict.values() for row in mat for p in row] # flatten params into a single list[Value]
print(f"num params: {len(params)}")
# Define the model architecture: a stateless function mapping token sequence and parameters to logits over what comes next.
# Follow GPT-2, blessed among the GPTs, with minor differences: layernorm -> rmsnorm, no biases, GeLU -> ReLU
def linear(x, w):
return [sum(wi * xi for wi, xi in zip(wo, x)) for wo in w]
def softmax(logits):
max_val = max(val.data for val in logits)
exps = [(val - max_val).exp() for val in logits]
total = sum(exps)
return [e / total for e in exps]
def rmsnorm(x):
ms = sum(xi * xi for xi in x) / len(x)
scale = (ms + 1e-5) ** -0.5
return [xi * scale for xi in x]
def gpt(token_id, pos_id, keys, values):
tok_emb = state_dict['wte'][token_id] # token embedding
pos_emb = state_dict['wpe'][pos_id] # position embedding
x = [t + p for t, p in zip(tok_emb, pos_emb)] # joint token and position embedding
#x = rmsnorm(x) #Commented by Sarnath: We do this again in FOR loop below. So..
for li in range(n_layer):
# 1) Multi-head attention block
x_residual = x
x = rmsnorm(x)
q = linear(x, state_dict[f'layer{li}.attn_wq'])
k = linear(x, state_dict[f'layer{li}.attn_wk'])
v = linear(x, state_dict[f'layer{li}.attn_wv'])
keys[li].append(k)
values[li].append(v)
x_attn = []
for h in range(n_head):
hs = h * head_dim
q_h = q[hs:hs+head_dim]
k_h = [ki[hs:hs+head_dim] for ki in keys[li]]
v_h = [vi[hs:hs+head_dim] for vi in values[li]]
attn_logits = [sum(q_h[j] * k_h[t][j] for j in range(head_dim)) / head_dim**0.5 for t in range(len(k_h))]
attn_weights = softmax(attn_logits)
head_out = [sum(attn_weights[t] * v_h[t][j] for t in range(len(v_h))) for j in range(head_dim)]
x_attn.extend(head_out)
x = linear(x_attn, state_dict[f'layer{li}.attn_wo'])
x = [a + b for a, b in zip(x, x_residual)]
# 2) MLP block
x_residual = x
x = rmsnorm(x)
x = linear(x, state_dict[f'layer{li}.mlp_fc1'])
x = [xi.relu() for xi in x]
x = linear(x, state_dict[f'layer{li}.mlp_fc2'])
x = [a + b for a, b in zip(x, x_residual)]
logits = linear(x, state_dict['lm_head'])
return logits
# Let there be Adam, the blessed optimizer and its buffers
learning_rate, beta1, beta2, eps_adam = 0.01, 0.85, 0.99, 1e-8
m = [0.0] * len(params) # first moment buffer
v = [0.0] * len(params) # second moment buffer
# Repeat in sequence
num_steps = args.num_steps # number of training steps #org: 1000
user_choice = 'y'
if not args.predict:
checkpoint_every = 1000
last_checkpoint_path = None
try:
for step in range(num_steps):
# Take single document, tokenize it, surround it with BOS special token on both sides
doc = docs[step % len(docs)]
tokens = [BOS] + [uchars.index(ch) for ch in doc] + [BOS]
n = min(block_size, len(tokens) - 1)
# Forward the token sequence through the model, building up the computation graph all the way to the loss.
keys, values = [[] for _ in range(n_layer)], [[] for _ in range(n_layer)]
losses = []
for pos_id in range(n):
token_id, target_id = tokens[pos_id], tokens[pos_id + 1]
logits = gpt(token_id, pos_id, keys, values)
probs = softmax(logits)
loss_t = -probs[target_id].log()
losses.append(loss_t)
loss = (1 / n) * sum(losses) # final average loss over the document sequence. May yours be low.
# Backward the loss, calculating the gradients with respect to all model parameters.
loss.backward()
# Adam optimizer update: update the model parameters based on the corresponding gradients.
lr_t = learning_rate * (1 - step / num_steps) # linear learning rate decay
for i, p in enumerate(params):
m[i] = beta1 * m[i] + (1 - beta1) * p.grad
v[i] = beta2 * v[i] + (1 - beta2) * p.grad ** 2
m_hat = m[i] / (1 - beta1 ** (step + 1))
v_hat = v[i] / (1 - beta2 ** (step + 1))
p.data -= lr_t * m_hat / (v_hat ** 0.5 + eps_adam)
p.grad = 0
print(f"step {step+1:4d} / {num_steps:4d} | loss {loss.data:.4f}")
if (step + 1) % checkpoint_every == 0:
ckpt_path = checkpoint_path_for_step(args.model, step + 1)
save_state_to_file(state_dict, ckpt_path)
print(f"saved checkpoint to: {ckpt_path}")
if last_checkpoint_path and last_checkpoint_path != ckpt_path and os.path.exists(last_checkpoint_path):
os.remove(last_checkpoint_path)
print(f"deleted old checkpoint: {last_checkpoint_path}")
last_checkpoint_path = ckpt_path
except KeyboardInterrupt:
print('\nTraining interrupted by user (Ctrl+C).')
user_choice = input('Training exited now....Do you want to do inference now (y/n)')
except Exception as e:
import traceback
traceback.print_exc()
user_choice = input('Training exited now....Do you want to do inference now (y/n)')
if (user_choice != 'y') and (user_choice != 'Y'):
import sys
sys.exit(-1)
save_state_to_file(state_dict, args.model)
print(f"saved parameters to: {args.model}")
# Inference: may the model babble back to us
if args.seed is not None:
random.seed(args.seed)
temperature = args.temperature # #org:0.5 in (0, 1], control the "creativity" of generated text, low to high
for ch in args.prefix:
if ch not in uchars:
raise ValueError(f"prefix character {ch!r} not present in dataset vocabulary")
static_prefix_tokens = [uchars.index(ch) for ch in args.prefix]
print("\n--- inference (new, hallucinated names) ---")
for sample_idx in range(args.num_samples):
keys, values = [[] for _ in range(n_layer)], [[] for _ in range(n_layer)]
token_id = BOS
sample = []
for pos_id in range(block_size):
logits = gpt(token_id, pos_id, keys, values)
probs = softmax([l / temperature for l in logits])
token_id = random.choices(range(vocab_size), weights=[p.data for p in probs])[0]
if token_id == BOS:
break
if (sample_idx < args.num_samples//2) and (pos_id >= 0) and (pos_id < len(static_prefix_tokens)): #Org: no IF block
token_id = static_prefix_tokens[pos_id]
sample.append(uchars[token_id])
print(f"sample {sample_idx+1:2d}: {''.join(sample)}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment