#%% IMPORTS
import torch
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from pytorch_lightning import Trainer

from torch.nn import functional as F
import pyro
import pyro.distributions as dist
# %%
class CoolSystem(pl.LightningModule):

    def __init__(self):
        super(CoolSystem, self).__init__()
        # not the best model...
        self.l1 = torch.nn.Linear(1, 1)

    def forward(self, x):
        return self.l1(x)

    def training_step(self, batch, batch_idx):
        x,y = batch
        yhat = self.forward(x)
        
        loss = (yhat-y).abs().mean()
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        x,y = batch
        yhat = self.forward(x)
        
        loss = (yhat-y).abs().mean()
        return {'val_loss': loss}

    def validation_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'loss': avg_loss}
        return {'val_loss': avg_loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        # (LBFGS it is automatically supported, no need for closure function)
        return torch.optim.Adam(self.parameters(), lr=0.02)

    @pl.data_loader
    def train_dataloader(self):
        x = torch.arange(100).float().view(-1,1)
        y = 2 + x + torch.distributions.Normal(0,1).sample((len(x),)).view(-1,1)*0.2

        ds = torch.utils.data.TensorDataset(x,y)
        dataloader = torch.utils.data.DataLoader(dataset=ds, batch_size = 2)
        return dataloader

    @pl.data_loader
    def val_dataloader(self):
        x = torch.arange(10).float().view(-1,1)
        y = 2 + x + torch.distributions.Normal(0,1).sample((len(x),)).view(-1,1)*0.2

        ds = torch.utils.data.TensorDataset(x,y)
        dataloader = torch.utils.data.DataLoader(dataset=ds, batch_size = 2)
        return dataloader

# %%

system = CoolSystem()
# most basic trainer, uses good defaults
trainer = Trainer(min_epochs=1)   
trainer.fit(system)

# RESULTS
list(system.parameters())

# %% PYRO LIGHTNING!!
#%%
import torch
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from pytorch_lightning import Trainer

from torch.nn import functional as F
import pyro
import pyro.distributions as dist
class PyroOptWrap(pyro.infer.SVI):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def state_dict(self,):
        return {}

class PyroCoolSystem(pl.LightningModule):
    def __init__(self, num_data = 100, lr = 1e-3):
        super(PyroCoolSystem, self).__init__()
        self.lr = lr
        self.num_data =num_data

    def model(self, batch):
        x, y = batch
        yhat = self.forward(x)
        obsdistr = dist.Normal(yhat, 0.2)#.to_event(1)
        pyro.sample("obs", obsdistr, obs = y)
        return yhat

    def guide(self, batch):
        b_m = pyro.param("b-mean", torch.tensor(0.1))
        a_m = pyro.param("a-mean", torch.tensor(0.1))
        b = pyro.sample("beta", dist.Normal(b_m , 0.1))
        a = pyro.sample("alpha", dist.Normal(a_m,0.1))

    def forward(self, x):
        b = pyro.sample("beta", dist.Normal(0,1))
        a = pyro.sample("alpha", dist.Normal(0,1))
        yhat = a + x*b
        return yhat

    def training_step(self, batch, batch_idx):
        #x,y = batch
        #yhat = self.forward(x)
        
        loss = self.svi.step(batch)
        loss = torch.tensor(loss).requires_grad_(True)
        tensorboard_logs = {'running/loss': loss, 'param/a-mean': pyro.param("a-mean"), 'param/b-mean': pyro.param("b-mean") }
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        
        loss = self.svi.evaluate_loss(batch)
        loss = torch.tensor(loss).requires_grad_(True)
        return {'val_loss': loss}

    def validation_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        #print(pyro.param("a-mean"), pyro.param('b-mean'))
        return {'val_loss': avg_loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        # (LBFGS it is automatically supported, no need for closure function)
        self.svi = PyroOptWrap(model=self.model,
                guide=self.guide,
                optim=pyro.optim.SGD({"lr": self.lr, "momentum":0.0}),
                loss=pyro.infer.Trace_ELBO())

        return [self.svi]
        
    @pl.data_loader
    def train_dataloader(self):
        x = torch.rand((self.num_data,)).float().view(-1,1)
        y = 2 + x + torch.distributions.Normal(0,1).sample((len(x),)).view(-1,1)*0.2

        ds = torch.utils.data.TensorDataset(x,y)
        dataloader = torch.utils.data.DataLoader(dataset=ds, batch_size = 2)
        return dataloader

    @pl.data_loader
    def val_dataloader(self):
        x = torch.rand((100,)).float().view(-1,1)
        y = 2 + x + torch.distributions.Normal(0,1).sample((len(x),)).view(-1,1)*0.2

        ds = torch.utils.data.TensorDataset(x,y)
        dataloader = torch.utils.data.DataLoader(dataset=ds, batch_size = 10)
        return dataloader

    def optimizer_step(self, *args, **kwargs):
        pass
    def backward(self, *args, **kwargs):
        pass
# %%
pyro.clear_param_store()
system = PyroCoolSystem(num_data=2)
# most basic trainer, uses good defaults
trainer = Trainer(min_epochs=1, max_epochs=100)
trainer.fit(system)


# %%


# %%