Skip to content

Instantly share code, notes, and snippets.

@jinglescode
Created January 23, 2021 13:25

Revisions

  1. jinglescode created this gist Jan 23, 2021.
    32 changes: 32 additions & 0 deletions gan-get_discriminator_loss.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,32 @@
    def get_discriminator_loss(generator, discriminator, criterion, real_samples, n_samples, dim_noise, device):
    '''
    Discriminator predict and get loss
    Parameters:
    generator:
    generator network
    discriminator:
    discriminator network
    criterion:
    loss function, likely `nn.BCEWithLogitsLoss()`
    real_samples:
    samples from training dataset
    n_samples: int
    number of samples to generate
    dim_noise: int
    dimension of noise vector
    device: string
    device, cpu or cuda
    Returns:
    discriminator_loss:
    loss scalar
    '''

    random_noise = get_noise(n_samples, dim_noise, device=device)
    generated_samples = generator(random_noise)
    discriminator_fake_pred = discriminator(generated_samples.detach())
    discriminator_fake_loss = criterion(discriminator_fake_pred, torch.zeros_like(discriminator_fake_pred))
    discriminator_real_pred = discriminator(real_samples)
    discriminator_real_loss = criterion(discriminator_real_pred, torch.ones_like(discriminator_real_pred))
    discriminator_loss = (discriminator_fake_loss + discriminator_real_loss) / 2

    return discriminator_loss