Created
January 23, 2021 13:25
-
-
Save jinglescode/aeb094381364028d370942495c76c2e0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_discriminator_loss(generator, discriminator, criterion, real_samples, n_samples, dim_noise, device): | |
''' | |
Discriminator predict and get loss | |
Parameters: | |
generator: | |
generator network | |
discriminator: | |
discriminator network | |
criterion: | |
loss function, likely `nn.BCEWithLogitsLoss()` | |
real_samples: | |
samples from training dataset | |
n_samples: int | |
number of samples to generate | |
dim_noise: int | |
dimension of noise vector | |
device: string | |
device, cpu or cuda | |
Returns: | |
discriminator_loss: | |
loss scalar | |
''' | |
random_noise = get_noise(n_samples, dim_noise, device=device) | |
generated_samples = generator(random_noise) | |
discriminator_fake_pred = discriminator(generated_samples.detach()) | |
discriminator_fake_loss = criterion(discriminator_fake_pred, torch.zeros_like(discriminator_fake_pred)) | |
discriminator_real_pred = discriminator(real_samples) | |
discriminator_real_loss = criterion(discriminator_real_pred, torch.ones_like(discriminator_real_pred)) | |
discriminator_loss = (discriminator_fake_loss + discriminator_real_loss) / 2 | |
return discriminator_loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment