Skip to content

Instantly share code, notes, and snippets.

@wiseodd
Created November 23, 2022 13:38
Show Gist options
  • Save wiseodd/b8d57fa029f876e00b336b7b3b5052bd to your computer and use it in GitHub Desktop.
Save wiseodd/b8d57fa029f876e00b336b7b3b5052bd to your computer and use it in GitHub Desktop.
Last-Layer Laplace for Image2Image Problems
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss: 0.002\n",
"Loss: 0.001\n",
"Loss: 0.000\n",
"Loss: 0.000\n",
"Loss: 0.000\n"
]
}
],
"source": [
"import torch\n",
"from torch import nn, optim\n",
"import torch.nn.functional as F\n",
"import torchvision as tv\n",
"import torchvision.transforms as transforms\n",
"\n",
"\n",
"transform = transforms.Compose([transforms.ToTensor()])\n",
"\n",
"batch_size = 128\n",
"\n",
"trainset = tv.datasets.MNIST(\n",
" root='~/Datasets', train=True, transform=transform\n",
")\n",
"trainloader = torch.utils.data.DataLoader(\n",
" trainset, batch_size=batch_size, shuffle=True\n",
")\n",
"\n",
"testset = tv.datasets.MNIST(\n",
" root='~/Datasets', train=False, transform=transform\n",
")\n",
"testloader = torch.utils.data.DataLoader(\n",
" testset, batch_size=batch_size, shuffle=False\n",
")\n",
"\n",
"class Model(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
"\n",
" self.feature_extractor = nn.Sequential(\n",
" nn.Conv2d(1, 100, 3, 3, 1),\n",
" nn.Sigmoid(), \n",
" )\n",
" self.last_layer = nn.Sequential(\n",
" nn.ConvTranspose2d(100, 1, 3, 3, 1, bias=False),\n",
" nn.Flatten(1)\n",
" )\n",
"\n",
" def forward(self, x):\n",
" x = self.feature_extractor(x)\n",
" return self.last_layer(x)\n",
"\n",
"\n",
"model = Model().cuda() \n",
"opt = optim.Adam(model.parameters(), lr=1e-3)\n",
"\n",
"for epoch in range(5):\n",
" for x, _ in trainloader:\n",
" x = x.cuda()\n",
" out = model(x)\n",
" loss = F.mse_loss(out, x.flatten(1))\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
"\n",
" print(f'Loss: {loss.item():.3f}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Using BackPACK"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([100, 1, 3, 3])\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "177e2d70fc1641128b68f1e2002fe451",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/469 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from backpack import extend, backpack\n",
"from backpack.extensions import DiagGGNExact\n",
"from tqdm.notebook import tqdm\n",
"\n",
"lastlayer = extend(model.last_layer)\n",
"lossfunc = extend(nn.MSELoss(reduction='sum'))\n",
"\n",
"print(model.last_layer[0].weight.shape)\n",
"diag_G = torch.zeros_like(model.last_layer[0].weight)\n",
"\n",
"for x, _ in tqdm(trainloader):\n",
" x = x.cuda()\n",
" loss = lossfunc(lastlayer(model.feature_extractor(x)), x.flatten(1))\n",
"\n",
" with backpack(DiagGGNExact()):\n",
" loss.backward()\n",
"\n",
" diag_G += model.last_layer[0].weight.diag_ggn_exact"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Predictions"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"import unfoldNd\n",
"\n",
"prec0 = 1\n",
"\n",
"# Laplace cov\n",
"diag_Sigma = 1/(diag_G + prec0)\n",
"diag_Sigma = diag_Sigma.transpose(0, 1).flatten(1)\n",
"\n",
"# diag_Sigma.shape should be (c_out, c_in*k*k\n",
"# )\n",
"assert len(diag_Sigma.shape) == 2 and diag_Sigma.shape == (1, 100*3*3)\n",
"\n",
"# Following the last layer of the model\n",
"unfold_transpose = unfoldNd.UnfoldTransposeNd(\n",
" kernel_size=3, dilation=1, padding=1, stride=3\n",
")\n",
"\n",
"@torch.no_grad()\n",
"def reconstruct(x):\n",
" phi = model.feature_extractor(x)\n",
" \n",
" # MAP prediction\n",
" mean_pred = model.last_layer(phi).reshape(x.shape)\n",
"\n",
" # Variance\n",
" J_pred = unfold_transpose(phi)\n",
" var_pred = torch.einsum('bij,ki,bij->bkj', J_pred, diag_Sigma, J_pred).reshape(mean_pred.shape)\n",
"\n",
" return mean_pred.cpu().numpy(), var_pred.cpu().numpy()\n",
"\n",
"x_recons = []\n",
"\n",
"for x, _ in testloader:\n",
" x = x.cuda()\n",
" x_recons.append(reconstruct(x))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.5 ('base')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.5"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "2478d115bd7922f0ed574df2e779addadfa7a60215a3648b6e3cf31cb3b0451d"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment