Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,125 Bytes
e85fecb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
"""
Copyright (c) 2024 The D-FINE Authors. All Rights Reserved.
"""
import torch
from .box_ops import box_xyxy_to_cxcywh
def weighting_function(reg_max, up, reg_scale, deploy=False):
"""
Generates the non-uniform Weighting Function W(n) for bounding box regression.
Args:
reg_max (int): Max number of the discrete bins.
up (Tensor): Controls upper bounds of the sequence,
where maximum offset is ±up * H / W.
reg_scale (float): Controls the curvature of the Weighting Function.
Larger values result in flatter weights near the central axis W(reg_max/2)=0
and steeper weights at both ends.
deploy (bool): If True, uses deployment mode settings.
Returns:
Tensor: Sequence of Weighting Function.
"""
if deploy:
upper_bound1 = (abs(up[0]) * abs(reg_scale)).item()
upper_bound2 = (abs(up[0]) * abs(reg_scale) * 2).item()
step = (upper_bound1 + 1) ** (2 / (reg_max - 2))
left_values = [-((step) ** i) + 1 for i in range(reg_max // 2 - 1, 0, -1)]
right_values = [(step) ** i - 1 for i in range(1, reg_max // 2)]
values = (
[-upper_bound2]
+ left_values
+ [torch.zeros_like(up[0][None])]
+ right_values
+ [upper_bound2]
)
return torch.tensor(values, dtype=up.dtype, device=up.device)
else:
upper_bound1 = abs(up[0]) * abs(reg_scale)
upper_bound2 = abs(up[0]) * abs(reg_scale) * 2
step = (upper_bound1 + 1) ** (2 / (reg_max - 2))
left_values = [-((step) ** i) + 1 for i in range(reg_max // 2 - 1, 0, -1)]
right_values = [(step) ** i - 1 for i in range(1, reg_max // 2)]
values = (
[-upper_bound2]
+ left_values
+ [torch.zeros_like(up[0][None])]
+ right_values
+ [upper_bound2]
)
return torch.cat(values, 0)
def translate_gt(gt, reg_max, reg_scale, up):
"""
Decodes bounding box ground truth (GT) values into distribution-based GT representations.
This function maps continuous GT values into discrete distribution bins, which can be used
for regression tasks in object detection models. It calculates the indices of the closest
bins to each GT value and assigns interpolation weights to these bins based on their proximity
to the GT value.
Args:
gt (Tensor): Ground truth bounding box values, shape (N, ).
reg_max (int): Maximum number of discrete bins for the distribution.
reg_scale (float): Controls the curvature of the Weighting Function.
up (Tensor): Controls the upper bounds of the Weighting Function.
Returns:
Tuple[Tensor, Tensor, Tensor]:
- indices (Tensor): Index of the left bin closest to each GT value, shape (N, ).
- weight_right (Tensor): Weight assigned to the right bin, shape (N, ).
- weight_left (Tensor): Weight assigned to the left bin, shape (N, ).
"""
gt = gt.reshape(-1)
function_values = weighting_function(reg_max, up, reg_scale)
# Find the closest left-side indices for each value
diffs = function_values.unsqueeze(0) - gt.unsqueeze(1)
mask = diffs <= 0
closest_left_indices = torch.sum(mask, dim=1) - 1
# Calculate the weights for the interpolation
indices = closest_left_indices.float()
weight_right = torch.zeros_like(indices)
weight_left = torch.zeros_like(indices)
valid_idx_mask = (indices >= 0) & (indices < reg_max)
valid_indices = indices[valid_idx_mask].long()
# Obtain distances
left_values = function_values[valid_indices]
right_values = function_values[valid_indices + 1]
left_diffs = torch.abs(gt[valid_idx_mask] - left_values)
right_diffs = torch.abs(right_values - gt[valid_idx_mask])
# Valid weights
weight_right[valid_idx_mask] = left_diffs / (left_diffs + right_diffs)
weight_left[valid_idx_mask] = 1.0 - weight_right[valid_idx_mask]
# Invalid weights (out of range)
invalid_idx_mask_neg = indices < 0
weight_right[invalid_idx_mask_neg] = 0.0
weight_left[invalid_idx_mask_neg] = 1.0
indices[invalid_idx_mask_neg] = 0.0
invalid_idx_mask_pos = indices >= reg_max
weight_right[invalid_idx_mask_pos] = 1.0
weight_left[invalid_idx_mask_pos] = 0.0
indices[invalid_idx_mask_pos] = reg_max - 0.1
return indices, weight_right, weight_left
def distance2bbox(points, distance, reg_scale):
"""
Decodes edge-distances into bounding box coordinates.
Args:
points (Tensor): (B, N, 4) or (N, 4) format, representing [x, y, w, h],
where (x, y) is the center and (w, h) are width and height.
distance (Tensor): (B, N, 4) or (N, 4), representing distances from the
point to the left, top, right, and bottom boundaries.
reg_scale (float): Controls the curvature of the Weighting Function.
Returns:
Tensor: Bounding boxes in (N, 4) or (B, N, 4) format [cx, cy, w, h].
"""
reg_scale = abs(reg_scale)
x1 = points[..., 0] - (0.5 * reg_scale + distance[..., 0]) * (points[..., 2] / reg_scale)
y1 = points[..., 1] - (0.5 * reg_scale + distance[..., 1]) * (points[..., 3] / reg_scale)
x2 = points[..., 0] + (0.5 * reg_scale + distance[..., 2]) * (points[..., 2] / reg_scale)
y2 = points[..., 1] + (0.5 * reg_scale + distance[..., 3]) * (points[..., 3] / reg_scale)
bboxes = torch.stack([x1, y1, x2, y2], -1)
return box_xyxy_to_cxcywh(bboxes)
def bbox2distance(points, bbox, reg_max, reg_scale, up, eps=0.1):
"""
Converts bounding box coordinates to distances from a reference point.
Args:
points (Tensor): (n, 4) [x, y, w, h], where (x, y) is the center.
bbox (Tensor): (n, 4) bounding boxes in "xyxy" format.
reg_max (float): Maximum bin value.
reg_scale (float): Controling curvarture of W(n).
up (Tensor): Controling upper bounds of W(n).
eps (float): Small value to ensure target < reg_max.
Returns:
Tensor: Decoded distances.
"""
reg_scale = abs(reg_scale)
left = (points[:, 0] - bbox[:, 0]) / (points[..., 2] / reg_scale + 1e-16) - 0.5 * reg_scale
top = (points[:, 1] - bbox[:, 1]) / (points[..., 3] / reg_scale + 1e-16) - 0.5 * reg_scale
right = (bbox[:, 2] - points[:, 0]) / (points[..., 2] / reg_scale + 1e-16) - 0.5 * reg_scale
bottom = (bbox[:, 3] - points[:, 1]) / (points[..., 3] / reg_scale + 1e-16) - 0.5 * reg_scale
four_lens = torch.stack([left, top, right, bottom], -1)
four_lens, weight_right, weight_left = translate_gt(four_lens, reg_max, reg_scale, up)
if reg_max is not None:
four_lens = four_lens.clamp(min=0, max=reg_max - eps)
return four_lens.reshape(-1).detach(), weight_right.detach(), weight_left.detach()
|