Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Copyright (c) 2024 The D-FINE Authors. All Rights Reserved. | |
""" | |
import torch | |
from .box_ops import box_xyxy_to_cxcywh | |
def weighting_function(reg_max, up, reg_scale, deploy=False): | |
""" | |
Generates the non-uniform Weighting Function W(n) for bounding box regression. | |
Args: | |
reg_max (int): Max number of the discrete bins. | |
up (Tensor): Controls upper bounds of the sequence, | |
where maximum offset is ±up * H / W. | |
reg_scale (float): Controls the curvature of the Weighting Function. | |
Larger values result in flatter weights near the central axis W(reg_max/2)=0 | |
and steeper weights at both ends. | |
deploy (bool): If True, uses deployment mode settings. | |
Returns: | |
Tensor: Sequence of Weighting Function. | |
""" | |
if deploy: | |
upper_bound1 = (abs(up[0]) * abs(reg_scale)).item() | |
upper_bound2 = (abs(up[0]) * abs(reg_scale) * 2).item() | |
step = (upper_bound1 + 1) ** (2 / (reg_max - 2)) | |
left_values = [-((step) ** i) + 1 for i in range(reg_max // 2 - 1, 0, -1)] | |
right_values = [(step) ** i - 1 for i in range(1, reg_max // 2)] | |
values = ( | |
[-upper_bound2] | |
+ left_values | |
+ [torch.zeros_like(up[0][None])] | |
+ right_values | |
+ [upper_bound2] | |
) | |
return torch.tensor(values, dtype=up.dtype, device=up.device) | |
else: | |
upper_bound1 = abs(up[0]) * abs(reg_scale) | |
upper_bound2 = abs(up[0]) * abs(reg_scale) * 2 | |
step = (upper_bound1 + 1) ** (2 / (reg_max - 2)) | |
left_values = [-((step) ** i) + 1 for i in range(reg_max // 2 - 1, 0, -1)] | |
right_values = [(step) ** i - 1 for i in range(1, reg_max // 2)] | |
values = ( | |
[-upper_bound2] | |
+ left_values | |
+ [torch.zeros_like(up[0][None])] | |
+ right_values | |
+ [upper_bound2] | |
) | |
return torch.cat(values, 0) | |
def translate_gt(gt, reg_max, reg_scale, up): | |
""" | |
Decodes bounding box ground truth (GT) values into distribution-based GT representations. | |
This function maps continuous GT values into discrete distribution bins, which can be used | |
for regression tasks in object detection models. It calculates the indices of the closest | |
bins to each GT value and assigns interpolation weights to these bins based on their proximity | |
to the GT value. | |
Args: | |
gt (Tensor): Ground truth bounding box values, shape (N, ). | |
reg_max (int): Maximum number of discrete bins for the distribution. | |
reg_scale (float): Controls the curvature of the Weighting Function. | |
up (Tensor): Controls the upper bounds of the Weighting Function. | |
Returns: | |
Tuple[Tensor, Tensor, Tensor]: | |
- indices (Tensor): Index of the left bin closest to each GT value, shape (N, ). | |
- weight_right (Tensor): Weight assigned to the right bin, shape (N, ). | |
- weight_left (Tensor): Weight assigned to the left bin, shape (N, ). | |
""" | |
gt = gt.reshape(-1) | |
function_values = weighting_function(reg_max, up, reg_scale) | |
# Find the closest left-side indices for each value | |
diffs = function_values.unsqueeze(0) - gt.unsqueeze(1) | |
mask = diffs <= 0 | |
closest_left_indices = torch.sum(mask, dim=1) - 1 | |
# Calculate the weights for the interpolation | |
indices = closest_left_indices.float() | |
weight_right = torch.zeros_like(indices) | |
weight_left = torch.zeros_like(indices) | |
valid_idx_mask = (indices >= 0) & (indices < reg_max) | |
valid_indices = indices[valid_idx_mask].long() | |
# Obtain distances | |
left_values = function_values[valid_indices] | |
right_values = function_values[valid_indices + 1] | |
left_diffs = torch.abs(gt[valid_idx_mask] - left_values) | |
right_diffs = torch.abs(right_values - gt[valid_idx_mask]) | |
# Valid weights | |
weight_right[valid_idx_mask] = left_diffs / (left_diffs + right_diffs) | |
weight_left[valid_idx_mask] = 1.0 - weight_right[valid_idx_mask] | |
# Invalid weights (out of range) | |
invalid_idx_mask_neg = indices < 0 | |
weight_right[invalid_idx_mask_neg] = 0.0 | |
weight_left[invalid_idx_mask_neg] = 1.0 | |
indices[invalid_idx_mask_neg] = 0.0 | |
invalid_idx_mask_pos = indices >= reg_max | |
weight_right[invalid_idx_mask_pos] = 1.0 | |
weight_left[invalid_idx_mask_pos] = 0.0 | |
indices[invalid_idx_mask_pos] = reg_max - 0.1 | |
return indices, weight_right, weight_left | |
def distance2bbox(points, distance, reg_scale): | |
""" | |
Decodes edge-distances into bounding box coordinates. | |
Args: | |
points (Tensor): (B, N, 4) or (N, 4) format, representing [x, y, w, h], | |
where (x, y) is the center and (w, h) are width and height. | |
distance (Tensor): (B, N, 4) or (N, 4), representing distances from the | |
point to the left, top, right, and bottom boundaries. | |
reg_scale (float): Controls the curvature of the Weighting Function. | |
Returns: | |
Tensor: Bounding boxes in (N, 4) or (B, N, 4) format [cx, cy, w, h]. | |
""" | |
reg_scale = abs(reg_scale) | |
x1 = points[..., 0] - (0.5 * reg_scale + distance[..., 0]) * (points[..., 2] / reg_scale) | |
y1 = points[..., 1] - (0.5 * reg_scale + distance[..., 1]) * (points[..., 3] / reg_scale) | |
x2 = points[..., 0] + (0.5 * reg_scale + distance[..., 2]) * (points[..., 2] / reg_scale) | |
y2 = points[..., 1] + (0.5 * reg_scale + distance[..., 3]) * (points[..., 3] / reg_scale) | |
bboxes = torch.stack([x1, y1, x2, y2], -1) | |
return box_xyxy_to_cxcywh(bboxes) | |
def bbox2distance(points, bbox, reg_max, reg_scale, up, eps=0.1): | |
""" | |
Converts bounding box coordinates to distances from a reference point. | |
Args: | |
points (Tensor): (n, 4) [x, y, w, h], where (x, y) is the center. | |
bbox (Tensor): (n, 4) bounding boxes in "xyxy" format. | |
reg_max (float): Maximum bin value. | |
reg_scale (float): Controling curvarture of W(n). | |
up (Tensor): Controling upper bounds of W(n). | |
eps (float): Small value to ensure target < reg_max. | |
Returns: | |
Tensor: Decoded distances. | |
""" | |
reg_scale = abs(reg_scale) | |
left = (points[:, 0] - bbox[:, 0]) / (points[..., 2] / reg_scale + 1e-16) - 0.5 * reg_scale | |
top = (points[:, 1] - bbox[:, 1]) / (points[..., 3] / reg_scale + 1e-16) - 0.5 * reg_scale | |
right = (bbox[:, 2] - points[:, 0]) / (points[..., 2] / reg_scale + 1e-16) - 0.5 * reg_scale | |
bottom = (bbox[:, 3] - points[:, 1]) / (points[..., 3] / reg_scale + 1e-16) - 0.5 * reg_scale | |
four_lens = torch.stack([left, top, right, bottom], -1) | |
four_lens, weight_right, weight_left = translate_gt(four_lens, reg_max, reg_scale, up) | |
if reg_max is not None: | |
four_lens = four_lens.clamp(min=0, max=reg_max - eps) | |
return four_lens.reshape(-1).detach(), weight_right.detach(), weight_left.detach() | |