Skip to content

Instantly share code, notes, and snippets.

@maujim
Created July 22, 2025 22:29
Show Gist options
  • Save maujim/6035253aed015a3592b46a69dbd947a8 to your computer and use it in GitHub Desktop.
Save maujim/6035253aed015a3592b46a69dbd947a8 to your computer and use it in GitHub Desktop.
# single file implementation of forward-forward algorithm
# https://arxiv.org/pdf/2212.13345 (hinton, 2022)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda
class ForwardForwardLayer(nn.Module):
def __init__(self, in_dim, out_dim, threshold=2.0):
super().__init__()
self.linear = nn.Linear(in_dim, out_dim)
self.relu = nn.ReLU()
self.threshold = threshold
def forward(self, x):
x_norm = x / (x.norm(2, dim=1, keepdim=True) + 1e-4)
return self.relu(self.linear(x_norm))
def goodness(self, x):
return (x ** 2).mean(dim=1)
def forward_forward_loss(self, x_pos, x_neg):
g_pos = self.goodness(self.forward(x_pos))
g_neg = self.goodness(self.forward(x_neg))
return F.softplus(torch.cat([-g_pos + self.threshold, g_neg - self.threshold])).mean()
class ForwardForwardNet(nn.Module):
def __init__(self, dims, threshold=2.0):
super().__init__()
self.layers = nn.ModuleList([
ForwardForwardLayer(dims[i], dims[i+1], threshold=threshold)
for i in range(len(dims) - 1)
])
def forward_through_layers(self, x, stop_at=None):
for i, layer in enumerate(self.layers):
x = layer(x)
if stop_at is not None and i == stop_at:
break
return x
def predict(self, x):
with torch.no_grad():
goodness_per_label = []
for label in range(10):
x_overlay = overlay_y_on_x(x.clone(), torch.full((x.size(0),), label, dtype=torch.long, device=x.device))
g_total = torch.zeros(x.size(0), device=x.device)
for layer in self.layers:
x_overlay = layer(x_overlay)
g_total += layer.goodness(x_overlay)
goodness_per_label.append(g_total.unsqueeze(1))
return torch.cat(goodness_per_label, dim=1).argmax(dim=1)
def overlay_y_on_x(x, y):
x[:, :10] = 0.0
x[torch.arange(x.size(0)), y] = x.max()
return x
def get_mnist_loaders(train_bs=128, test_bs=512):
transform = Compose([
ToTensor(),
Normalize((0.1307,), (0.3081,)),
Lambda(lambda x: x.view(-1))
])
train_set = MNIST('./data', train=True, download=True, transform=transform)
test_set = MNIST('./data', train=False, download=True, transform=transform)
return DataLoader(train_set, batch_size=train_bs, shuffle=True), DataLoader(test_set, batch_size=test_bs)
def train_ff(model, train_loader, num_epochs=5, lr=0.03, device='cuda'):
model.to(device)
for i, layer in enumerate(model.layers):
print(f'training layer {i}...')
optimizer = torch.optim.Adam(layer.parameters(), lr=lr)
for epoch in range(num_epochs):
for x, y in train_loader:
x, y = x.to(device), y.to(device)
x_pos = overlay_y_on_x(x.clone(), y)
y_neg = y[torch.randperm(y.size(0))]
x_neg = overlay_y_on_x(x.clone(), y_neg)
loss = layer.forward_forward_loss(x_pos, x_neg)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def evaluate(model, loader, device='cuda'):
model.eval()
correct = total = 0
with torch.no_grad():
for x, y in loader:
x, y = x.to(device), y.to(device)
preds = model.predict(x)
correct += (preds == y).sum().item()
total += y.size(0)
return 1.0 - correct / total # return error rate
# kick it off
if __name__ == '__main__':
torch.manual_seed(42)
train_loader, test_loader = get_mnist_loaders()
ff_model = ForwardForwardNet([784, 500, 500])
train_ff(ff_model, train_loader)
train_err = evaluate(ff_model, train_loader)
test_err = evaluate(ff_model, test_loader)
print(f"train error: {train_err:.4f}")
print(f"test error: {test_err:.4f}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment