|
"""
|
|
Copyright (c) 2024 The D-FINE Authors. All Rights Reserved.
|
|
"""
|
|
|
|
import torch
|
|
|
|
from .box_ops import box_xyxy_to_cxcywh
|
|
|
|
|
|
def weighting_function(reg_max, up, reg_scale, deploy=False):
|
|
"""
|
|
Generates the non-uniform Weighting Function W(n) for bounding box regression.
|
|
|
|
Args:
|
|
reg_max (int): Max number of the discrete bins.
|
|
up (Tensor): Controls upper bounds of the sequence,
|
|
where maximum offset is ±up * H / W.
|
|
reg_scale (float): Controls the curvature of the Weighting Function.
|
|
Larger values result in flatter weights near the central axis W(reg_max/2)=0
|
|
and steeper weights at both ends.
|
|
deploy (bool): If True, uses deployment mode settings.
|
|
|
|
Returns:
|
|
Tensor: Sequence of Weighting Function.
|
|
"""
|
|
if deploy:
|
|
upper_bound1 = (abs(up[0]) * abs(reg_scale)).item()
|
|
upper_bound2 = (abs(up[0]) * abs(reg_scale) * 2).item()
|
|
step = (upper_bound1 + 1) ** (2 / (reg_max - 2))
|
|
left_values = [-((step) ** i) + 1 for i in range(reg_max // 2 - 1, 0, -1)]
|
|
right_values = [(step) ** i - 1 for i in range(1, reg_max // 2)]
|
|
values = (
|
|
[-upper_bound2]
|
|
+ left_values
|
|
+ [torch.zeros_like(up[0][None])]
|
|
+ right_values
|
|
+ [upper_bound2]
|
|
)
|
|
return torch.tensor(values, dtype=up.dtype, device=up.device)
|
|
else:
|
|
upper_bound1 = abs(up[0]) * abs(reg_scale)
|
|
upper_bound2 = abs(up[0]) * abs(reg_scale) * 2
|
|
step = (upper_bound1 + 1) ** (2 / (reg_max - 2))
|
|
left_values = [-((step) ** i) + 1 for i in range(reg_max // 2 - 1, 0, -1)]
|
|
right_values = [(step) ** i - 1 for i in range(1, reg_max // 2)]
|
|
values = (
|
|
[-upper_bound2]
|
|
+ left_values
|
|
+ [torch.zeros_like(up[0][None])]
|
|
+ right_values
|
|
+ [upper_bound2]
|
|
)
|
|
return torch.cat(values, 0)
|
|
|
|
|
|
def translate_gt(gt, reg_max, reg_scale, up):
|
|
"""
|
|
Decodes bounding box ground truth (GT) values into distribution-based GT representations.
|
|
|
|
This function maps continuous GT values into discrete distribution bins, which can be used
|
|
for regression tasks in object detection models. It calculates the indices of the closest
|
|
bins to each GT value and assigns interpolation weights to these bins based on their proximity
|
|
to the GT value.
|
|
|
|
Args:
|
|
gt (Tensor): Ground truth bounding box values, shape (N, ).
|
|
reg_max (int): Maximum number of discrete bins for the distribution.
|
|
reg_scale (float): Controls the curvature of the Weighting Function.
|
|
up (Tensor): Controls the upper bounds of the Weighting Function.
|
|
|
|
Returns:
|
|
Tuple[Tensor, Tensor, Tensor]:
|
|
- indices (Tensor): Index of the left bin closest to each GT value, shape (N, ).
|
|
- weight_right (Tensor): Weight assigned to the right bin, shape (N, ).
|
|
- weight_left (Tensor): Weight assigned to the left bin, shape (N, ).
|
|
"""
|
|
gt = gt.reshape(-1)
|
|
function_values = weighting_function(reg_max, up, reg_scale)
|
|
|
|
|
|
diffs = function_values.unsqueeze(0) - gt.unsqueeze(1)
|
|
mask = diffs <= 0
|
|
closest_left_indices = torch.sum(mask, dim=1) - 1
|
|
|
|
|
|
indices = closest_left_indices.float()
|
|
|
|
weight_right = torch.zeros_like(indices)
|
|
weight_left = torch.zeros_like(indices)
|
|
|
|
valid_idx_mask = (indices >= 0) & (indices < reg_max)
|
|
valid_indices = indices[valid_idx_mask].long()
|
|
|
|
|
|
left_values = function_values[valid_indices]
|
|
right_values = function_values[valid_indices + 1]
|
|
|
|
left_diffs = torch.abs(gt[valid_idx_mask] - left_values)
|
|
right_diffs = torch.abs(right_values - gt[valid_idx_mask])
|
|
|
|
|
|
weight_right[valid_idx_mask] = left_diffs / (left_diffs + right_diffs)
|
|
weight_left[valid_idx_mask] = 1.0 - weight_right[valid_idx_mask]
|
|
|
|
|
|
invalid_idx_mask_neg = indices < 0
|
|
weight_right[invalid_idx_mask_neg] = 0.0
|
|
weight_left[invalid_idx_mask_neg] = 1.0
|
|
indices[invalid_idx_mask_neg] = 0.0
|
|
|
|
invalid_idx_mask_pos = indices >= reg_max
|
|
weight_right[invalid_idx_mask_pos] = 1.0
|
|
weight_left[invalid_idx_mask_pos] = 0.0
|
|
indices[invalid_idx_mask_pos] = reg_max - 0.1
|
|
|
|
return indices, weight_right, weight_left
|
|
|
|
|
|
def distance2bbox(points, distance, reg_scale):
|
|
"""
|
|
Decodes edge-distances into bounding box coordinates.
|
|
|
|
Args:
|
|
points (Tensor): (B, N, 4) or (N, 4) format, representing [x, y, w, h],
|
|
where (x, y) is the center and (w, h) are width and height.
|
|
distance (Tensor): (B, N, 4) or (N, 4), representing distances from the
|
|
point to the left, top, right, and bottom boundaries.
|
|
|
|
reg_scale (float): Controls the curvature of the Weighting Function.
|
|
|
|
Returns:
|
|
Tensor: Bounding boxes in (N, 4) or (B, N, 4) format [cx, cy, w, h].
|
|
"""
|
|
reg_scale = abs(reg_scale)
|
|
x1 = points[..., 0] - (0.5 * reg_scale + distance[..., 0]) * (points[..., 2] / reg_scale)
|
|
y1 = points[..., 1] - (0.5 * reg_scale + distance[..., 1]) * (points[..., 3] / reg_scale)
|
|
x2 = points[..., 0] + (0.5 * reg_scale + distance[..., 2]) * (points[..., 2] / reg_scale)
|
|
y2 = points[..., 1] + (0.5 * reg_scale + distance[..., 3]) * (points[..., 3] / reg_scale)
|
|
|
|
bboxes = torch.stack([x1, y1, x2, y2], -1)
|
|
|
|
return box_xyxy_to_cxcywh(bboxes)
|
|
|
|
|
|
def bbox2distance(points, bbox, reg_max, reg_scale, up, eps=0.1):
|
|
"""
|
|
Converts bounding box coordinates to distances from a reference point.
|
|
|
|
Args:
|
|
points (Tensor): (n, 4) [x, y, w, h], where (x, y) is the center.
|
|
bbox (Tensor): (n, 4) bounding boxes in "xyxy" format.
|
|
reg_max (float): Maximum bin value.
|
|
reg_scale (float): Controling curvarture of W(n).
|
|
up (Tensor): Controling upper bounds of W(n).
|
|
eps (float): Small value to ensure target < reg_max.
|
|
|
|
Returns:
|
|
Tensor: Decoded distances.
|
|
"""
|
|
reg_scale = abs(reg_scale)
|
|
left = (points[:, 0] - bbox[:, 0]) / (points[..., 2] / reg_scale + 1e-16) - 0.5 * reg_scale
|
|
top = (points[:, 1] - bbox[:, 1]) / (points[..., 3] / reg_scale + 1e-16) - 0.5 * reg_scale
|
|
right = (bbox[:, 2] - points[:, 0]) / (points[..., 2] / reg_scale + 1e-16) - 0.5 * reg_scale
|
|
bottom = (bbox[:, 3] - points[:, 1]) / (points[..., 3] / reg_scale + 1e-16) - 0.5 * reg_scale
|
|
four_lens = torch.stack([left, top, right, bottom], -1)
|
|
four_lens, weight_right, weight_left = translate_gt(four_lens, reg_max, reg_scale, up)
|
|
if reg_max is not None:
|
|
four_lens = four_lens.clamp(min=0, max=reg_max - eps)
|
|
return four_lens.reshape(-1).detach(), weight_right.detach(), weight_left.detach()
|
|
|