Skip to content

Instantly share code, notes, and snippets.

@ag1805x
Last active July 11, 2025 06:02
Show Gist options
  • Save ag1805x/13c9720860baad32567ab1e38023de01 to your computer and use it in GitHub Desktop.
Save ag1805x/13c9720860baad32567ab1e38023de01 to your computer and use it in GitHub Desktop.
Code to see how focal loss changes as alpha and gamma are changed.
# Code to see how focal loss changes as alpha and gamma are changed.
# Load libraries
import torch
import torch.nn.functional as F
# Define data to calculate focal loss
# includes two correct predictions and two incorrect predictions
logits = torch.tensor([-1.5, 1.5, 1.5, -1.5], requires_grad = False)
targets = torch.tensor([0, 1, 0, 1], dtype=torch.float32)
probs = F.sigmoid(logits)
print("LOGITS:: ", logits)
print("TARGETS:: ", targets)
print("PROBABILITY:: ", probs)
# Calculate the BCE loss and the estimated probability of correct class (prediction confidence)
BCE_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction = "none")
pt = torch.exp(-BCE_loss)
print("BCE_LOSS:: ", BCE_loss)
print("PREDICTION_CONFIDENCE:: ", pt)
# Track change of gamma
print("Impact of changing gamma")
alpha = 0.25
for gamma in [0.5, 1, 2, 3, 4, 5]:
focal_loss = alpha * (1 - pt) ** gamma * BCE_loss
print(focal_loss)
# Track change of alpha
print("Impact of changing alpha")
gamma = 2
for alpha in [0, 0.25, 0.5, 0.75, 1, 2]:
focal_loss = alpha * (1 - pt) ** gamma * BCE_loss
print(focal_loss)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment