joonhyun23452's picture
open proxydet demo
8075387
import torch
from torch.nn import functional as F
# TODO: merge these two function
def heatmap_focal_loss(
inputs,
targets,
pos_inds,
labels,
alpha: float = -1,
beta: float = 4,
gamma: float = 2,
reduction: str = 'sum',
sigmoid_clamp: float = 1e-4,
ignore_high_fp: float = -1.,
):
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs: (sum_l N*Hl*Wl, C)
targets: (sum_l N*Hl*Wl, C)
pos_inds: N
labels: N
Returns:
Loss tensor with the reduction option applied.
"""
pred = torch.clamp(inputs.sigmoid_(), min=sigmoid_clamp, max=1-sigmoid_clamp)
neg_weights = torch.pow(1 - targets, beta)
pos_pred_pix = pred[pos_inds] # N x C
pos_pred = pos_pred_pix.gather(1, labels.unsqueeze(1))
pos_loss = torch.log(pos_pred) * torch.pow(1 - pos_pred, gamma)
neg_loss = torch.log(1 - pred) * torch.pow(pred, gamma) * neg_weights
if ignore_high_fp > 0:
not_high_fp = (pred < ignore_high_fp).float()
neg_loss = not_high_fp * neg_loss
if reduction == "sum":
pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum()
if alpha >= 0:
pos_loss = alpha * pos_loss
neg_loss = (1 - alpha) * neg_loss
return - pos_loss, - neg_loss
heatmap_focal_loss_jit = torch.jit.script(heatmap_focal_loss)
# heatmap_focal_loss_jit = heatmap_focal_loss
def binary_heatmap_focal_loss(
inputs,
targets,
pos_inds,
alpha: float = -1,
beta: float = 4,
gamma: float = 2,
sigmoid_clamp: float = 1e-4,
ignore_high_fp: float = -1.,
):
"""
Args:
inputs: (sum_l N*Hl*Wl,)
targets: (sum_l N*Hl*Wl,)
pos_inds: N
Returns:
Loss tensor with the reduction option applied.
"""
pred = torch.clamp(inputs.sigmoid_(), min=sigmoid_clamp, max=1-sigmoid_clamp)
neg_weights = torch.pow(1 - targets, beta)
pos_pred = pred[pos_inds] # N
pos_loss = torch.log(pos_pred) * torch.pow(1 - pos_pred, gamma)
neg_loss = torch.log(1 - pred) * torch.pow(pred, gamma) * neg_weights
if ignore_high_fp > 0:
not_high_fp = (pred < ignore_high_fp).float()
neg_loss = not_high_fp * neg_loss
pos_loss = - pos_loss.sum()
neg_loss = - neg_loss.sum()
if alpha >= 0:
pos_loss = alpha * pos_loss
neg_loss = (1 - alpha) * neg_loss
return pos_loss, neg_loss
binary_heatmap_focal_loss_jit = torch.jit.script(binary_heatmap_focal_loss)