########################################################################################## import torch import torch.nn.Functional as F import pytorch_lightning as pl ########################################################################################## class FlashModel(pl.LightningModule): def __init__(self, model): super().__init__() self.model = model def validation_step(self, batch, batch_idx): x, y = batch y_hat = self.model(x) val_loss = F.cross_entropy(y_hat, y) self.log("val_loss": val_loss) ########################################################################################## ## Under the Hood for batch in train_dataloader: loss = model.training_step() loss.backward() ##..... if validate_at_some_point: # disable grads + batchnorm + dropout torch.set_grad_enabled(False) model.eval() ##------------------- VAL loop -------------------## for val_batch in model.val_dataloader: val_out = model.validation_step(val_batch) ##------------------- VAL loop -------------------## ## enable grads + batchnorm + dropout torch.set_grad_enabled(True) model.train() ########################################################################################## ## If we need to do something with the validation outputs implement validation_epoch_end() hook. class FlashModel(pl.LightningModule): def __init__(self, model): super().__init__() self.model = model def validation_step(self, batch, batch_idx): x, y = batch y_hat = self.model(x) val_loss = F.cross_entropy(y_hat, y) self.log("val_loss": val_loss) pred = ... return pred ## <- this is the new line here def validation_epoch_end(self, validation_step_outputs): ## <- thi is the new hook here that needs tobe implemented. for preds in validation_step_outputs: ## Do something with "pred" ##########################################################################################