Created
October 14, 2019 00:05
-
-
Save philip-bl/81b81bc2da3680bb7f959f1db64ae783 to your computer and use it in GitHub Desktop.
VAE
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Variational Autoencoders\n", | |
"\n", | |
"VAEs consider two-step generative process by a prior over latent space $p(z)$ and a conditional generative distribution $p_{\\theta}(x|z)$, which is parametrized by a deep neural network (DNN). Our goal is to maximize marginal log-likehood which is intractable in general case. Therefore, variational inference (VI) framework is considered.\n", | |
"\n", | |
"\\begin{equation*}\n", | |
" \\begin{aligned}\n", | |
" & \\log p(x) \\geq \\mathcal{L}(x;\\theta;q) = \\mathbb{E}_{z\\sim q(z)}[\\log p_{\\theta}(x|z)] - \\text{KL}[q(z)\\|p(z)],\n", | |
" \\end{aligned}\n", | |
"\\end{equation*}\n", | |
"\n", | |
"where $q(z|x)$ is a variational posterior distribution. Given data distribution $p_e(x) = \\frac1N\\sum_{i=1}^N \\delta_{x_i}$ we aim at maximizing the average marginal log-likelihood. Following the variational auto-encoder architecture amortized inference is proposed by choice of the variational distribution $q_{\\phi}(z|x)$ which is also parametrized by DNN.\n", | |
"\n", | |
"\\begin{equation*}\n", | |
" \\begin{aligned}\n", | |
" & \\arg\\max\\limits_{\\phi, \\theta}\\mathbb{E}_{x}\\mathcal{L}(x,\\phi,\\theta)=\\arg\\max\\limits_{\\phi, \\theta}\\mathbb{E}_{x}\\mathbb{E}_{z\\sim q(z)}[\\log p_{\\theta}(x|z)] - \\mathbb{E}_x \\text{KL}[q_{\\phi}(z|x)\\|p(z)].\n", | |
" \\end{aligned}\n", | |
"\\end{equation*}\n", | |
"\n", | |
"To evaluate the performance of the VAE approach, we will estimate a negative log likelihood (NLL) on the test set. NLL is calculated by importance sampling method:\n", | |
"\\begin{equation*}\n", | |
" - \\log p(x) \\approx - \\log \\frac{1}{K} \n", | |
" \\sum_{i = 1}^K \\frac{p_\\theta(x | z_i) p(z_i)}{q_\\phi(z_i | x)},\\,\\,\\,\\,z_i \\sim q_\\phi(z | x) \n", | |
"\\end{equation*}" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### References\n", | |
"1. Auto-Encoding Variational Bayes https://arxiv.org/pdf/1312.6114.pdf\n", | |
"2. Beta-VAE https://pdfs.semanticscholar.org/a902/26c41b79f8b06007609f39f82757073641e2.pdf\n", | |
"3. Importance Weighted Autoencoders https://arxiv.org/pdf/1509.00519.pdf \n", | |
"4. VAE with a VampPrior https://arxiv.org/pdf/1705.07120.pdf " | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## VAE\n", | |
"Implement all the method in the class `SimpleVAE` below, using the formular above and detail in the documentation" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Install all libraries required to run this notebook\n", | |
"\n", | |
"```shell\n", | |
"$ conda create --name hw3test python=3.6\n", | |
"$ conda activate hw3test\n", | |
"\n", | |
"# install pytorch - consult https://pytorch.org/get-started/locally, for me the command below was enough\n", | |
"$ conda install pytorch torchvision cudatoolkit=9.0 -c pytorch\n", | |
"\n", | |
"$ pip install einops pytorch-ignite tensorboardX libcrap\n", | |
"$ conda install jupyterlab matplotlib\n", | |
"```\n", | |
"\n", | |
"Also, to view tensorboard logs, install tensorboard. Google to learn how to do it." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from itertools import chain\n", | |
"from typing import *\n", | |
"import math\n", | |
"from datetime import datetime\n", | |
"import os.path\n", | |
"from functools import partial\n", | |
"\n", | |
"import numpy as np\n", | |
"import matplotlib.pyplot as plt\n", | |
"import matplotlib.colors as mcolors\n", | |
"\n", | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as tnnf\n", | |
"from torch.nn.init import xavier_normal_\n", | |
"from torchvision.transforms import Normalize\n", | |
"from torch.distributions import Normal, Independent, Distribution, kl_divergence\n", | |
"from torch.utils.data import Dataset, DataLoader, TensorDataset\n", | |
"from torchvision.utils import make_grid\n", | |
"from torch.utils.data import SubsetRandomSampler\n", | |
"from torch.optim import SGD, Adam\n", | |
"\n", | |
"\n", | |
"from einops import rearrange, reduce # pip install einops, or see https://github.com/arogozhnikov/einops\n", | |
"from einops.layers.torch import Rearrange, Reduce\n", | |
"\n", | |
"from ignite.engine import Engine, Events # pip install pytorch-ignite, I have version 0.2.0\n", | |
"from ignite.handlers import ModelCheckpoint\n", | |
"from ignite.contrib.handlers.tensorboard_logger import (\n", | |
" TensorboardLogger, OutputHandler, WeightsScalarHandler, GradsScalarHandler, WeightsHistHandler, GradsHistHandler\n", | |
")\n", | |
"\n", | |
"# You also need either pip install tensorboardX, or use pytorch ≥ 1.2.0. I haven't tried the latter." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def get_now_as_str(utc: bool =True, year: bool =False, seconds: bool =False) -> str:\n", | |
" \"\"\"Returns current date and time as a string in format like\n", | |
" 'UTC05-22T05:40' (if utc is True and year and seconds are false).\n", | |
" If utc is False, returns something like '05-22T08:40', where date and time are in the\n", | |
" local timezone of the computer.\n", | |
" If year is True, also returns a year: '2019-05-22T08:40'.\n", | |
" If seconds is True, also returns seconds: '2019-05-22T08:40:33'.\"\"\"\n", | |
" format = \"%m-%dT%H:%M\"\n", | |
" if seconds:\n", | |
" format += \":%S\"\n", | |
" if year:\n", | |
" format = \"%Y-\" + format\n", | |
" if utc:\n", | |
" now = datetime.utcnow()\n", | |
" format = \"UTC\" + format\n", | |
" else:\n", | |
" now = datetime.now()\n", | |
" return datetime.strftime(now, format)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"device = cuda\n" | |
] | |
} | |
], | |
"source": [ | |
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", | |
"print(f\"device = {device}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def make_unit_normal_like(distribution: Distribution, device):\n", | |
" unit_normal = Independent(\n", | |
" Normal(torch.zeros(*distribution.batch_shape, *distribution.event_shape, device=device), 1.0),\n", | |
" len(distribution.event_shape))\n", | |
" assert unit_normal.batch_shape == distribution.batch_shape\n", | |
" assert unit_normal.event_shape == distribution.event_shape\n", | |
" return unit_normal" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class CPConv3x3(nn.Sequential):\n", | |
" def __init__(self, in_channels, out_channels, stride=1, rank: Optional[int] = None, deconv: bool = False):\n", | |
" assert stride in (1, 2)\n", | |
" max_possible_rank = 9 * min(in_channels, out_channels)\n", | |
" if rank is None:\n", | |
" rank = max_possible_rank // 6\n", | |
" assert rank <= max_possible_rank\n", | |
" super().__init__()\n", | |
" self.add_module(\"pw1\", nn.Conv2d(in_channels, rank, kernel_size=1, bias=False))\n", | |
" if not deconv:\n", | |
" self.add_module(\"dw\", nn.Conv2d(rank, rank, kernel_size=3, groups=rank, bias=False, stride=stride, padding=1))\n", | |
" else:\n", | |
" self.add_module(\"dw\",nn.ConvTranspose2d(\n", | |
" rank, rank, kernel_size=3, groups=rank, bias=False,\n", | |
" stride=stride, padding=1, output_padding=1 if stride == 2 else 0\n", | |
" ))\n", | |
" self.add_module(\"pw2\", nn.Conv2d(rank, out_channels, kernel_size=1, bias=True))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def make_conv3x3(in_channels, out_channels, stride:int = 1, deconv: bool = False, cp: bool = False) -> nn.Module:\n", | |
" assert stride in (1, 2)\n", | |
" if cp:\n", | |
" return CPConv3x3(in_channels, out_channels, stride=stride, deconv=deconv)\n", | |
" elif not deconv:\n", | |
" return nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1)\n", | |
" else:\n", | |
" return nn.ConvTranspose2d(in_channels, out_channels, 3, stride=stride, padding=1, output_padding=1 if stride == 2 else 0)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class ResidualBlock(nn.Module):\n", | |
" def __init__(self, in_channels, out_channels, deconv: bool, cp: bool, elu_in_the_end: bool = True):\n", | |
" super().__init__()\n", | |
" self.deconv = deconv\n", | |
" self.conv1 = make_conv3x3(in_channels, out_channels, stride=2, deconv=deconv, cp=cp)\n", | |
" self.conv2 = make_conv3x3(out_channels, out_channels, deconv=deconv, cp=cp)\n", | |
" self.skip_connection_channels = min(in_channels, out_channels)\n", | |
" self.elu_in_the_end = elu_in_the_end\n", | |
" last_bias = self.conv2.bias if not cp else self.conv2.pw2.bias\n", | |
" nn.init.constant_(last_bias, 0.0)\n", | |
" \n", | |
" def skip_connection(self, x) -> torch.Tensor:\n", | |
" if not self.deconv:\n", | |
" return reduce(\n", | |
" x, \"b c (h dh) (w dw) -> b c h w\", \"mean\", dh=2, dw=2\n", | |
" )\n", | |
" else:\n", | |
" return tnnf.interpolate(x, scale_factor=2)\n", | |
" \n", | |
" def forward(self, x):\n", | |
" out = self.conv2(tnnf.elu(self.conv1(x)))\n", | |
" out[:, :self.skip_connection_channels] += self.skip_connection(x[:, :self.skip_connection_channels])\n", | |
" if self.elu_in_the_end:\n", | |
" out = tnnf.elu(out)\n", | |
" return out" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class LessShittyEncoder(nn.Module):\n", | |
" def __init__(self, in_channels, image_side, hid_dim, channels_increase_per_block, init_log_var_bias: float):\n", | |
" super().__init__()\n", | |
" B = 6\n", | |
" self.channels_increase_per_block = channels_increase_per_block\n", | |
" self.convolutions = nn.Sequential(*(\n", | |
" ResidualBlock(\n", | |
" in_channels * channels_increase_per_block**b, in_channels * channels_increase_per_block**(b+1),\n", | |
" deconv=False, cp=False\n", | |
" )\n", | |
" for b in range(0, B)\n", | |
" ))\n", | |
" self.linear_to_mean = nn.Linear(in_channels * channels_increase_per_block**B * (image_side // 2**B)**2, hid_dim)\n", | |
" self.linear_to_log_var = nn.Linear(in_channels * channels_increase_per_block**B * (image_side // 2**B)**2, hid_dim)\n", | |
" nn.init.constant_(self.linear_to_log_var.bias, init_log_var_bias)\n", | |
" \n", | |
" def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:\n", | |
" foo = rearrange(self.convolutions(x), \"b c h w -> b (c h w)\")\n", | |
" return self.linear_to_mean(foo), self.linear_to_log_var(foo)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class LessShittyDecoder(nn.Sequential):\n", | |
" def __init__(self, in_channels, hid_dim, image_side, channels_increase_per_block):\n", | |
" B = int(math.log2(image_side))\n", | |
" super().__init__(\n", | |
" nn.Linear(hid_dim, in_channels * channels_increase_per_block**B),\n", | |
" Rearrange(\"b c -> b c () ()\"),\n", | |
" *(\n", | |
" ResidualBlock(\n", | |
" in_channels * channels_increase_per_block**(b+1), in_channels * channels_increase_per_block**b,\n", | |
" deconv=True, cp=False\n", | |
" )\n", | |
" for b in range(B-1, 0, -1)\n", | |
" ),\n", | |
" ResidualBlock(in_channels * channels_increase_per_block, in_channels, deconv=True, cp=False, elu_in_the_end=False)\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class SimpleVAE(nn.Module):\n", | |
" def __init__(self, hid_dim, p_x_given_z_std, channels_increase_per_block):\n", | |
" super().__init__()\n", | |
" in_channels = 3\n", | |
" image_side = 64\n", | |
" self.hid_dim = hid_dim\n", | |
" self.encoder = LessShittyEncoder(in_channels, image_side, hid_dim, channels_increase_per_block, 0.0)\n", | |
" self.decoder = LessShittyDecoder(in_channels, hid_dim, image_side, channels_increase_per_block)\n", | |
" self.p_x_given_z_std = p_x_given_z_std\n", | |
" self.hid_dim = hid_dim\n", | |
" self.init_params()\n", | |
" \n", | |
" def init_params(self):\n", | |
" for m in self.modules():\n", | |
" if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):\n", | |
" nn.init.xavier_normal_(m.weight.data)\n", | |
" \n", | |
" def calc_q_z_given_x(self, x: torch.Tensor) -> Distribution:\n", | |
" mean, log_var = self.encoder(x)\n", | |
" log_std = log_var * 0.5\n", | |
" std = torch.exp(log_std)\n", | |
" q_z_given_x = Independent(Normal(mean, std), reinterpreted_batch_ndims=1)\n", | |
" assert q_z_given_x.batch_shape == (x.shape[0],)\n", | |
" assert q_z_given_x.event_shape == (self.hid_dim,)\n", | |
" return q_z_given_x\n", | |
"\n", | |
" def forward(self, x, num_z_samples_per_image: int) -> Tuple[Distribution, torch.Tensor, Distribution]:\n", | |
" q_z_given_x = self.calc_q_z_given_x(x)\n", | |
" \n", | |
" z_realization = q_z_given_x.rsample((num_z_samples_per_image,))\n", | |
" assert z_realization.shape == (num_z_samples_per_image, x.shape[0], self.hid_dim)\n", | |
" reconstruction = rearrange(\n", | |
" self.decoder(rearrange(z_realization, \"k b c -> (k b) c\")),\n", | |
" \"(k b) c h w -> b k c h w\", k=num_z_samples_per_image\n", | |
" )\n", | |
" \n", | |
" p_x_given_z_realization = Independent(\n", | |
" Normal(reconstruction, self.p_x_given_z_std),\n", | |
" reinterpreted_batch_ndims=3\n", | |
" )\n", | |
" assert p_x_given_z_realization.batch_shape == (x.shape[0], num_z_samples_per_image)\n", | |
" return q_z_given_x, z_realization, p_x_given_z_realization\n", | |
" \n", | |
" def generate_unconditional_samples(self, count: int) -> torch.Tensor:\n", | |
" z = Normal(torch.zeros(self.hid_dim, device=device), 1.0).sample((count, ))\n", | |
" assert z.shape == (count, self.hid_dim)\n", | |
" return self.decoder(z)\n", | |
" \n", | |
" def calc_mean_iwll(self, x, num_samples_per_image: int) -> torch.Tensor:\n", | |
" \"\"\"Estimate of log likelihood , i.e. log p(x), by importance sampling.\n", | |
" This isn't used in training, only for evaluation.\"\"\"\n", | |
" with torch.no_grad():\n", | |
" lls = torch.zeros(num_samples_per_image, x.shape[0])\n", | |
" q_z_given_x = self.calc_q_z_given_x(x)\n", | |
" p_z = Independent(Normal(torch.ones(self.hid_dim, device=x.device), 1.0), reinterpreted_batch_ndims=1)\n", | |
" for k in range(num_samples_per_image):\n", | |
" z_realization = q_z_given_x.sample()\n", | |
" assert z_realization.shape == (x.shape[0], self.hid_dim)\n", | |
" reconstruction = self.decoder(z_realization)\n", | |
" assert reconstruction.shape == (x.shape)\n", | |
" p_x_given_z_realization = Independent(\n", | |
" Normal(reconstruction, self.p_x_given_z_std),\n", | |
" reinterpreted_batch_ndims=3\n", | |
" )\n", | |
" ll = p_x_given_z_realization.log_prob(x) + p_z.log_prob(z_realization) - q_z_given_x.log_prob(z_realization)\n", | |
" assert ll.shape == (x.shape[0], )\n", | |
" lls[k] = ll\n", | |
" return torch.mean(torch.logsumexp(lls, dim=0) - torch.log(torch.tensor(num_samples_per_image, dtype=torch.float32)))\n", | |
"\n", | |
" def calc_average_kl(self, q_z_given_x: Distribution) -> torch.Tensor:\n", | |
" \"\"\"Calculates kl divergence between q(z|x) and N(0, I) and returns its average over images.\"\"\"\n", | |
" return torch.mean(kl_divergence(q_z_given_x, make_unit_normal_like(q_z_given_x, device=q_z_given_x.mean.device)))\n", | |
" \n", | |
" def calc_mean_log_cond_prob(self, p_x_given_z_realization, x) -> torch.Tensor:\n", | |
" \"\"\"This calculates an estimate of the non-KL term in VAE loss (not IWAE, doesn't use importance sampling).\"\"\"\n", | |
" log_cond_prob = rearrange(p_x_given_z_realization.log_prob(x.unsqueeze(1)), \"b () ... -> b ...\")\n", | |
" assert log_cond_prob.shape == (x.shape[0], )\n", | |
" return torch.mean(log_cond_prob)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def calc_beta_vae_loss(mean_kl: torch.Tensor, mean_log_cond_prob: torch.Tensor, beta: float):\n", | |
" \"\"\"\n", | |
" This is β-VAE loss, not IWAE. This is used for training.\n", | |
" \"\"\"\n", | |
" return -mean_log_cond_prob + beta*mean_kl" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Generalization\n", | |
"\n", | |
"We will evaluate generalization of our generative model using the approahc discussed in https://arxiv.org/abs/1811.03259.\n", | |
"\n", | |
"Download the dataset with dots [here](https://drive.google.com/open?id=1CsDMOIGEsD1l3BLhuQDfEfEmLEb83wMz)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"1. [50 pts] Train your VAE on the **subset** of dots contating only 3 points (use batch 0-5 for training and leave 6 and 7 as a test set). Plot ELBO vs Iteration, KL vs Iteration \n", | |
"2. [10 pts] Calculate NLL on a test set, contating only 3 dots \n", | |
"3. [10 pts] Calculate NLL on a test set, containing only 5 dots \n", | |
"4. [30 pts] Sample 25 images from the VAE and draw then on the 5 $\\times$ 5 grid. Comment on the generalization ability of the model " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 48, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# You should change these paths\n", | |
"DATASET_ROOT = \"/mnt/hdd_1tb/datasets/dots\"\n", | |
"OUTPUTS_ROOT_DIR = \"/mnt/important/experiments/2019-10_vae_dots\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def load_images_file(path: str) -> torch.Tensor:\n", | |
" tensor = torch.from_numpy(rearrange(np.load(path)[\"images\"], \"b h w c -> b c h w\"))\n", | |
" assert tensor.shape == (8192, 3, 64, 64)\n", | |
" return tensor" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_dataset_mean = torch.tensor([0.966, 0.966, 0.966])\n", | |
"train_dataset_std = torch.tensor([0.1417, 0.1418, 0.1417])\n", | |
"\n", | |
"class DotsDataset(TensorDataset):\n", | |
" def __init__(self, root_path: str, num_dots: int, batches_nums: Iterable[int]):\n", | |
" self.paths = tuple(\n", | |
" os.path.join(root_path, f\"{num_dots}_dots\", f\"batch{batch_num}.npz\")\n", | |
" for batch_num in batches_nums\n", | |
" )\n", | |
" tensor = (\n", | |
" torch.cat(tuple(load_images_file(path) for path in self.paths))\n", | |
" - rearrange(train_dataset_mean, \"c -> c () ()\")\n", | |
" ) / rearrange(train_dataset_std, \"c -> c () ()\")\n", | |
" super().__init__(tensor)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_dataset = DotsDataset(DATASET_ROOT, num_dots=3, batches_nums=range(0, 6))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"val_dataset_3_dots = DotsDataset(DATASET_ROOT, num_dots=3, batches_nums=(6, 7))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"val_dataset_5_dots = DotsDataset(DATASET_ROOT, num_dots=5, batches_nums=range(0, 8))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"pixel_values_range = (train_dataset.tensors[0].min().item(), train_dataset.tensors[0].max().item())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def show_random_image_from_dataset(dataset: DotsDataset) -> None:\n", | |
" normalized = make_grid(\n", | |
" [dataset[np.random.randint(0, len(dataset))][0]], nrow=1, normalize=True, range=pixel_values_range\n", | |
" )\n", | |
" plt.imshow(\n", | |
" rearrange(normalized, \"c h w -> h w c\").cpu())" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Below we check samples from the 3 datasets to see if everything's ok." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 42, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD7CAYAAACscuKmAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAQ3UlEQVR4nO3dfYxddZ3H8fenU8pjCy1zWyptGZUCsiqFDA8uLkFqsSpL624wmuh2DbHZDRpMTKDsJrsxm2z4y+Bu1E3jUxNZoQG0XZZFmlFcFQIdhEqh4PBQ6djSmQJdHlo60H73jzk9955h2rlz7z33jv4+r2Ryfr/zO3fONzPzuefhnjlHEYGZ/emb1ukCzKw9HHazRDjsZolw2M0S4bCbJcJhN0tEU2GXtFzS05KekbSmVUWZWeup0c/ZJXUBvwOWAYPAZuCzEfFk68ozs1aZ3sRrLwKeiYjnACTdBqwAjhj27u7u6OnpaWKVZnY027dvZ8+ePRpvrJmwnw7sqOkPAhcf7QU9PT1s3ry5iVWa2dFceOGFRxxr5ph9vHePdxwTSFotqV9S//DwcBOrM7NmNBP2QWBhTX8BsHPsQhGxNiJ6I6K3Uqk0sToza0YzYd8MLJb0bkkzgM8AG1tTlpm1WsPH7BHxtqQvAT8FuoDvRcQTLavMzFqqmRN0RMQ9wD0tqsXMSuQr6MwS4bCbJcJhN0uEw26WCIfdLBEOu1kiHHazRDjsZolw2M0S4bCbJaKpy2XNUvfab5/P2/u2/aE60FX8D/BZvWfm7eN75pVe13i8ZTdLhMNulgjvxptN4PWtv8/bO275r8LYaw8PVDtHuXerjq1G7ZTL/qwwtuiGlXn72NPmNFjlxLxlN0uEw26WCIfdLBE+Zjcb441tOwr9gS9/J2+/tefVhr5nHHg7b7+yaUth7M0XqnddPvs//j5vz+ie1dC6jsRbdrNEOOxmifBuvNkYO9feV+g3uuter/1PVx+3sPvWX+Tthdf/ZUvX4y27WSIcdrNEOOxmifAxuxnw5o49eXvvL4/41PHSvXTvo3l7wZc+WRhTV3Pb5glfLel7koYkba2ZN0fSJkkD2XR2U1WYWenqeav4AbB8zLw1QF9ELAb6sr6ZTWET7sZHxP9K6hkzewVwedZeB9wP3NjCuszaav+zL+bt2qvd2m1k8KW8/dYrrxfGmr2irtGDgHkRsQsgm85tqgozK13pZ+MlrZbUL6l/eHh44heYWSkaDftuSfMBsunQkRaMiLUR0RsRvZVKpcHVmZVLM7ryr46aPi3/6jphRuGrWY2GfSOwKmuvAjY0XYmZlaqej95+BDwInC1pUNK1wM3AMkkDwLKsb2ZTWD1n4z97hKGlLa7FzErkK+jMgFkXLs7bMxZ0F8ZGBveMXby8Oi4+K293nXBcS7+3r403S4TDbpYI78abAdOOqUah++rewtjOb91b7sprnhRVWXlxaavxlt0sEQ67WSIcdrNE+JjdbIzTv3hlof/m89WrwV/+n980v4Li05x51+rq+k5dfkHz3/8IvGU3S4TDbpYI78abjaHpxf98e+/Nn8/bJ3/4fYWx2t36fQPV+79Lxe3oSUt68nb31RcVxk75i3MbrnUyvGU3S4TDbpYI78aX4NChg3n75ZcfL4zt319ztx4VT8vOPOmMvH3yyWfWLDbm9K01ZPszL+fte+98qjD25KPVM+6HDh4qjPWcNSdvX/mpswtjH/x2cZd8KvOW3SwRDrtZIhx2s0T4mL0FXn75iUJ/cEdf3h4ZeaXu77ObX+ftE048PW+f0XNVYbkTTzhtsiUm679vrz7K6Zv/Wv35Hth/cLzFx7VtS/V4/t47isf6f/23H8jbf7fmz/P2VDzP4i27WSIcdrNEeDe+Qf+3dyBvP//cTwpjEc0/PmjfG3/I288O3F4YO+vs6hVdxx03B6t65IEdhf6//8uv8vZbI4fGLj5pEcX+Hd+vfrTafdpJefuaL5zX9LpazVt2s0Q47GaJcNjNEuFj9gbtevGBvN2KY/SjGRnZW+gPDW3O24sWfazUdf+xufv2bYV+K47T63XP+urHcn/1Nx8ojHV1dX67Ws/jnxZK+rmkbZKekHR9Nn+OpE2SBrLp7PLLNbNG1fN28zbw1Yh4H3AJcJ2kc4E1QF9ELAb6sr6ZTVH1POttF7Ara78maRtwOrACuDxbbB1wP3BjKVVOEa+/Ub05weuvbe9YHXtfqV4VtnDhssLY2JsmpGDkQPUwavMvdhxlyXK98Gz1asktD+8sjF3woQXtLucdJvWXIakHOB94CJiXvREcfkOY2+rizKx16g67pJOAO4GvRMSrk3jdakn9kvqHh4cnfoGZlaKusEs6htGg3xoRd2Wzd0uan43PB4bGe21ErI2I3ojorVQqrajZzBow4TG7Rv9957vAtoj4es3QRmAVcHM23VBKhVPIyIG9Ey/UBiMj1R2rgwdHCmPTp7f2Mb9/DPbteytv769pd9K+N6ZGHbXq+Zz9UuDzwOOSHsvm/QOjIV8v6VrgBeCacko0s1ao52z8r3jHMyxyS1tbjpmVxVfQTUJX17GdLgGAadOOrWkf08FKpoaZM6s/j+7TTiyM7XnxjbbVUfup52kLZrZtvfVK70NZs0Q57GaJ8G78JMya9e68fexx3Xn7wJt72lrHyacsztvTpnUdZck0dE2vbrMu+9h7CmN3rXt87OKlueCS6lVyZ57TfZQlO8NbdrNEOOxmiXDYzRLhY/ZJqP2Pskrlgrw9uOO+ktdb/DXVrtuKPjXmphEP9G3P2y8Ovtby9R1/QvV38+kvTr2bTNbylt0sEQ67WSK8G9+gefMuydsjI8Xdw6HdDzb9/adNm5G3F57x8cJY7UeAVvSuhbMK/a99q3qPvm/88y/z9pOP7m7o+y/oObnQX31j9e+g99KFDX3PdvGW3SwRDrtZIhx2s0T4mL1BtY/kXbToysLY7Dnn5O2X9mwpjO3fV3NrrjGP9Z05c1HePvXUJXn7+ONPbarWlNVetvpvt63M2/2/Hiwst/WRXXk7DhYf6HbGWdXn6V2+/L2FsdpLdae6P55KzawpDrtZIrwbX4KZJy0at22dVXvodeGHix+Tje3/KfKW3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLxIRhl3ScpIclbZH0hKSvZfPnSNokaSCbzi6/XDNrVD1b9gPAFRFxHrAEWC7pEmAN0BcRi4G+rG9mU9SEYY9Rr2fdY7KvAFYA67L564CV47zczKaIep/P3pU9wXUI2BQRDwHzImIXQDadW16ZZtasusIeEQcjYgmwALhI0vvrXYGk1ZL6JfUPDw9P/AIzK8WkzsZHxF7gfmA5sFvSfIBsOnSE16yNiN6I6K1UKk2Wa2aNqudsfEXSKVn7eOCjwFPARmBVttgqYENZRZpZ8+r5F9f5wDpJXYy+OayPiLslPQisl3Qt8AJwTYl1mlmTJgx7RPwWOH+c+S8BS8soysxaz1fQmSXCYTdLhMNulgiH3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyWi7rBnj21+VNLdWX+OpE2SBrLp7PLKNLNmTWbLfj2wraa/BuiLiMVAX9Y3symqrrBLWgB8EvhOzewVwLqsvQ5Y2drSzKyV6t2y3wLcAByqmTcvInYBZNO5La7NzFqonuezXwUMRcQjjaxA0mpJ/ZL6h4eHG/ | |
kWZtYC9WzZLwWulrQduA24QtIPgd2S5gNk06HxXhwRayOiNyJ6K5VKi8o2s8maMOwRcVNELIiIHuAzwM8i4nPARmBVttgqYENpVZpZ05r5nP1mYJmkAWBZ1jezKWr6ZBaOiPuB+7P2S8DS1pdkZmXwFXRmiXDYzRLhsJslwmE3S4TDbpYIh90sEQ67WSIcdrNEOOxmiXDYzRLhsJslwmE3S4TDbpYIh90sEQ67WSIcdrNEOOxmiXDYzRLhsJslwmE3S4TDbpYIh90sEQ67WSIcdrNEOOxmiajriTDZQx1fAw4Cb0dEr6Q5wO1AD7Ad+HREvFJOmWbWrMls2T8SEUsiojfrrwH6ImIx0Jf1zWyKamY3fgWwLmuvA1Y2X46ZlaXesAdwn6RHJK3O5s2LiF0A2XRuGQWaWWvU+xTXSyNip6S5wCZJT9W7guzNYTXAokWLGijRzFqhri17ROzMpkPAj4GLgN2S5gNk06EjvHZtRPRGRG+lUmlN1WY2aROGXdKJkmYebgNXAluBjcCqbLFVwIayijSz5tWzGz8P+LGkw8v/Z0TcK2kzsF7StcALwDXllWlmzZow7BHxHHDeOPNfApaWUZSZtZ6voDNLhMNulgiH3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLRF1hl3SKpDskPSVpm6QPSZojaZOkgWw6u+xizaxx9W7ZvwHcGxHnMPooqG3AGqAvIhYDfVnfzKaoep7iOgu4DPguQESMRMReYAWwLltsHbCyrCLNrHn1bNnfAwwD35f0qKTvZI9unhcRuwCy6dwS6zSzJtUT9unABcC3I+J84A0mscsuabWkfkn9w8PDDZZpZs2qJ+yDwGBEPJT172A0/LslzQfIpkPjvTgi1kZEb0T0ViqVVtRsZg2YMOwR8SKwQ9LZ2aylwJPARmBVNm8VsKGUCs2sJabXudyXgVslzQCeA77A6BvFeknXAi8A15RTopm1Ql1hj4jHgN5xhpa2thwzK4uvoDNLhMNulgiH3SwRDrtZIhx2s0Q47GaJcNjNEqGIaN/KpGHg90A3sKdtKz4y11HkOoqmQh2TreGMiBj3uvS2hj1fqdQfEeNdpOM6XIfrKKkG78abJcJhN0tEp8K+tkPrHct1FLmOoqlQR8tq6Mgxu5m1n3fjzRLR1rBLWi7paUnPSGrb3WglfU/SkKStNfPafitsSQsl/Ty7HfcTkq7vRC2SjpP0sKQtWR1f60QdNfV0Zfc3vLtTdUjaLulxSY9J6u9gHaXdtr1tYZfUBXwT+DhwLvBZSee2afU/AJaPmdeJW2G/DXw1It4HXAJcl/0M2l3LAeCKiDgPWAIsl3RJB+o47HpGb09+WKfq+EhELKn5qKsTdZR32/aIaMsX8CHgpzX9m4Cb2rj+HmBrTf9pYH7Wng883a5aamrYACzrZC3ACcBvgIs7UQewIPsDvgK4u1O/G2A70D1mXlvrAGYBz5OdS2t1He3cjT8d2FHTH8zmdUpHb4UtqQc4H3ioE7Vku86PMXqj0E0xekPRTvxMbgFuAA7VzOtEHQHcJ+kRSas7VEept21vZ9g1zrwkPwqQdBJwJ/CViHi1EzVExMGIWMLolvUiSe9vdw2SrgKGIuKRdq97HJdGxAWMHmZeJ+myDtTQ1G3bJ9LOsA8CC2v6C4CdbVz/WHXdCrvVJB3DaNBvjYi7OlkLQIw+3ed+Rs9ptLuOS4GrJW0HbgOukPTDDtRBROzMpkPAj4GLOlBHU7dtn0g7w74ZWCzp3dldaj/D6O2oO6Xtt8KWJEYfo7UtIr7eqVokVSSdkrWPBz4KPNXuOiLipohYEBE9jP49/CwiPtfuOiSdKGnm4TZwJbC13XVE2bdtL/vEx5gTDZ8Afgc8C/xjG9f7I2AX8Baj757XAqcyemJoIJvOaUMdH2b00OW3wGPZ1yfaXQvwQeDRrI6twD9l89v+M6mp6XKqJ+ja/fN4D7Al+3ri8N9mh/5GlgD92e/mJ8DsVtXhK+jMEuEr6MwS4bCbJcJhN0uEw26WCIfdLBEOu1kiHHazRDjsZon4fxMnTn/5iPtQAAAAAElFTkSuQmCC\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"show_random_image_from_dataset(train_dataset)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD7CAYAAACscuKmAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAQoElEQVR4nO3df5BV5X3H8ffHRUAB5ddCdkCyOmE0xiAmK+rQWCNBiTXCpDHV0Q7JkKHTmFadZAy2007TmXbsH2bsH52MJDHSajU0akGbwZA1NLW1yqKiKCL+oIgguxhBRSIsfPvHHu7uwV327t5z793s83nN7JznOefsfb6z7GfPOfcenqOIwMyGvxPqXYCZ1YbDbpYIh90sEQ67WSIcdrNEOOxmiago7JIWSNoi6RVJy4oqysyKp8F+zi6pAXgZmA/sANYD10bEi8WVZ2ZFGVHB984BXomI1wAk3Q8sBPoM++TJk6O5ubmCIc3seLZt28aePXvU27ZKwj4NeKNHfwdwwfG+obm5mfXr11cwpJkdz/nnn9/ntkqu2Xv76/GRawJJSyW1SWrr6OioYDgzq0QlYd8BnNajPx3YeexOEbE8IloioqWxsbGC4cysEpWEfT0wU9LpkkYC1wCriynLzIo26Gv2iOiU9C3gUaABuCsiXiisMjMrVCVv0BERPwd+XlAtZlZFvoPOLBEOu1kiHHazRDjsZolw2M0S4bCbJcJhN0uEw26WCIfdLBEOu1kiKrpd1qyaes6iJPU6H4MNgI/sZolw2M0S4dN4q6n9H3Tm+g8++Gap/eiju3Pbtm8/UGqPHNl9Gt/SMiG336JF00rtCy+YWEidw5GP7GaJcNjNEuGwmyXC1+xWdXv3HSy1b775+dy2Z57ZO+DXW7OmPdf/5S+7+zfdNDO37frrZgz49YcrH9nNEuGwmyXCp/GJevXAtlJ70/4tuW2d0f3xWOOJk0vtOeNm5/Yb3TCqrLFuv/2VUnswp+396ezxad4dd2zNbTvrrHGldstn8x/ZpcZHdrNEOOxmiXDYzRLha/ZhrOPg26X2Pe0P5ra9+MHLA369h3+zNtf/0sT5pfbF4/MP8H1r929L7ccey98GW02d+btxWbWq+/GDvmbvh6S7JLVL2tRj3URJayVtzZZp/xTNfgeUcxp/N7DgmHXLgNaImAm0Zn0zG8L6PY2PiF9Laj5m9ULgkqy9AlgHfLfAumyQ3j+8v9T+wc5/LrXfOPiRp2kP2N7Ofbn+ve0PlNojTsgfN3Y/0f0/0fbvP1Lx2IPV1vZO3cYeagb7Bt3UiNgFkC2nFFeSmVVD1d+Nl7RUUpukto6OjmoPZ2Z9GOy78bslNUXELklNQHtfO0bEcmA5QEtLS/S1nxXj8X1PldpFnLofzxG6/znXvP2fuW2fOvTVqo5drs5O/8odNdgj+2pgcdZeDKwqphwzq5ZyPnq7D3gCOFPSDklLgNuA+ZK2AvOzvpkNYeW8G39tH5vmFVyLmVWR76AbZja891xdxt11KH+X3Kc/8WYfe9bWjBkn17uEIcP3xpslwmE3S4RP44eZd465y61eJjUfLrXPPffUUnvjxtrWt2DB1JqON5T5yG6WCIfdLBEOu1kifM0+zIw+oXsSyH2H36tjHaNL7W9+8/RS+zvfyc8b/957hynS3Ln5Z70tvKqp0Nf/XeYju1kiHHazRPg0fpg5Z+xZpfbudx6v2bjjThiT688ae3apPXZO911st98+K7ffnXe+Xmo//XR+Tvno4z+snXpq/tf28su7P1678cZP5LaNHNlwnKrT4iO7WSIcdrNE+DR+mPncuO4pnf9334ZSe/+RA1Ud98JTP5vrj23o/T+gzDl/Yp/9jc/l7657/fXu+fRGjlSpfdFFk3L7TRg/cmDFJspHdrNEOOxmiXDYzRLha/ZhZtroj5XaX2v6o1L77l0rc/vtP/JBxWNdMO68UvvLk6+o+PXOnXXqcftWGR/ZzRLhsJslwqfxw9jsMZ8qtf/q4zfntv33u93zy2/54NXctkPR/SjUKSd2f8w155Tzcvt9+uTuu/UkYUObj+xmiXDYzRLhsJslwtfsiZh04vhc/6pJl/XYWONirC7KefzTaZJ+JWmzpBck3ZitnyhpraSt2XJC9cs1s8Eq5zS+E/h2RHwSuBC4QdLZwDKgNSJmAq1Z38yGqH7DHhG7IuLprP0esBmYBiwEVmS7rQAWVatIM6vcgN6gk9QMnAc8CUyNiF3Q9QcBmFJ0cWZWnLLDLmks8ABwU0S8O4DvWyqpTVJbR0fHYGo0swKUFXZJJ9IV9Hsj4sFs9W5JTdn2JqC9t++NiOUR0RIRLY2NjUXUbGaDUM678QJ+DGyOiO/32LQaWJy1FwOrii/PzIpSzufsc4E/Bp6X9Gy27i+A24CVkpYA24Grq1OimRWh37BHxONAX//LYV6x5ZhZtfh2WbNEOOxmiXDYzRLhsJslwmE3S4TDbpYIh90sEZ68wpIRR46U2ntXPZbb9v5/dT8X78iBD3PbRp0xrdSecE1+fvxRpzUVWWJV+chulgiH3SwRPo23Ye3Ai6+U2tv/9G+71z+7ZVCvt/uOe3L9KX9+Xan9sVuWlNpDcR59H9nNEuGwmyXCYTdLhK/ZbVg59FZ+6rNt13dPevzhq29U/Prxfv5R17v//oeldsOYk0rtKd+6jqHGR3azRDjsZonwabwNK2//y8O5fhGn7uXac+e/ldqTvvGV3LaG0aNqVkdffGQ3S4TDbpYIn8bbsLLv4XV1G/vg9l3ddTySr2PiVy6vcTUf5SO7WSIcdrNEOOxmifA1uw0rnXveqXcJAHTu2VvvEj6inGe9jZb0lKSNkl6Q9L1s/URJayVtzZYTql+umQ1WOafxHwKXRsS5wGxggaQLgWVAa0TMBFqzvpkNUeU86y2A97PuidlXAAuBS7L1K4B1wHcLr9BsABrGn5LrH3qz1yeJ16COcXUZ93jKfT57Q/YE13ZgbUQ8CUyNiF0A2XJK9co0s0qVFfaIOBwRs4HpwBxJ55Q7gKSlktoktXV0dPT/DWZWFQP66C0i9tJ1ur4A2C2pCSBb9nq+FBHLI6IlIloaGxsrLNfMBqvfa3ZJjcChiNgr6STgC8A/AKuBxcBt2XJVNQs1K8cpX/xcrv/bF17pY8/ijWjqPpiNv/L3azZuucr5nL0JWCGpga4zgZUR8YikJ4CVkpYA24Grq1inmVWonHfjnwPO62X928C8ahRlZsXzHXQ2rExa/KVc/537/qPUrvbHcJO/trDUbhg7pqpjDYbvjTdLhMNulgifxtuwMmrGtFz/4z/5u1J7+5/8Tal98PU3BzfAiIZcd/I3/rDUntrj8U9DkY/sZolw2M0S4bCbJcLX7Dasjb1gVql95v90P275Nz9dk9tv/6/bSu0jBw7mto06vft9gAnXXpHbdvKsMwupsxZ8ZDdLhMNulgifxlsyGk4+udRu/PqXc9uO7Q9HPrKbJcJhN0uEw26WCIfdLBEOu1kiHHazRDjsZolw2M0S4bCbJcJhN0uEw26WCIfdLBEOu1kiHHazRJQd9uyxzc9IeiTrT5S0VtLWbDmhemWaWaUGcmS/Edjco78MaI2ImUBr1jezIaqssEuaDvwB8KMeqxcCK7L2CmBRsaWZWZHKPbLfAdwCHOmxbmpE7ALIllMKrs3MCtRv2CVdCbRHxIbBDCBpqaQ2SW0dHR2DeQkzK0A5R/a5wFWStgH3A5dKugfYLakJIFv2+ojMiFgeES0R0dLY2NjbLmZWA/2GPSJujYjpEdEMXAM8FhHXA6uBxdlui4FVVavSzCpWyefstwHzJW0F5md9MxuiBjSVdESsA9Zl7beBecWXZGbV4DvozBLhsJslwmE3S4TDbpYIh90sEQ67WSIcdrNEOOxmiXDYzRLhsJslwmE3S4TDbpYIh90sEQ67WSIcdrNEOOxmiXDYzRLhsJslwmE3S4TDbpYIh90sEQ67WSIcdrNEOOxmiXDYzRJR1hNhsoc6vgccBjojokXSROCnQDOwDfhqRLxTnTLNrFIDObJ/ | |
PiJmR0RL1l8GtEbETKA165vZEFXJafxCYEXWXgEsqrwcM6uWcsMewC8kbZC0NFs3NSJ2AWTLKdUo0MyKUe5TXOdGxE5JU4C1kl4qd4Dsj8NSgBkzZgyiRDMrQllH9ojYmS3bgYeAOcBuSU0A2bK9j+9dHhEtEdHS2NhYTNVmNmD9hl3SGEnjjraBy4BNwGpgcbbbYmBVtYo0s8qVcxo/FXhI0tH9/zUi1khaD6yUtATYDlxdvTLNrFL9hj0iXgPO7WX928C8ahRlZsXzHXRmiXDYzRLhsJslwmE3S4TDbpYIh90sEQ67WSIcdrNEOOxmiXDYzRLhsJslwmE3S4TDbpYIh90sEQ67WSIcdrNEOOxmiXDYzRLhsJslwmE3S4TDbpYIh90sEQ67WSIcdrNEOOxmiSgr7JLGS/qZpJckbZZ0kaSJktZK2potJ1S7WDMbvHKP7P8IrImIs+h6FNRmYBnQGhEzgdasb2ZDVDlPcT0FuBj4MUBEHIyIvcBCYEW22wpgUbWKNLPKlXNkPwPoAH4i6RlJP8oe3Tw1InYBZMspVazTzCpUTthHAJ8BfhAR5wH7GcApu6SlktoktXV0dAyyTDOrVDlh3wHsiIgns/7P6Ar/bklNANmyvbdvjojlEdESES2NjY1F1Gxmg9Bv2CPiLeANSWdmq+YBLwKrgcXZusXAqqpUaGaFGFHmfn8G3CtpJPAa8HW6/lCslLQE2A5cXZ0SzawIZYU9Ip4FWnrZNK/YcsysWnwHnVkiHHazRDjsZolw2M0S4bCbJcJhN0uEw26WCEVE7QaTOoD/AyYDe2o2cN9cR57ryBsKdQy0ho9HRK/3pdc07KVBpbaI6O0mHdfhOlxHlWrwabxZIhx2s0TUK+zL6zTusVxHnuvIGwp1FFZDXa7Zzaz2fBpvloiahl3SAklbJL0iqWaz0Uq6S1K7pE091tV8KmxJp0n6VTYd9wuSbqxHLZJGS3pK0sasju/Vo44e9TRk8xs+Uq86JG2T9LykZyW11bGOqk3bXrOwS2oA/gn4InA2cK2ks2s0/N3AgmPW1WMq7E7g2xHxSeBC4IbsZ1DrWj4ELo2Ic4HZwAJJF9ahjqNupGt68qPqVcfnI2J2j4+66lFH9aZtj4iafAEXAY/26N8K3FrD8ZuBTT36W4CmrN0EbKlVLT1qWAXMr2ctwMnA08AF9agDmJ79Al8KPFKvfxtgGzD5mHU1rQM4BXid7L20ouuo5Wn8NOCNHv0d2bp6qetU2JKagfOAJ+tRS3bq/CxdE4Wuja4JRevxM7kDuAU40mNdPeoI4BeSNkhaWqc6qjptey3Drl7WJflRgKSxwAPATRHxbj1qiIjDETGbriPrHEnn1LoGSVcC7RGxodZj92JuRHyGrsvMGyRdXIcaKpq2vT+1DPsO4LQe/enAzhqOf6yypsIumqQT6Qr6vRHxYD1rAYiup/uso+s9jVrXMRe4StI24H7gUkn31KEOImJntmwHHgLm1KGOiqZt708tw74emCnp9GyW2mvomo66Xmo+FbYk0fUYrc0R8f161SKpUdL4rH0S8AXgpVrXERG3RsT0iGim6/fhsYi4vtZ1SBojadzRNnAZsKnWdUS1p22v9hsfx7zRcAXwMvAq8Jc1HPc+YBdwiK6/nkuASXS9MbQ1W06sQR2/R9ely3PAs9nXFbWuBZgFPJPVsQn462x9zX8mPWq6hO436Gr98zgD2Jh9vXD0d7NOvyOzgbbs3+bfgQlF1eE76MwS4TvozBLhsJslwmE3S4TDbpYIh90sEQ67WSIcdrNEOOxmifh/R8NL8K+cohQAAAAASUVORK5CYII=\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"show_random_image_from_dataset(val_dataset_3_dots)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD7CAYAAACscuKmAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAATmklEQVR4nO3dfZQV5X0H8O+XheX9bdnLZpWXRd0iVAFxQRRjlRdL1ARNSn2JFlMiOam2emKj2JzY2p605OTEY0/b05aqCalGQ3wpBK2Iq9QqCCyCKKKCiLiy7t5F3t9Zfv1jh7l3trvs3XvvzFz2+X7O2TO/Z2bunR+wv/s8M3d4hmYGEen8usSdgIhEQ8Uu4ggVu4gjVOwijlCxizhCxS7iiJyKneQMkh+S3EpyXr6SEpH8Y7bfs5MsAvARgOkAagGsBXCzmb2fv/REJF+65vDaiQC2mtk2ACD5NICZANos9tLSUquoqMjhkCJyOtu3b0djYyNb25ZLsZ8N4LO0di2AS073goqKCqxduzaHQ4rI6UyYMKHNbbmcs7f26fH/zglIziVZQ7ImmUzmcDgRyUUuxV4LYGhaewiAnS13MrMFZlZlZlWJRCKHw4lILnIp9rUAKkmOIFkM4CYAS/KTlojkW9bn7GZ2guRdAJYBKALwuJltyltmIpJXuVygg5m9CODFPOUiIiHSHXQijlCxizhCxS7iiJzO2UUKnZ086ceHVqzy4wMv/U9gv6bkl37cpW/vwLZeV0z0497XTg1sK+rZIy95RkE9u4gjVOwijlCxizhC5+zSqTTt2x9oN9z39358+I2arN7z0Ksr/XjvE88Htg3+2Y/8uHvliKzePyrq2UUcoWIXcYSG8dKpJB96JNDOdujeluNbPw20G374Ez8+66l/9uNC/EpOPbuII1TsIo7QMF7OeEc2b/XjQ6+8Gemx04f1B5e+4sf9Zl0XaR6ZUM8u4ggVu4gjVOwijtA5u5zxDlW/kWo0NcWXx+tr/Fjn7CISGxW7iCM0jJcz3sn9B+NOAQDQtP9A3Cmclnp2EUeo2EUcoWIXcYTO2eWMV5QoiTsFAEDXAsmjLe327CQfJ9lA8r20dSUkl5Pc4i0HhpumiOQqk2H8LwHMaLFuHoBqM6sEUO21RaSAtTuMN7PXSVa0WD0TwJVevBDACgD35zEvkYz1uWaKH+9e8FRw46HDkeXRe/oVkR0rG9leoCszszoA8JaD85eSiIQh9KvxJOeSrCFZk0wmwz6ciLQh26vx9STLzayOZDmAhrZ2NLMFABYAQFVVlWV5PJE2dTurzI/7zbo2sG3fwmdCPXbPy8b7ce9pl4d6rFxl27MvATDbi2cDWJyfdEQkLJl89fYUgFUARpKsJTkHwHwA00luATDda4tIAcvkavzNbWya2sZ6ESlAuoNOOpVBP/hucAVT4b5FL6QaHflKrig1AO511WWBTYmH7kkdqkth331e2NmJSN6o2EUcoWG8dCosKgq0S//ye37c/5br/fjAC9WB/U40funHRb17Bbb1mjbZj3uMHpmXPOOgnl3EESp2EUeo2EUcoXN2cUb6bbUD77glxkzioZ5dxBEqdhFHqNhFHKFiF3GEil3EEboaf4bb9nnwkUOr393tx1/uO+bHPYqDn+sXnNffjyeMDk4O3KULIZ2PenYRR6jYRRyhYhdxhM7ZzwDHT5wMtP/zhU/9eOWGLwPbMp3Rc+U7qdctW1kf2HbHN0f4cXlpjwzfUQqdenYRR6jYRRyhYfwZYNHLtYH2my2G7rn6tO5QoP1vv/3Yj+f96fl+3LN7cGIIObOoZxdxhIpdxBEqdhFH6Jy9QDXuOerHb25ojPTYtQ1H/HjVxl1+PGWCHtZ7Jsvk8U9DSb5GcjPJTSTv9taXkFxOcou3HNjee4lIfDIZxp8AcK+ZjQIwCcCdJEcDmAeg2swqAVR7bREpUJk8660OQJ0X7ye5GcDZAGYCuNLbbSGAFQDuDyVLB63bnPrfa0ePx/ek6w0f7PFjDePPbB26QEeyAsBFAFYDKPM+CE59IOg3QaSAZVzsJPsAeBbAPWa2rwOvm0uyhmRNMpnMJkcRyYOMip1kNzQX+pNm9py3up5kube9HEBDa681swVmVmVmVYlEIh85i0gW2j1nJ0kAjwHYbGYPp21aAmA2gPnecnEoGRYos9R59K5NqwLbDnyeut20qHvPwLayidP9uLhXf7Rl/6GmXFPMiwOHCyMPyV0m37NPBnAbgHdJbvDW/RWai3wRyTkAdgCYFU6KIpIPmVyNfwOBR9oHTM1vOiISFt1B1wH1617x481P/NSPv9z0Vsbv0X1A6kuLodNuDGz7/dv/2o/79SqMf5q+BZKH5E73xos4QsUu4giN0U6jbs2yQHvNT77jx02H92f1nkf3pL6h3PrMPwW2HfoiNbfcxX/xqB8vXhH8TD5yLDgnXZguGjUgsmNJuNSzizhCxS7iCBW7iCN0zt6CnUydD7//+EOBbdmep2dq5xtL/Pisy1M3JH51/OTAfsvfavXO5LwZ9pXUXX+XXlgS6rEkOurZRRyhYhdxhIbxLaR/3bZ327ux5VH72rN+POtvgnfaHUt7HNQbbwfnp2vK4lu584b2DrS/e0OFH3cv1lzxnYV6dhFHqNhFHKFiF3GEztlbOLDjw7hTAAAcqN3qx0VFwf9h/CfXDvfjKROCs/+8tTH1HLjd+477cY/uwc/1CytTE2eMrQxOotE8X4l0NurZRRyhYhdxhIbxLXQpLo47BQAAu3bLaL8hg3sF2n80rVcbe4rr1LOLOELFLuIIDeNbSIy/yo/ZrXtgmx0/2nL30JSOmdz+Tjlo3J16Uuuvn/0ksK0m7amxJ5pSU2aPanHV/pZvjfDjc4f3y3eKkmfq2UUcoWIXcYSKXcQROmdvof+wUX5cPmlGYNvO/w33CVddinv48fCrb83rey9dXhto3/vgWj9u/DKzaxHLXt0ZaP/7wtTdhj+864LAtrvvGN3RFCVk7fbsJHuQXEPyHZKbSD7krS8huZzkFm85MPx0RSRbmQzjjwKYYmZjAYwDMIPkJADzAFSbWSWAaq8tIgUqk2e9GYADXrOb92MAZgK40lu/EMAKAPfnPcMYjf2znwXah+s/8+PdH72d8/u3/GpvzPfn+3HJyItzfv+165N+fNe84COq9h84kfP7Hz6Sminj736+MbCttCT1Z/v2t87N+ViSu0yfz17kPcG1AcByM1sNoMzM6gDAWw4+3XuISLwyKnYzazKzcQCGAJhI8oL2XnMKybkka0jWJJPJ9l8gIqHo0FdvZrYHzcP1GQDqSZYDgLdsdX5jM1tgZlVmVpVIJFrbRUQi0O45O8kEgONmtodkTwDTAPwUwBIAswHM95bhfi8Vg56DygPtyf/wvB9/8uIv/Lh2xXOB/Q7Vp57Z1qW4Z2Db4LTbcStm3BbYlhjz1eyTbcXjT33sx/k4Rz8ds2D7sSe3+PEt3zzHjzUxRnwy+Z69HMBCkkVoHgksMrOlJFcBWERyDoAdAGaFmKeI5CiTq/EbAVzUyvpdAKaGkZSI5J/uoOuA4r6pRyGNvPHeVmMAsLQxbdTD1qPHmvx42WufR3rsdO9s2uPHb61LXZi9tEpf2sRF98aLOELFLuIIDeNDEOcV5wMHU9NH702bSjpOu3ZHN+mHtE09u4gjVOwijlCxizhC5+ydTL8+qXnvSwel4sZdx+JIBwBQXqa57AuBenYRR6jYRRyhYXwn061b6vP7mmlD/PhXv9kWaR6XXFzqxxePGRTpsaV16tlFHKFiF3GEil3EETpn78TmfLvSj1+sDv4PuMbG/N7C2q1r8Bbh798+Mq/vL7lTzy7iCBW7iCM0jO/ELhiZekjPow9fFtj2gx+nHv+07dMDyEb63PA/vndMYNvXrx6a1XtKeNSzizhCxS7iCA3jHfHVS8oC7TeXfs2Pn31hR2Db2vWNfnzyZGo+vfPP6x/Y78YbKvx4YP/ | |
go6yk8KhnF3GEil3EESp2EUfonN1RxcVFfnzzDSMC21q2pXPIuGf3Htu8nuRSr11CcjnJLd5yYHvvISLx6cgw/m4Am9Pa8wBUm1klgGqvLSIFKqNiJzkEwLUAHk1bPRPAQi9eCOD6/KYmIvmUac/+CID7AJxMW1dmZnUA4C31EC+RAtZusZO8DkCDma3L5gAk55KsIVmTTCbbf4GIhCKTnn0ygG+Q3A7gaQBTSD4BoJ5kOQB4y4bWXmxmC8ysysyqEolEntIWkY5qt9jN7AEzG2JmFQBuAvCqmd0KYAmA2d5uswEsDi1LEclZLjfVzAcwneQWANO9togUqA7dVGNmKwCs8OJdAKbmPyURCYNulxVxhIpdxBEqdhFHqNhFHKFiF3GEil3EESp2EUeo2EUcoWIXcYSKXcQRKnYRR6jYRRyhYhdxhIpdxBEqdhFHqNhFHKFiF3GEHv8kkidmqcdb73t/e2Db8d37/LioV4/AtgFjz/NjFhUhLOrZRRyhYhdxhIbxIjnYtfJdP/7idyv9+OC2nRm/R/ezBvlx2bQJqfiaSYH9SGaTok89u4gjVOwijlCxizhC5+wiHdDwavD5ptv/Y2mqcfIksnF05y4/3vGrl/z42N79gf2G3XJ1Vu9/SkbF7j3UcT+AJgAnzKyKZAmA3wCoALAdwB+b2e6cshGR0HRkGH+VmY0zsyqvPQ9AtZlVAqj22iJSoHIZxs8EcKUXL0TzM+DuzzEfkYJz4tBhP659ujq4Mcuheya++N2qQHvQZRf6ce+K8g6/X6Y9uwF4meQ6knO9dWVmVgcA3nJwh48uIpHJtGefbGY7SQ4GsJzkB5kewPtwmAsAw4YNyyJFEcmHjHp2M9vpLRsAPA9gIoB6kuUA4C0b2njtAjOrMrOqRCKRn6xFpMPa7dlJ9gbQxcz2e/HVAP4WwBIAswHM95aLw0xUJC6Nr2/04xN7D0Z34BbXA5Ir1vtx79s7fs6eyTC+DMDz3n25XQH82sxeIrkWwCKScwDsADCrw0cXkci0W+xmtg3A2FbW7wIwNYykRCT/dAedSDuOfJ6MOwUAwJHPG3N6ve6NF3GEil3EESp2EUfonF2kHSwqjD4x1zwK408hIqFTsYs4QsN4kXb0GVXhx/X/vTq2PPqOGp7T69WzizhCxS7iCA3jRdpRMuF8P66r+Epg26HtX4R23K59ewXapX8wLqf3U88u4ggVu4gjVOwijtA5u0g72CXVJ1Z87+uBbVt//ls/Pta4J+djdUl7nHPLYxUP6Jvbe+f0ahE5Y6jYRRyhYbxIB/Q5Z0igff6Ds/24/uU1frx79ebAfsf3HvDjrmlDdQAYMP73/DgxrcqP+5x7dm7JtqCeXcQRKnYRR6jYRRyhc3aRHPQoK/Hj4bfN8ONht/5hYD87fsKP2S1Ydt407aFTzy7iCBW7iCM0jBcJQcuhOYu7xZRJSkY9O8kBJJ8h+QHJzSQvJVlCcjnJLd5yYNjJikj2Mh3G/yOAl8zsfDQ/CmozgHkAqs2sEkC11xaRAtVusZPsB+AKAI8BgJkdM7M9AGYCWOjtthDA9WElKSK5y6RnPwdAEsAvSK4n+aj36OYyM6sDAG85OMQ8RSRHmRR7VwDjAfyrmV0E4CA6MGQnOZdkDcmaZLIwHpAn4qJMir0WQK2ZnZpD9xk0F389yXIA8JYNrb3YzBaYWZWZVSUSiXzkLCJZaLfYzewLAJ+RHOmtmgrgfQBLAJz6Lz+zASwOJUMRyYtMv2f/cwBPkiwGsA3Ad9D8QbGI5BwAOwDMCidFEcmHjIrdzDYAqGpl09T8piMiYdHtsiKOULGLOELFLuIIFbuII1TsIo5QsYs4QsUu4giaWXQHI5MAPgVQCqAxsgO3TXkEKY+gQsijozkMN7NW70uPtNj9g5I1ZtbaTTrKQ3koj5By0DBexBEqdhFHxFXsC2I6bkvKI0h5BBVCHnnLIZZzdhGJnobxIo6ItNhJziD5IcmtJCObjZbk4yQbSL6Xti7yqbBJDiX5mjcd9yaSd8eRC8keJNeQfMfL46E48kjLp8ib33BpXHmQ3E7yXZIbSNbEmEdo07ZHVuwkiwD8C4CvARgN4GaSoyM6/C8BzGixLo6psE8AuNfMRgGYBOBO7+8g6lyOAphiZmMBjAMwg+SkGPI45W40T09+Slx5XGVm49K+6oojj/CmbTezSH4AXApgWVr7AQAPRHj8CgDvpbU/BFDuxeUAPowql7QcFgOYHmcuAHoBeBvAJXHkAWCI9ws8BcDSuP5tAGwHUNpiXaR5AOgH4BN419LynUeUw/izAXyW1q711sUl1qmwSVYAuAjA6jhy8YbOG9A8Uehya55QNI6/k0cA3AfgZNq6OPIwAC+TXEdybkx5hDpte5TF3tpzaZ38KoBkHwDPArjHzPbFkYOZNZnZODT3rBNJXhB1DiSvA9BgZuuiPnYrJpvZeDSfZt5J8ooYcshp2vb2RFnstQCGprWHANgZ4fFbymgq7Hwj2Q3Nhf6kmT0XZy4AYM1P91mB5msaUecxGcA3SG4H8DSAKSSfiCEPmNlOb9kA4HkAE2PII6dp29sTZbGvBVBJcoQ3S+1NaJ6OOi6RT4XN5kd7PgZgs5k9HFcuJBMkB3hxTwDTAHwQdR5m9oCZDTGzCjT/PrxqZrdGnQfJ3iT7nooBXA3gvajzsLCnbQ/7wkeLCw3XAPgIwMcAfhThcZ8CUAfgOJo/PecAGITmC0NbvGVJBHlcjuZTl40ANng/10SdC4AxANZ7ebwH4EFvfeR/J2k5XYnUBbqo/z7OAfCO97Pp1O9mTL8j4wDUeP82/wVgYL7y0B10Io7QHXQijlCxizhCxS7iCBW7iCNU7CKOULGLOELFLuIIFbuII/4PX0UMsbPbpmEAAAAASUVORK5CYII=\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"show_random_image_from_dataset(val_dataset_5_dots)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Each epoch we calculate β-VAE loss, its kl term, its conditional probability term IWAE estimate of log likelihood.\n", | |
"# It's too computation-heavy (especially IWAE estimate of log likelyhood) to do it on the whole val datasets\n", | |
"# so we take this many random samples from val datasets and calculate on them\n", | |
"VAL_SUBSAMPLING_SIZE = 2560" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_loader = DataLoader(\n", | |
" train_dataset, batch_size=256, shuffle=True, drop_last=True,\n", | |
" pin_memory=device != \"cpu\", num_workers=0\n", | |
")\n", | |
"val_loader_3_dots = DataLoader(\n", | |
" val_dataset_3_dots, batch_size=256, drop_last=True,\n", | |
" pin_memory=device != \"cpu\", num_workers=0, sampler=SubsetRandomSampler(range(VAL_SUBSAMPLING_SIZE))\n", | |
")\n", | |
"val_loader_5_dots = DataLoader(\n", | |
" val_dataset_5_dots, batch_size=256, drop_last=True,\n", | |
" pin_memory=device != \"cpu\", num_workers=0, sampler=SubsetRandomSampler(range(VAL_SUBSAMPLING_SIZE))\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"HID_DIM = 350" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model = SimpleVAE(HID_DIM, 1e-1, 3).to(device)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"optimizer = Adam(model.parameters(), lr=2e-4)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# β hyperparameter of β-VAE loss\n", | |
"BETA = 1.0" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Code for training VAE\n", | |
"\n", | |
"Code below performs VAE training using β-VAE loss. It logs a lot of diagnostic information to tensorboard, including evaluation on 3 dots validation dataset and 5 dots validation dataset each epoch." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def create_trainer(model, optimizer) -> Engine:\n", | |
" def update(engine: Engine, batch: Tuple[torch.Tensor]):\n", | |
" model.train()\n", | |
" optimizer.zero_grad()\n", | |
" x = batch[0].to(device)\n", | |
" q_z_given_x, z_realization, p_x_given_z_realization = model(x, 1)\n", | |
" mean_kl = model.calc_average_kl(q_z_given_x)\n", | |
" assert torch.all(torch.isfinite(mean_kl))\n", | |
" mean_log_cond_prob = model.calc_mean_log_cond_prob(p_x_given_z_realization, x)\n", | |
" assert torch.all(torch.isfinite(mean_log_cond_prob))\n", | |
" loss = calc_beta_vae_loss(mean_kl, mean_log_cond_prob, beta=BETA)\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
" reconstruction = rearrange(p_x_given_z_realization.mean, \"b () ... -> b ...\")\n", | |
" engine.state.metrics[\"mean_kl\"] = float(mean_kl)\n", | |
" engine.state.metrics[\"mean_log_cond_prob\"] = float(mean_log_cond_prob)\n", | |
" engine.state.metrics[\"loss\"] = float(loss)\n", | |
" engine.state.metrics[\"mean_abs_q_z_given_x_expectation\"] = float(torch.mean(torch.abs(q_z_given_x.mean)))\n", | |
" engine.state.metrics[\"mean_q_z_given_x_std\"] = float(torch.mean(q_z_given_x.stddev))\n", | |
" return reconstruction, q_z_given_x\n", | |
" return Engine(update)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 68, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# code for debugging in case we encounter NaNs or infs\n", | |
"def find_tensors(obj):\n", | |
" if isinstance(obj, torch.Tensor):\n", | |
" return (obj,)\n", | |
" elif isinstance(obj, tuple) or isinstance(obj, list):\n", | |
" return tuple(chain.from_iterable(find_tensors(elem) for elem in obj))\n", | |
" else:\n", | |
" return ()\n", | |
"\n", | |
"def break_on_nan(module, input_, output, this_is_forward: bool) -> None:\n", | |
" tensors = chain(find_tensors(input_), find_tensors(output))\n", | |
" for t in tensors:\n", | |
" if not torch.all(torch.isfinite(t)):\n", | |
" trainer.terminate()\n", | |
" breakpoint() \n", | |
"\n", | |
"for module in model.modules():\n", | |
" if not isinstance(module, SimpleVAE):\n", | |
" module.register_forward_hook(partial(break_on_nan, this_is_forward=True))\n", | |
" module.register_backward_hook(partial(break_on_nan, this_is_forward=False))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 80, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"trainer = create_trainer(model, optimizer)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 81, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def add_weights_and_grads_logging(\n", | |
" trainer: Engine, tb_logger: TensorboardLogger, model: nn.Module\n", | |
") -> None:\n", | |
" def abs_mean(tensor: torch.Tensor) -> torch.Tensor:\n", | |
" with torch.no_grad():\n", | |
" return tensor.abs().mean()\n", | |
"\n", | |
" for (handler, event) in (\n", | |
" (WeightsScalarHandler(model, abs_mean), Events.ITERATION_COMPLETED),\n", | |
" (GradsScalarHandler(model, abs_mean), Events.ITERATION_COMPLETED),\n", | |
" (WeightsHistHandler(model), Events.EPOCH_COMPLETED),\n", | |
" (GradsHistHandler(model), Events.EPOCH_COMPLETED)\n", | |
" ):\n", | |
" tb_logger.attach(trainer, handler, event)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 82, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def create_evaluator(model) -> Engine:\n", | |
" def update(evaluator: Engine, batch: Tuple[torch.Tensor]):\n", | |
" model.eval()\n", | |
" x = batch[0].to(device)\n", | |
" with torch.no_grad():\n", | |
" q_z_given_x, z_realization, p_x_given_z_realization = model(x, 1)\n", | |
" mean_kl = model.calc_average_kl(q_z_given_x)\n", | |
" assert torch.all(torch.isfinite(mean_kl))\n", | |
" mean_log_cond_prob = model.calc_mean_log_cond_prob(p_x_given_z_realization, x)\n", | |
" assert torch.all(torch.isfinite(mean_log_cond_prob))\n", | |
" loss = calc_beta_vae_loss(mean_kl, mean_log_cond_prob, beta=BETA)\n", | |
" reconstruction = rearrange(p_x_given_z_realization.mean, \"b () ... -> b ...\")\n", | |
" evaluator.state.metrics[\"mean_kl_history\"].append(float(mean_kl))\n", | |
" evaluator.state.metrics[\"mean_log_cond_prob_history\"].append(float(mean_log_cond_prob))\n", | |
" evaluator.state.metrics[\"loss_history\"].append(float(loss))\n", | |
" evaluator.state.metrics[\"mean_abs_q_z_given_x_expectation_history\"].append(float(torch.mean(torch.abs(q_z_given_x.mean))))\n", | |
" evaluator.state.metrics[\"mean_q_z_given_x_std\"].append(float(torch.mean(q_z_given_x.stddev)))\n", | |
" \n", | |
" # calculate and log mean iwll\n", | |
" evaluator.state.metrics[\"mean_iwll_history\"].append(float(model.calc_mean_iwll(x, num_samples_per_image=10)))\n", | |
" return reconstruction, q_z_given_x\n", | |
" evaluator = Engine(update)\n", | |
" \n", | |
" @evaluator.on(Events.EPOCH_STARTED)\n", | |
" def init_dicts(evaluator: Engine) -> None:\n", | |
" for field in (\n", | |
" \"mean_kl_history\", \"mean_log_cond_prob_history\", \"loss_history\",\n", | |
" \"mean_abs_q_z_given_x_expectation_history\", \"mean_q_z_given_x_std\", \"mean_iwll_history\"\n", | |
" ):\n", | |
" evaluator.state.metrics[field] = []\n", | |
" \n", | |
" return evaluator" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 83, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"evaluator_3_dots = create_evaluator(model)\n", | |
"evaluator_5_dots = create_evaluator(model)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 84, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<ignite.engine.engine.RemovableEventHandle at 0x7f7c92d92438>" | |
] | |
}, | |
"execution_count": 84, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"outputs_dir = os.path.join(OUTPUTS_ROOT_DIR, get_now_as_str(utc=False, seconds=True))\n", | |
"checkpointer = ModelCheckpoint(outputs_dir, \"vae\", save_interval=1, n_saved=50)\n", | |
"trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {\"model\": model, \"optimizer\": optimizer})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 85, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def log_images(tb_logger, engine_for_determining_epoch, tag, engine):\n", | |
" num_images = 32\n", | |
" input_images_grid = make_grid(\n", | |
" engine.state.batch[0][:num_images], nrow=min(8, num_images), normalize=True,\n", | |
" range=pixel_values_range, pad_value=0.0\n", | |
" )\n", | |
" reconstructions_grid = make_grid(\n", | |
" engine.state.output[0][:num_images], nrow=min(8, num_images), normalize=True,\n", | |
" range=pixel_values_range, pad_value=0.0\n", | |
" )\n", | |
" epoch = engine_for_determining_epoch.state.epoch\n", | |
" tb_logger.writer.add_image(f\"epoch/{tag}/input\", input_images_grid, global_step=epoch)\n", | |
" tb_logger.writer.add_image(f\"epoch/{tag}/reconstructions\", reconstructions_grid, global_step=epoch)\n", | |
" tb_logger.writer.add_histogram(\n", | |
" f\"epoch/{tag}/q_z_given_x_expectation\", engine.state.output[1].mean.detach().cpu().numpy(), global_step=epoch\n", | |
" )\n", | |
" tb_logger.writer.add_histogram(\n", | |
" f\"epoch/f{tag}/q_z_given_x_std\", engine.state.output[1].stddev.detach().cpu().numpy(), global_step=epoch\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 86, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def log_evaluator_metrics(tb_logger: TensorboardLogger, engine_for_determining_epoch, tag, evaluator):\n", | |
" for metric_name, records in evaluator.state.metrics.items():\n", | |
" mean = sum(records) / len(records)\n", | |
" s = f\"epoch/{tag}/{metric_name.replace('_history', '')}\"\n", | |
" tb_logger.writer.add_scalar(s, mean, global_step=engine_for_determining_epoch.state.epoch)\n", | |
" if \"iwll\" in s:\n", | |
" print(f\"logged {s} = {mean}, type={type(mean)}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"logged epoch/val_3_dots/mean_iwll = 9991.90634765625, type=<class 'float'>\n", | |
"logged epoch/val_5_dots/mean_iwll = -24523.9267578125, type=<class 'float'>\n" | |
] | |
} | |
], | |
"source": [ | |
"with TensorboardLogger(outputs_dir) as tb_logger:\n", | |
" tb_logger.attach(\n", | |
" trainer,\n", | |
" OutputHandler(\n", | |
" \"iteration/train\",\n", | |
" metric_names=[\"mean_kl\", \"mean_log_cond_prob\", \"loss\", \"mean_abs_q_z_given_x_expectation\", \"mean_q_z_given_x_std\"]\n", | |
" ),\n", | |
" Events.ITERATION_COMPLETED\n", | |
" )\n", | |
" add_weights_and_grads_logging(trainer, tb_logger, model)\n", | |
" trainer.add_event_handler(Events.EPOCH_COMPLETED, partial(log_images, tb_logger, trainer, \"train\"))\n", | |
" evaluator_3_dots.add_event_handler(Events.COMPLETED, partial(log_images, tb_logger, trainer, \"val_3_dots\"))\n", | |
" evaluator_5_dots.add_event_handler(Events.COMPLETED, partial(log_images, tb_logger, trainer, \"val_5_dots\"))\n", | |
" evaluator_3_dots.add_event_handler(Events.COMPLETED, partial(log_evaluator_metrics, tb_logger, trainer, \"val_3_dots\"))\n", | |
" evaluator_5_dots.add_event_handler(Events.COMPLETED, partial(log_evaluator_metrics, tb_logger, trainer, \"val_5_dots\"))\n", | |
" @trainer.on(Events.EPOCH_COMPLETED)\n", | |
" def run_evaluators(trainer):\n", | |
" evaluator_3_dots.run(val_loader_3_dots)\n", | |
" evaluator_5_dots.run(val_loader_5_dots)\n", | |
" @trainer.on(Events.EPOCH_COMPLETED)\n", | |
" def log_unconditional_samples(trainer):\n", | |
" model.eval()\n", | |
" num_images = 32\n", | |
" with torch.no_grad():\n", | |
" grid = make_grid(\n", | |
" model.generate_unconditional_samples(num_images), nrow=min(8, num_images),\n", | |
" normalize=True, range=pixel_values_range, pad_value=0.0\n", | |
" )\n", | |
" tb_logger.writer.add_image(f\"epoch/unconditional_samples\", grid, global_step=trainer.state.epoch)\n", | |
" \n", | |
" trainer.run(train_loader, max_epochs=1200)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Training plots\n", | |
"The task asks us to plot ELBO vs iteration and KL vs iteration.\n", | |
"\n", | |
"Please find \"iteration/train/loss\" in tensorboard - this is negative ELBO vs iteration plot. Also, you can see it on the screenshot below.\n", | |
"\n", | |
"<img src=\"iteration_train_loss.png\"/>\n", | |
"\n", | |
"Please find \"iteration/train/mean_kl\" in tensorboard - this is KL vs iteration plot. Also, you can see it on the screenshot below.\n", | |
"\n", | |
"<img src=\"iteration_train_mean_kl.png\"/>\n", | |
"\n", | |
"You can also check out \"iteration/train/mean_log_cond_prob\" - this is the estimate the non-KL term in ELBO." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### NLL on validation datasets\n", | |
"\n", | |
"We are asked to \"calculate NLL on a test set, contating only 3 dots\" and \"calculate NLL on a test set, containing only 5 dots\".\n", | |
"\n", | |
"Please see the rightmost point on \"epoch/val_3_dots/mean_iwll\" and \"epoch/val_5_dots/mean_iwll\". These plots depict importance weighted (`num_samples_per_image=10`) mean log likelihood, estimated each epoch on a random subset of the two validation datasets. Also, screenshots are attached below.\n", | |
"\n", | |
"<img src=\"epoch_val_3_dots_iwll.png\"/>\n", | |
"\n", | |
"<img src=\"epoch_val_5_dots_iwll.png\"/>\n", | |
"\n", | |
"Values after the last epoch are: -1.8e3 for 3 dots, -5.8e4 for 5 dots." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Sampling\n", | |
"We are also asked to:\n", | |
">sample 25 images from the VAE and draw then on the 5 × 5 grid. Comment on the generalization ability of the model \n", | |
"\n", | |
"In tensorboard see images in \"epoch/unconditional_samples\" and \"epoch/$DATASET_NAME/reconstructions\". Also, see screenshots below.\n", | |
"\n", | |
"These are unconditionally generated samples:\n", | |
"\n", | |
"<img src=\"epoch_unconditional_samples.png\" />\n", | |
"\n", | |
"These are training dataset samples and their reconstructions:\n", | |
"\n", | |
"<img src=\"epoch_train_input_and_reconstructions.png\" />\n", | |
"\n", | |
"These are validation 3 dots samples and their reconstructions:\n", | |
"\n", | |
"<img src=\"epoch_val_3_dots_input_and_reconstructions.png\" />\n", | |
"\n", | |
"These are validation 5 dots samples and their reconstructions:\n", | |
"\n", | |
"<img src=\"epoch_val_5_dots_input_and_reconstructions.png\" />\n", | |
"\n", | |
"Here's my comment on generalization ability:\n", | |
"\n", | |
"Judging by the images, the model didn't overfit - it performs similarly on train dataset and on 3 dots validation dataset.\n", | |
"When it's asked to reconstruct an image from the 5 dots validation dataset, it usually, but not always, successfully preserves all 5 dots. Sometimes one of the brighter dots becomes too bright, or a dot which is near another dot looks broken. It's a pleasant surprise to me that it works this well on 5 dots, considering the fact that the model didn't train on such images. I think the model learned to memorize, for each region of the screen, whether there is a dot there, hence it works fine even for 5 dots.\n", | |
"\n", | |
"Also, unconditionally generated images aren't very good - they are often blurry and tend to contain more than 3 dots." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Bonus task (20 pts max)\n", | |
"\n", | |
"Assume that we want to quantify the generalization ability of the model. To do that, we need to accurately compute number of dots on all the generated images.\n", | |
"\n", | |
"1. Train neural network, which will classify images based on number of dots with high probability (>95%)\n", | |
"2. Generate 1000 images from you VAE\n", | |
"3. Classify generated images and plot the proportion of images with ${1, 2, ..., 12}$ dots in generated sample" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### I didn't do this part" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment