Created
November 23, 2022 13:38
-
-
Save wiseodd/b8d57fa029f876e00b336b7b3b5052bd to your computer and use it in GitHub Desktop.
Last-Layer Laplace for Image2Image Problems
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Loss: 0.002\n", | |
"Loss: 0.001\n", | |
"Loss: 0.000\n", | |
"Loss: 0.000\n", | |
"Loss: 0.000\n" | |
] | |
} | |
], | |
"source": [ | |
"import torch\n", | |
"from torch import nn, optim\n", | |
"import torch.nn.functional as F\n", | |
"import torchvision as tv\n", | |
"import torchvision.transforms as transforms\n", | |
"\n", | |
"\n", | |
"transform = transforms.Compose([transforms.ToTensor()])\n", | |
"\n", | |
"batch_size = 128\n", | |
"\n", | |
"trainset = tv.datasets.MNIST(\n", | |
" root='~/Datasets', train=True, transform=transform\n", | |
")\n", | |
"trainloader = torch.utils.data.DataLoader(\n", | |
" trainset, batch_size=batch_size, shuffle=True\n", | |
")\n", | |
"\n", | |
"testset = tv.datasets.MNIST(\n", | |
" root='~/Datasets', train=False, transform=transform\n", | |
")\n", | |
"testloader = torch.utils.data.DataLoader(\n", | |
" testset, batch_size=batch_size, shuffle=False\n", | |
")\n", | |
"\n", | |
"class Model(nn.Module):\n", | |
" def __init__(self):\n", | |
" super().__init__()\n", | |
"\n", | |
" self.feature_extractor = nn.Sequential(\n", | |
" nn.Conv2d(1, 100, 3, 3, 1),\n", | |
" nn.Sigmoid(), \n", | |
" )\n", | |
" self.last_layer = nn.Sequential(\n", | |
" nn.ConvTranspose2d(100, 1, 3, 3, 1, bias=False),\n", | |
" nn.Flatten(1)\n", | |
" )\n", | |
"\n", | |
" def forward(self, x):\n", | |
" x = self.feature_extractor(x)\n", | |
" return self.last_layer(x)\n", | |
"\n", | |
"\n", | |
"model = Model().cuda() \n", | |
"opt = optim.Adam(model.parameters(), lr=1e-3)\n", | |
"\n", | |
"for epoch in range(5):\n", | |
" for x, _ in trainloader:\n", | |
" x = x.cuda()\n", | |
" out = model(x)\n", | |
" loss = F.mse_loss(out, x.flatten(1))\n", | |
" loss.backward()\n", | |
" opt.step()\n", | |
" opt.zero_grad()\n", | |
"\n", | |
" print(f'Loss: {loss.item():.3f}')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Using BackPACK" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"torch.Size([100, 1, 3, 3])\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "177e2d70fc1641128b68f1e2002fe451", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/469 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"from backpack import extend, backpack\n", | |
"from backpack.extensions import DiagGGNExact\n", | |
"from tqdm.notebook import tqdm\n", | |
"\n", | |
"lastlayer = extend(model.last_layer)\n", | |
"lossfunc = extend(nn.MSELoss(reduction='sum'))\n", | |
"\n", | |
"print(model.last_layer[0].weight.shape)\n", | |
"diag_G = torch.zeros_like(model.last_layer[0].weight)\n", | |
"\n", | |
"for x, _ in tqdm(trainloader):\n", | |
" x = x.cuda()\n", | |
" loss = lossfunc(lastlayer(model.feature_extractor(x)), x.flatten(1))\n", | |
"\n", | |
" with backpack(DiagGGNExact()):\n", | |
" loss.backward()\n", | |
"\n", | |
" diag_G += model.last_layer[0].weight.diag_ggn_exact" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Predictions" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import unfoldNd\n", | |
"\n", | |
"prec0 = 1\n", | |
"\n", | |
"# Laplace cov\n", | |
"diag_Sigma = 1/(diag_G + prec0)\n", | |
"diag_Sigma = diag_Sigma.transpose(0, 1).flatten(1)\n", | |
"\n", | |
"# diag_Sigma.shape should be (c_out, c_in*k*k\n", | |
"# )\n", | |
"assert len(diag_Sigma.shape) == 2 and diag_Sigma.shape == (1, 100*3*3)\n", | |
"\n", | |
"# Following the last layer of the model\n", | |
"unfold_transpose = unfoldNd.UnfoldTransposeNd(\n", | |
" kernel_size=3, dilation=1, padding=1, stride=3\n", | |
")\n", | |
"\n", | |
"@torch.no_grad()\n", | |
"def reconstruct(x):\n", | |
" phi = model.feature_extractor(x)\n", | |
" \n", | |
" # MAP prediction\n", | |
" mean_pred = model.last_layer(phi).reshape(x.shape)\n", | |
"\n", | |
" # Variance\n", | |
" J_pred = unfold_transpose(phi)\n", | |
" var_pred = torch.einsum('bij,ki,bij->bkj', J_pred, diag_Sigma, J_pred).reshape(mean_pred.shape)\n", | |
"\n", | |
" return mean_pred.cpu().numpy(), var_pred.cpu().numpy()\n", | |
"\n", | |
"x_recons = []\n", | |
"\n", | |
"for x, _ in testloader:\n", | |
" x = x.cuda()\n", | |
" x_recons.append(reconstruct(x))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3.9.5 ('base')", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.5" | |
}, | |
"orig_nbformat": 4, | |
"vscode": { | |
"interpreter": { | |
"hash": "2478d115bd7922f0ed574df2e779addadfa7a60215a3648b6e3cf31cb3b0451d" | |
} | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment