Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class BCEWithLogitsLoss(nn.BCEWithLogitsLoss): | |
def __init__(self, label_smoothing=0.0, **kwargs): | |
super(BCEWithLogitsLoss, self).__init__(**kwargs) | |
self.label_smoothing = label_smoothing | |
def forward(self, input, target): | |
if self.label_smoothing: | |
target = target * (1.0 - self.label_smoothing) + 0.5 * self.label_smoothing | |
return super(BCEWithLogitsLoss, self).forward(input, target) | |
class SigmoidFocalLoss(nn.Module): | |
def __init__(self, alpha=1, gamma=2, label_smoothing=0.0, reduction="mean"): | |
""" | |
Args: | |
alpha (float): Weighting factor in range (0,1) to balance positive vs negative examples. | |
gamma (float): Focusing parameter to reduce the relative loss for well-classified examples. | |
label_smoothing (float): Label smoothing factor to reduce the confidence of the true label. | |
reduction (str): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. | |
'none': no reduction will be applied, | |
'mean': the sum of the output will be divided by the number of elements in the output, | |
'sum': the output will be summed. | |
""" | |
super(SigmoidFocalLoss, self).__init__() | |
self.alpha = alpha | |
self.gamma = gamma | |
self.label_smoothing = label_smoothing | |
self.reduction = reduction | |
def forward(self, input, target): | |
""" | |
Args: | |
input (Tensor): Predicted logits for each example. | |
target (Tensor): Ground truth binary labels (0 or 1) for each example. | |
""" | |
if self.label_smoothing: | |
target = target * (1.0 - self.label_smoothing) + 0.5 * self.label_smoothing | |
p = torch.sigmoid(input) | |
ce_loss = F.binary_cross_entropy_with_logits(input, target, reduction="none") | |
p_t = p * target + (1 - p) * (1 - target) | |
loss = ce_loss * ((1 - p_t) ** self.gamma) | |
if self.alpha >= 0: | |
alpha_t = self.alpha * target + (1 - self.alpha) * (1 - target) | |
loss = alpha_t * loss | |
# Check reduction option and return loss accordingly | |
if self.reduction == "none": | |
pass | |
elif self.reduction == "mean": | |
loss = loss.mean() | |
elif self.reduction == "sum": | |
loss = loss.sum() | |
else: | |
raise ValueError( | |
f"Invalid Value for arg 'reduction': '{self.reduction} \n Supported reduction modes: 'none', 'mean', 'sum'" | |
) | |
return loss | |