#%% IMPORTS import torch import pytorch_lightning as pl import matplotlib.pyplot as plt from pytorch_lightning import Trainer from torch.nn import functional as F import pyro import pyro.distributions as dist # %% class CoolSystem(pl.LightningModule): def __init__(self): super(CoolSystem, self).__init__() # not the best model... self.l1 = torch.nn.Linear(1, 1) def forward(self, x): return self.l1(x) def training_step(self, batch, batch_idx): x,y = batch yhat = self.forward(x) loss = (yhat-y).abs().mean() tensorboard_logs = {'train_loss': loss} return {'loss': loss, 'log': tensorboard_logs} def validation_step(self, batch, batch_idx): x,y = batch yhat = self.forward(x) loss = (yhat-y).abs().mean() return {'val_loss': loss} def validation_end(self, outputs): # OPTIONAL avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() tensorboard_logs = {'loss': avg_loss} return {'val_loss': avg_loss, 'log': tensorboard_logs} def configure_optimizers(self): # REQUIRED # can return multiple optimizers and learning_rate schedulers # (LBFGS it is automatically supported, no need for closure function) return torch.optim.Adam(self.parameters(), lr=0.02) @pl.data_loader def train_dataloader(self): x = torch.arange(100).float().view(-1,1) y = 2 + x + torch.distributions.Normal(0,1).sample((len(x),)).view(-1,1)*0.2 ds = torch.utils.data.TensorDataset(x,y) dataloader = torch.utils.data.DataLoader(dataset=ds, batch_size = 2) return dataloader @pl.data_loader def val_dataloader(self): x = torch.arange(10).float().view(-1,1) y = 2 + x + torch.distributions.Normal(0,1).sample((len(x),)).view(-1,1)*0.2 ds = torch.utils.data.TensorDataset(x,y) dataloader = torch.utils.data.DataLoader(dataset=ds, batch_size = 2) return dataloader # %% system = CoolSystem() # most basic trainer, uses good defaults trainer = Trainer(min_epochs=1) trainer.fit(system) # RESULTS list(system.parameters()) # %% PYRO LIGHTNING!! #%% import torch import pytorch_lightning as pl import matplotlib.pyplot as plt from pytorch_lightning import Trainer from torch.nn import functional as F import pyro import pyro.distributions as dist class PyroOptWrap(pyro.infer.SVI): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def state_dict(self,): return {} class PyroCoolSystem(pl.LightningModule): def __init__(self, num_data = 100, lr = 1e-3): super(PyroCoolSystem, self).__init__() self.lr = lr self.num_data =num_data def model(self, batch): x, y = batch yhat = self.forward(x) obsdistr = dist.Normal(yhat, 0.2)#.to_event(1) pyro.sample("obs", obsdistr, obs = y) return yhat def guide(self, batch): b_m = pyro.param("b-mean", torch.tensor(0.1)) a_m = pyro.param("a-mean", torch.tensor(0.1)) b = pyro.sample("beta", dist.Normal(b_m , 0.1)) a = pyro.sample("alpha", dist.Normal(a_m,0.1)) def forward(self, x): b = pyro.sample("beta", dist.Normal(0,1)) a = pyro.sample("alpha", dist.Normal(0,1)) yhat = a + x*b return yhat def training_step(self, batch, batch_idx): #x,y = batch #yhat = self.forward(x) loss = self.svi.step(batch) loss = torch.tensor(loss).requires_grad_(True) tensorboard_logs = {'running/loss': loss, 'param/a-mean': pyro.param("a-mean"), 'param/b-mean': pyro.param("b-mean") } return {'loss': loss, 'log': tensorboard_logs} def validation_step(self, batch, batch_idx): loss = self.svi.evaluate_loss(batch) loss = torch.tensor(loss).requires_grad_(True) return {'val_loss': loss} def validation_end(self, outputs): # OPTIONAL avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() tensorboard_logs = {'val_loss': avg_loss} #print(pyro.param("a-mean"), pyro.param('b-mean')) return {'val_loss': avg_loss, 'log': tensorboard_logs} def configure_optimizers(self): # REQUIRED # can return multiple optimizers and learning_rate schedulers # (LBFGS it is automatically supported, no need for closure function) self.svi = PyroOptWrap(model=self.model, guide=self.guide, optim=pyro.optim.SGD({"lr": self.lr, "momentum":0.0}), loss=pyro.infer.Trace_ELBO()) return [self.svi] @pl.data_loader def train_dataloader(self): x = torch.rand((self.num_data,)).float().view(-1,1) y = 2 + x + torch.distributions.Normal(0,1).sample((len(x),)).view(-1,1)*0.2 ds = torch.utils.data.TensorDataset(x,y) dataloader = torch.utils.data.DataLoader(dataset=ds, batch_size = 2) return dataloader @pl.data_loader def val_dataloader(self): x = torch.rand((100,)).float().view(-1,1) y = 2 + x + torch.distributions.Normal(0,1).sample((len(x),)).view(-1,1)*0.2 ds = torch.utils.data.TensorDataset(x,y) dataloader = torch.utils.data.DataLoader(dataset=ds, batch_size = 10) return dataloader def optimizer_step(self, *args, **kwargs): pass def backward(self, *args, **kwargs): pass # %% pyro.clear_param_store() system = PyroCoolSystem(num_data=2) # most basic trainer, uses good defaults trainer = Trainer(min_epochs=1, max_epochs=100) trainer.fit(system) # %% # %%