|
"""
|
|
D-FINE: Redefine Regression Task of DETRs as Fine-grained Distribution Refinement
|
|
Copyright (c) 2024 The D-FINE Authors. All Rights Reserved.
|
|
---------------------------------------------------------------------------------
|
|
Modified from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
|
Copyright (c) 2023 lyuwenyu. All Rights Reserved.
|
|
"""
|
|
|
|
import math
|
|
from typing import List
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def inverse_sigmoid(x: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
|
|
x = x.clip(min=0.0, max=1.0)
|
|
return torch.log(x.clip(min=eps) / (1 - x).clip(min=eps))
|
|
|
|
|
|
def bias_init_with_prob(prior_prob=0.01):
|
|
"""initialize conv/fc bias value according to a given probability value."""
|
|
bias_init = float(-math.log((1 - prior_prob) / prior_prob))
|
|
return bias_init
|
|
|
|
|
|
def deformable_attention_core_func(
|
|
value, value_spatial_shapes, sampling_locations, attention_weights
|
|
):
|
|
"""
|
|
Args:
|
|
value (Tensor): [bs, value_length, n_head, c]
|
|
value_spatial_shapes (Tensor|List): [n_levels, 2]
|
|
value_level_start_index (Tensor|List): [n_levels]
|
|
sampling_locations (Tensor): [bs, query_length, n_head, n_levels, n_points, 2]
|
|
attention_weights (Tensor): [bs, query_length, n_head, n_levels, n_points]
|
|
|
|
Returns:
|
|
output (Tensor): [bs, Length_{query}, C]
|
|
"""
|
|
bs, _, n_head, c = value.shape
|
|
_, Len_q, _, n_levels, n_points, _ = sampling_locations.shape
|
|
|
|
split_shape = [h * w for h, w in value_spatial_shapes]
|
|
value_list = value.split(split_shape, dim=1)
|
|
sampling_grids = 2 * sampling_locations - 1
|
|
sampling_value_list = []
|
|
for level, (h, w) in enumerate(value_spatial_shapes):
|
|
|
|
value_l_ = value_list[level].flatten(2).permute(0, 2, 1).reshape(bs * n_head, c, h, w)
|
|
|
|
sampling_grid_l_ = sampling_grids[:, :, :, level].permute(0, 2, 1, 3, 4).flatten(0, 1)
|
|
|
|
sampling_value_l_ = F.grid_sample(
|
|
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
|
|
)
|
|
sampling_value_list.append(sampling_value_l_)
|
|
|
|
attention_weights = attention_weights.permute(0, 2, 1, 3, 4).reshape(
|
|
bs * n_head, 1, Len_q, n_levels * n_points
|
|
)
|
|
output = (
|
|
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
|
|
.sum(-1)
|
|
.reshape(bs, n_head * c, Len_q)
|
|
)
|
|
|
|
return output.permute(0, 2, 1)
|
|
|
|
|
|
def deformable_attention_core_func_v2(
|
|
value: torch.Tensor,
|
|
value_spatial_shapes,
|
|
sampling_locations: torch.Tensor,
|
|
attention_weights: torch.Tensor,
|
|
num_points_list: List[int],
|
|
method="default",
|
|
):
|
|
"""
|
|
Args:
|
|
value (Tensor): [bs, value_length, n_head, c]
|
|
value_spatial_shapes (Tensor|List): [n_levels, 2]
|
|
value_level_start_index (Tensor|List): [n_levels]
|
|
sampling_locations (Tensor): [bs, query_length, n_head, n_levels * n_points, 2]
|
|
attention_weights (Tensor): [bs, query_length, n_head, n_levels * n_points]
|
|
|
|
Returns:
|
|
output (Tensor): [bs, Length_{query}, C]
|
|
"""
|
|
bs, n_head, c, _ = value[0].shape
|
|
_, Len_q, _, _, _ = sampling_locations.shape
|
|
|
|
|
|
if method == "default":
|
|
sampling_grids = 2 * sampling_locations - 1
|
|
|
|
elif method == "discrete":
|
|
sampling_grids = sampling_locations
|
|
|
|
sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
|
sampling_locations_list = sampling_grids.split(num_points_list, dim=-2)
|
|
|
|
sampling_value_list = []
|
|
for level, (h, w) in enumerate(value_spatial_shapes):
|
|
value_l = value[level].reshape(bs * n_head, c, h, w)
|
|
sampling_grid_l: torch.Tensor = sampling_locations_list[level]
|
|
|
|
if method == "default":
|
|
sampling_value_l = F.grid_sample(
|
|
value_l, sampling_grid_l, mode="bilinear", padding_mode="zeros", align_corners=False
|
|
)
|
|
|
|
elif method == "discrete":
|
|
|
|
sampling_coord = (
|
|
sampling_grid_l * torch.tensor([[w, h]], device=value_l.device) + 0.5
|
|
).to(torch.int64)
|
|
|
|
|
|
sampling_coord = sampling_coord.clamp(0, h - 1)
|
|
sampling_coord = sampling_coord.reshape(bs * n_head, Len_q * num_points_list[level], 2)
|
|
|
|
s_idx = (
|
|
torch.arange(sampling_coord.shape[0], device=value_l.device)
|
|
.unsqueeze(-1)
|
|
.repeat(1, sampling_coord.shape[1])
|
|
)
|
|
sampling_value_l: torch.Tensor = value_l[
|
|
s_idx, :, sampling_coord[..., 1], sampling_coord[..., 0]
|
|
]
|
|
|
|
sampling_value_l = sampling_value_l.permute(0, 2, 1).reshape(
|
|
bs * n_head, c, Len_q, num_points_list[level]
|
|
)
|
|
|
|
sampling_value_list.append(sampling_value_l)
|
|
|
|
attn_weights = attention_weights.permute(0, 2, 1, 3).reshape(
|
|
bs * n_head, 1, Len_q, sum(num_points_list)
|
|
)
|
|
weighted_sample_locs = torch.concat(sampling_value_list, dim=-1) * attn_weights
|
|
output = weighted_sample_locs.sum(-1).reshape(bs, n_head * c, Len_q)
|
|
|
|
return output.permute(0, 2, 1)
|
|
|
|
|
|
def get_activation(act: str, inpace: bool = True):
|
|
"""get activation"""
|
|
if act is None:
|
|
return nn.Identity()
|
|
|
|
elif isinstance(act, nn.Module):
|
|
return act
|
|
|
|
act = act.lower()
|
|
|
|
if act == "silu" or act == "swish":
|
|
m = nn.SiLU()
|
|
|
|
elif act == "relu":
|
|
m = nn.ReLU()
|
|
|
|
elif act == "leaky_relu":
|
|
m = nn.LeakyReLU()
|
|
|
|
elif act == "silu":
|
|
m = nn.SiLU()
|
|
|
|
elif act == "gelu":
|
|
m = nn.GELU()
|
|
|
|
elif act == "hardsigmoid":
|
|
m = nn.Hardsigmoid()
|
|
|
|
else:
|
|
raise RuntimeError("")
|
|
|
|
if hasattr(m, "inplace"):
|
|
m.inplace = inpace
|
|
|
|
return m
|
|
|