Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,445 Bytes
4c954ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import torch
import torch.nn.functional as F
def cross_entropy_loss_for_junction(logits, positive):
nlogp = -F.log_softmax(logits, dim=1)
loss = (positive * nlogp[:, None, 1] + (1 - positive) * nlogp[:, None, 0])
return loss.mean()
def focal_loss_for_junction(logits, positive, gamma=2.0):
prob = F.softmax(logits, 1)
ce_loss = F.cross_entropy(logits, positive, reduction='none')
p_t = prob[:,1:]*positive + prob[:,:1]*(1-positive)
loss = ce_loss * ((1-p_t)**gamma)
return loss.mean()
def sigmoid_l1_loss(logits, targets, offset = 0.0, mask=None):
logp = torch.sigmoid(logits) + offset
loss = torch.abs(logp-targets)
if mask is not None:
w = mask.mean(3, True).mean(2,True)
w[w==0] = 1
loss = loss*(mask/w)
return loss.mean()
def sigmoid_focal_loss(
inputs: torch.Tensor,
targets: torch.Tensor,
alpha: float = -1,
gamma: float = 2,
reduction: str = "none",
) -> torch.Tensor:
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha: (optional) Weighting factor in range (0,1) to balance
positive vs negative examples. Default = -1 (no weighting).
gamma: Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples.
reduction: 'none' | 'mean' | 'sum'
'none': No reduction will be applied to the output.
'mean': The output will be averaged.
'sum': The output will be summed.
Returns:
Loss tensor with the reduction option applied.
"""
inputs = inputs.float()
targets = targets.float()
p = torch.sigmoid(inputs)
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = p * targets + (1 - p) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
if reduction == "mean":
loss = loss.mean()
elif reduction == "sum":
loss = loss.sum()
return loss |