Created
July 22, 2025 22:29
-
-
Save maujim/6035253aed015a3592b46a69dbd947a8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # single file implementation of forward-forward algorithm | |
| # https://arxiv.org/pdf/2212.13345 (hinton, 2022) | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| from torchvision.datasets import MNIST | |
| from torchvision.transforms import Compose, ToTensor, Normalize, Lambda | |
| class ForwardForwardLayer(nn.Module): | |
| def __init__(self, in_dim, out_dim, threshold=2.0): | |
| super().__init__() | |
| self.linear = nn.Linear(in_dim, out_dim) | |
| self.relu = nn.ReLU() | |
| self.threshold = threshold | |
| def forward(self, x): | |
| x_norm = x / (x.norm(2, dim=1, keepdim=True) + 1e-4) | |
| return self.relu(self.linear(x_norm)) | |
| def goodness(self, x): | |
| return (x ** 2).mean(dim=1) | |
| def forward_forward_loss(self, x_pos, x_neg): | |
| g_pos = self.goodness(self.forward(x_pos)) | |
| g_neg = self.goodness(self.forward(x_neg)) | |
| return F.softplus(torch.cat([-g_pos + self.threshold, g_neg - self.threshold])).mean() | |
| class ForwardForwardNet(nn.Module): | |
| def __init__(self, dims, threshold=2.0): | |
| super().__init__() | |
| self.layers = nn.ModuleList([ | |
| ForwardForwardLayer(dims[i], dims[i+1], threshold=threshold) | |
| for i in range(len(dims) - 1) | |
| ]) | |
| def forward_through_layers(self, x, stop_at=None): | |
| for i, layer in enumerate(self.layers): | |
| x = layer(x) | |
| if stop_at is not None and i == stop_at: | |
| break | |
| return x | |
| def predict(self, x): | |
| with torch.no_grad(): | |
| goodness_per_label = [] | |
| for label in range(10): | |
| x_overlay = overlay_y_on_x(x.clone(), torch.full((x.size(0),), label, dtype=torch.long, device=x.device)) | |
| g_total = torch.zeros(x.size(0), device=x.device) | |
| for layer in self.layers: | |
| x_overlay = layer(x_overlay) | |
| g_total += layer.goodness(x_overlay) | |
| goodness_per_label.append(g_total.unsqueeze(1)) | |
| return torch.cat(goodness_per_label, dim=1).argmax(dim=1) | |
| def overlay_y_on_x(x, y): | |
| x[:, :10] = 0.0 | |
| x[torch.arange(x.size(0)), y] = x.max() | |
| return x | |
| def get_mnist_loaders(train_bs=128, test_bs=512): | |
| transform = Compose([ | |
| ToTensor(), | |
| Normalize((0.1307,), (0.3081,)), | |
| Lambda(lambda x: x.view(-1)) | |
| ]) | |
| train_set = MNIST('./data', train=True, download=True, transform=transform) | |
| test_set = MNIST('./data', train=False, download=True, transform=transform) | |
| return DataLoader(train_set, batch_size=train_bs, shuffle=True), DataLoader(test_set, batch_size=test_bs) | |
| def train_ff(model, train_loader, num_epochs=5, lr=0.03, device='cuda'): | |
| model.to(device) | |
| for i, layer in enumerate(model.layers): | |
| print(f'training layer {i}...') | |
| optimizer = torch.optim.Adam(layer.parameters(), lr=lr) | |
| for epoch in range(num_epochs): | |
| for x, y in train_loader: | |
| x, y = x.to(device), y.to(device) | |
| x_pos = overlay_y_on_x(x.clone(), y) | |
| y_neg = y[torch.randperm(y.size(0))] | |
| x_neg = overlay_y_on_x(x.clone(), y_neg) | |
| loss = layer.forward_forward_loss(x_pos, x_neg) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| def evaluate(model, loader, device='cuda'): | |
| model.eval() | |
| correct = total = 0 | |
| with torch.no_grad(): | |
| for x, y in loader: | |
| x, y = x.to(device), y.to(device) | |
| preds = model.predict(x) | |
| correct += (preds == y).sum().item() | |
| total += y.size(0) | |
| return 1.0 - correct / total # return error rate | |
| # kick it off | |
| if __name__ == '__main__': | |
| torch.manual_seed(42) | |
| train_loader, test_loader = get_mnist_loaders() | |
| ff_model = ForwardForwardNet([784, 500, 500]) | |
| train_ff(ff_model, train_loader) | |
| train_err = evaluate(ff_model, train_loader) | |
| test_err = evaluate(ff_model, test_loader) | |
| print(f"train error: {train_err:.4f}") | |
| print(f"test error: {test_err:.4f}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment