Spaces:
Runtime error
Runtime error
File size: 4,043 Bytes
8075387 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import torch
from torch import nn
class IOULoss(nn.Module):
def __init__(self, loc_loss_type='iou'):
super(IOULoss, self).__init__()
self.loc_loss_type = loc_loss_type
def forward(self, pred, target, weight=None, reduction='sum'):
pred_left = pred[:, 0]
pred_top = pred[:, 1]
pred_right = pred[:, 2]
pred_bottom = pred[:, 3]
target_left = target[:, 0]
target_top = target[:, 1]
target_right = target[:, 2]
target_bottom = target[:, 3]
target_aera = (target_left + target_right) * \
(target_top + target_bottom)
pred_aera = (pred_left + pred_right) * \
(pred_top + pred_bottom)
w_intersect = torch.min(pred_left, target_left) + \
torch.min(pred_right, target_right)
h_intersect = torch.min(pred_bottom, target_bottom) + \
torch.min(pred_top, target_top)
g_w_intersect = torch.max(pred_left, target_left) + \
torch.max(pred_right, target_right)
g_h_intersect = torch.max(pred_bottom, target_bottom) + \
torch.max(pred_top, target_top)
ac_uion = g_w_intersect * g_h_intersect
area_intersect = w_intersect * h_intersect
area_union = target_aera + pred_aera - area_intersect
ious = (area_intersect + 1.0) / (area_union + 1.0)
gious = ious - (ac_uion - area_union) / ac_uion
if self.loc_loss_type == 'iou':
losses = -torch.log(ious)
elif self.loc_loss_type == 'linear_iou':
losses = 1 - ious
elif self.loc_loss_type == 'giou':
losses = 1 - gious
else:
raise NotImplementedError
if weight is not None:
losses = losses * weight
else:
losses = losses
if reduction == 'sum':
return losses.sum()
elif reduction == 'batch':
return losses.sum(dim=[1])
elif reduction == 'none':
return losses
else:
raise NotImplementedError
def giou_loss(
boxes1: torch.Tensor,
boxes2: torch.Tensor,
reduction: str = "none",
eps: float = 1e-7,
) -> torch.Tensor:
"""
Generalized Intersection over Union Loss (Hamid Rezatofighi et. al)
https://arxiv.org/abs/1902.09630
Gradient-friendly IoU loss with an additional penalty that is non-zero when the
boxes do not overlap and scales with the size of their smallest enclosing box.
This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable.
Args:
boxes1, boxes2 (Tensor): box locations in XYXY format, shape (N, 4) or (4,).
reduction: 'none' | 'mean' | 'sum'
'none': No reduction will be applied to the output.
'mean': The output will be averaged.
'sum': The output will be summed.
eps (float): small number to prevent division by zero
"""
x1, y1, x2, y2 = boxes1.unbind(dim=-1)
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)
assert (x2 >= x1).all(), "bad box: x1 larger than x2"
assert (y2 >= y1).all(), "bad box: y1 larger than y2"
# Intersection keypoints
xkis1 = torch.max(x1, x1g)
ykis1 = torch.max(y1, y1g)
xkis2 = torch.min(x2, x2g)
ykis2 = torch.min(y2, y2g)
intsctk = torch.zeros_like(x1)
mask = (ykis2 > ykis1) & (xkis2 > xkis1)
intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk
iouk = intsctk / (unionk + eps)
# smallest enclosing box
xc1 = torch.min(x1, x1g)
yc1 = torch.min(y1, y1g)
xc2 = torch.max(x2, x2g)
yc2 = torch.max(y2, y2g)
area_c = (xc2 - xc1) * (yc2 - yc1)
miouk = iouk - ((area_c - unionk) / (area_c + eps))
loss = 1 - miouk
if reduction == "mean":
loss = loss.mean()
elif reduction == "sum":
loss = loss.sum()
return loss |