D-FINE / src /zoo /dfine /dfine_utils.py
developer0hye's picture
Upload 76 files
e85fecb verified
"""
Copyright (c) 2024 The D-FINE Authors. All Rights Reserved.
"""
import torch
from .box_ops import box_xyxy_to_cxcywh
def weighting_function(reg_max, up, reg_scale, deploy=False):
"""
Generates the non-uniform Weighting Function W(n) for bounding box regression.
Args:
reg_max (int): Max number of the discrete bins.
up (Tensor): Controls upper bounds of the sequence,
where maximum offset is ±up * H / W.
reg_scale (float): Controls the curvature of the Weighting Function.
Larger values result in flatter weights near the central axis W(reg_max/2)=0
and steeper weights at both ends.
deploy (bool): If True, uses deployment mode settings.
Returns:
Tensor: Sequence of Weighting Function.
"""
if deploy:
upper_bound1 = (abs(up[0]) * abs(reg_scale)).item()
upper_bound2 = (abs(up[0]) * abs(reg_scale) * 2).item()
step = (upper_bound1 + 1) ** (2 / (reg_max - 2))
left_values = [-((step) ** i) + 1 for i in range(reg_max // 2 - 1, 0, -1)]
right_values = [(step) ** i - 1 for i in range(1, reg_max // 2)]
values = (
[-upper_bound2]
+ left_values
+ [torch.zeros_like(up[0][None])]
+ right_values
+ [upper_bound2]
)
return torch.tensor(values, dtype=up.dtype, device=up.device)
else:
upper_bound1 = abs(up[0]) * abs(reg_scale)
upper_bound2 = abs(up[0]) * abs(reg_scale) * 2
step = (upper_bound1 + 1) ** (2 / (reg_max - 2))
left_values = [-((step) ** i) + 1 for i in range(reg_max // 2 - 1, 0, -1)]
right_values = [(step) ** i - 1 for i in range(1, reg_max // 2)]
values = (
[-upper_bound2]
+ left_values
+ [torch.zeros_like(up[0][None])]
+ right_values
+ [upper_bound2]
)
return torch.cat(values, 0)
def translate_gt(gt, reg_max, reg_scale, up):
"""
Decodes bounding box ground truth (GT) values into distribution-based GT representations.
This function maps continuous GT values into discrete distribution bins, which can be used
for regression tasks in object detection models. It calculates the indices of the closest
bins to each GT value and assigns interpolation weights to these bins based on their proximity
to the GT value.
Args:
gt (Tensor): Ground truth bounding box values, shape (N, ).
reg_max (int): Maximum number of discrete bins for the distribution.
reg_scale (float): Controls the curvature of the Weighting Function.
up (Tensor): Controls the upper bounds of the Weighting Function.
Returns:
Tuple[Tensor, Tensor, Tensor]:
- indices (Tensor): Index of the left bin closest to each GT value, shape (N, ).
- weight_right (Tensor): Weight assigned to the right bin, shape (N, ).
- weight_left (Tensor): Weight assigned to the left bin, shape (N, ).
"""
gt = gt.reshape(-1)
function_values = weighting_function(reg_max, up, reg_scale)
# Find the closest left-side indices for each value
diffs = function_values.unsqueeze(0) - gt.unsqueeze(1)
mask = diffs <= 0
closest_left_indices = torch.sum(mask, dim=1) - 1
# Calculate the weights for the interpolation
indices = closest_left_indices.float()
weight_right = torch.zeros_like(indices)
weight_left = torch.zeros_like(indices)
valid_idx_mask = (indices >= 0) & (indices < reg_max)
valid_indices = indices[valid_idx_mask].long()
# Obtain distances
left_values = function_values[valid_indices]
right_values = function_values[valid_indices + 1]
left_diffs = torch.abs(gt[valid_idx_mask] - left_values)
right_diffs = torch.abs(right_values - gt[valid_idx_mask])
# Valid weights
weight_right[valid_idx_mask] = left_diffs / (left_diffs + right_diffs)
weight_left[valid_idx_mask] = 1.0 - weight_right[valid_idx_mask]
# Invalid weights (out of range)
invalid_idx_mask_neg = indices < 0
weight_right[invalid_idx_mask_neg] = 0.0
weight_left[invalid_idx_mask_neg] = 1.0
indices[invalid_idx_mask_neg] = 0.0
invalid_idx_mask_pos = indices >= reg_max
weight_right[invalid_idx_mask_pos] = 1.0
weight_left[invalid_idx_mask_pos] = 0.0
indices[invalid_idx_mask_pos] = reg_max - 0.1
return indices, weight_right, weight_left
def distance2bbox(points, distance, reg_scale):
"""
Decodes edge-distances into bounding box coordinates.
Args:
points (Tensor): (B, N, 4) or (N, 4) format, representing [x, y, w, h],
where (x, y) is the center and (w, h) are width and height.
distance (Tensor): (B, N, 4) or (N, 4), representing distances from the
point to the left, top, right, and bottom boundaries.
reg_scale (float): Controls the curvature of the Weighting Function.
Returns:
Tensor: Bounding boxes in (N, 4) or (B, N, 4) format [cx, cy, w, h].
"""
reg_scale = abs(reg_scale)
x1 = points[..., 0] - (0.5 * reg_scale + distance[..., 0]) * (points[..., 2] / reg_scale)
y1 = points[..., 1] - (0.5 * reg_scale + distance[..., 1]) * (points[..., 3] / reg_scale)
x2 = points[..., 0] + (0.5 * reg_scale + distance[..., 2]) * (points[..., 2] / reg_scale)
y2 = points[..., 1] + (0.5 * reg_scale + distance[..., 3]) * (points[..., 3] / reg_scale)
bboxes = torch.stack([x1, y1, x2, y2], -1)
return box_xyxy_to_cxcywh(bboxes)
def bbox2distance(points, bbox, reg_max, reg_scale, up, eps=0.1):
"""
Converts bounding box coordinates to distances from a reference point.
Args:
points (Tensor): (n, 4) [x, y, w, h], where (x, y) is the center.
bbox (Tensor): (n, 4) bounding boxes in "xyxy" format.
reg_max (float): Maximum bin value.
reg_scale (float): Controling curvarture of W(n).
up (Tensor): Controling upper bounds of W(n).
eps (float): Small value to ensure target < reg_max.
Returns:
Tensor: Decoded distances.
"""
reg_scale = abs(reg_scale)
left = (points[:, 0] - bbox[:, 0]) / (points[..., 2] / reg_scale + 1e-16) - 0.5 * reg_scale
top = (points[:, 1] - bbox[:, 1]) / (points[..., 3] / reg_scale + 1e-16) - 0.5 * reg_scale
right = (bbox[:, 2] - points[:, 0]) / (points[..., 2] / reg_scale + 1e-16) - 0.5 * reg_scale
bottom = (bbox[:, 3] - points[:, 1]) / (points[..., 3] / reg_scale + 1e-16) - 0.5 * reg_scale
four_lens = torch.stack([left, top, right, bottom], -1)
four_lens, weight_right, weight_left = translate_gt(four_lens, reg_max, reg_scale, up)
if reg_max is not None:
four_lens = four_lens.clamp(min=0, max=reg_max - eps)
return four_lens.reshape(-1).detach(), weight_right.detach(), weight_left.detach()