Last active
July 11, 2025 06:02
-
-
Save ag1805x/13c9720860baad32567ab1e38023de01 to your computer and use it in GitHub Desktop.
Code to see how focal loss changes as alpha and gamma are changed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Code to see how focal loss changes as alpha and gamma are changed. | |
# Load libraries | |
import torch | |
import torch.nn.functional as F | |
# Define data to calculate focal loss | |
# includes two correct predictions and two incorrect predictions | |
logits = torch.tensor([-1.5, 1.5, 1.5, -1.5], requires_grad = False) | |
targets = torch.tensor([0, 1, 0, 1], dtype=torch.float32) | |
probs = F.sigmoid(logits) | |
print("LOGITS:: ", logits) | |
print("TARGETS:: ", targets) | |
print("PROBABILITY:: ", probs) | |
# Calculate the BCE loss and the estimated probability of correct class (prediction confidence) | |
BCE_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction = "none") | |
pt = torch.exp(-BCE_loss) | |
print("BCE_LOSS:: ", BCE_loss) | |
print("PREDICTION_CONFIDENCE:: ", pt) | |
# Track change of gamma | |
print("Impact of changing gamma") | |
alpha = 0.25 | |
for gamma in [0.5, 1, 2, 3, 4, 5]: | |
focal_loss = alpha * (1 - pt) ** gamma * BCE_loss | |
print(focal_loss) | |
# Track change of alpha | |
print("Impact of changing alpha") | |
gamma = 2 | |
for alpha in [0, 0.25, 0.5, 0.75, 1, 2]: | |
focal_loss = alpha * (1 - pt) ** gamma * BCE_loss | |
print(focal_loss) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment