Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn.functional as F | |
def cross_entropy_loss_for_junction(logits, positive): | |
nlogp = -F.log_softmax(logits, dim=1) | |
loss = (positive * nlogp[:, None, 1] + (1 - positive) * nlogp[:, None, 0]) | |
return loss.mean() | |
def focal_loss_for_junction(logits, positive, gamma=2.0): | |
prob = F.softmax(logits, 1) | |
ce_loss = F.cross_entropy(logits, positive, reduction='none') | |
p_t = prob[:,1:]*positive + prob[:,:1]*(1-positive) | |
loss = ce_loss * ((1-p_t)**gamma) | |
return loss.mean() | |
def sigmoid_l1_loss(logits, targets, offset = 0.0, mask=None): | |
logp = torch.sigmoid(logits) + offset | |
loss = torch.abs(logp-targets) | |
if mask is not None: | |
w = mask.mean(3, True).mean(2,True) | |
w[w==0] = 1 | |
loss = loss*(mask/w) | |
return loss.mean() | |
def sigmoid_focal_loss( | |
inputs: torch.Tensor, | |
targets: torch.Tensor, | |
alpha: float = -1, | |
gamma: float = 2, | |
reduction: str = "none", | |
) -> torch.Tensor: | |
""" | |
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. | |
Args: | |
inputs: A float tensor of arbitrary shape. | |
The predictions for each example. | |
targets: A float tensor with the same shape as inputs. Stores the binary | |
classification label for each element in inputs | |
(0 for the negative class and 1 for the positive class). | |
alpha: (optional) Weighting factor in range (0,1) to balance | |
positive vs negative examples. Default = -1 (no weighting). | |
gamma: Exponent of the modulating factor (1 - p_t) to | |
balance easy vs hard examples. | |
reduction: 'none' | 'mean' | 'sum' | |
'none': No reduction will be applied to the output. | |
'mean': The output will be averaged. | |
'sum': The output will be summed. | |
Returns: | |
Loss tensor with the reduction option applied. | |
""" | |
inputs = inputs.float() | |
targets = targets.float() | |
p = torch.sigmoid(inputs) | |
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") | |
p_t = p * targets + (1 - p) * (1 - targets) | |
loss = ce_loss * ((1 - p_t) ** gamma) | |
if alpha >= 0: | |
alpha_t = alpha * targets + (1 - alpha) * (1 - targets) | |
loss = alpha_t * loss | |
if reduction == "mean": | |
loss = loss.mean() | |
elif reduction == "sum": | |
loss = loss.sum() | |
return loss |