|
import copy |
|
import math |
|
from dataclasses import astuple |
|
|
|
import torch |
|
from torch import nn |
|
from torch.nn.modules.transformer import _get_activation_fn |
|
from torchvision.ops import RoIAlign |
|
|
|
_DEFAULT_SCALE_CLAMP = math.log(1000.0 / 16) |
|
|
|
def convert_boxes_to_pooler_format(bboxes): |
|
bs, num_proposals = bboxes.shape[:2] |
|
sizes = torch.full((bs,), num_proposals).to(bboxes.device) |
|
aggregated_bboxes = bboxes.view(bs * num_proposals, -1) |
|
indices = torch.repeat_interleave( |
|
torch.arange(len(sizes), dtype=aggregated_bboxes.dtype, device=aggregated_bboxes.device), sizes |
|
) |
|
return torch.cat([indices[:, None], aggregated_bboxes], dim=1) |
|
|
|
|
|
def assign_boxes_to_levels( |
|
bboxes, |
|
min_level, |
|
max_level, |
|
canonical_box_size, |
|
canonical_level, |
|
): |
|
aggregated_bboxes = bboxes.view(bboxes.shape[0] * bboxes.shape[1], -1) |
|
area = (aggregated_bboxes[:, 2] - aggregated_bboxes[:, 0]) * (aggregated_bboxes[:, 3] - aggregated_bboxes[:, 1]) |
|
box_sizes = torch.sqrt(area) |
|
|
|
level_assignments = torch.floor(canonical_level + torch.log2(box_sizes / canonical_box_size + 1e-8)) |
|
|
|
|
|
level_assignments = torch.clamp(level_assignments, min=min_level, max=max_level) |
|
return level_assignments.to(torch.int64) - min_level |
|
|
|
|
|
class SinusoidalPositionEmbeddings(nn.Module): |
|
def __init__(self, dim): |
|
super().__init__() |
|
self.dim = dim |
|
|
|
def forward(self, time): |
|
device = time.device |
|
half_dim = self.dim // 2 |
|
embeddings = math.log(10000) / (half_dim - 1) |
|
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) |
|
embeddings = time[:, None] * embeddings[None, :] |
|
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) |
|
return embeddings |
|
|
|
|
|
class HeadDynamicK(nn.Module): |
|
def __init__(self, config, roi_input_shape): |
|
super().__init__() |
|
num_classes = config.num_labels |
|
|
|
ddet_head = DiffusionDetHead(config, roi_input_shape, num_classes) |
|
self.num_head = config.num_heads |
|
self.head_series = nn.ModuleList([copy.deepcopy(ddet_head) for _ in range(self.num_head)]) |
|
self.return_intermediate = config.deep_supervision |
|
|
|
|
|
self.hidden_dim = config.hidden_dim |
|
time_dim = self.hidden_dim * 4 |
|
self.time_mlp = nn.Sequential( |
|
SinusoidalPositionEmbeddings(self.hidden_dim), |
|
nn.Linear(self.hidden_dim, time_dim), |
|
nn.GELU(), |
|
nn.Linear(time_dim, time_dim), |
|
) |
|
|
|
|
|
self.use_focal = config.use_focal |
|
self.use_fed_loss = config.use_fed_loss |
|
self.num_classes = num_classes |
|
if self.use_focal or self.use_fed_loss: |
|
prior_prob = config.prior_prob |
|
self.bias_value = -math.log((1 - prior_prob) / prior_prob) |
|
self._reset_parameters() |
|
|
|
def _reset_parameters(self): |
|
|
|
for p in self.parameters(): |
|
if p.dim() > 1: |
|
nn.init.xavier_uniform_(p) |
|
|
|
|
|
if self.use_focal or self.use_fed_loss: |
|
if p.shape[-1] == self.num_classes or p.shape[-1] == self.num_classes + 1: |
|
nn.init.constant_(p, self.bias_value) |
|
|
|
|
|
def forward(self, features, bboxes, t): |
|
|
|
time = self.time_mlp(t) |
|
|
|
inter_class_logits = [] |
|
inter_pred_bboxes = [] |
|
|
|
bs = len(features[0]) |
|
|
|
class_logits, pred_bboxes = None, None |
|
for head_idx, ddet_head in enumerate(self.head_series): |
|
class_logits, pred_bboxes, proposal_features = ddet_head(features, bboxes, time) |
|
if self.return_intermediate: |
|
inter_class_logits.append(class_logits) |
|
inter_pred_bboxes.append(pred_bboxes) |
|
bboxes = pred_bboxes.detach() |
|
|
|
if self.return_intermediate: |
|
return torch.stack(inter_class_logits), torch.stack(inter_pred_bboxes) |
|
|
|
return class_logits[None], pred_bboxes[None] |
|
|
|
|
|
class DynamicConv(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
|
|
self.hidden_dim = config.hidden_dim |
|
self.dim_dynamic = config.dim_dynamic |
|
self.num_dynamic = config.num_dynamic |
|
self.num_params = self.hidden_dim * self.dim_dynamic |
|
self.dynamic_layer = nn.Linear(self.hidden_dim, self.num_dynamic * self.num_params) |
|
|
|
self.norm1 = nn.LayerNorm(self.dim_dynamic) |
|
self.norm2 = nn.LayerNorm(self.hidden_dim) |
|
|
|
self.activation = nn.ReLU(inplace=True) |
|
|
|
pooler_resolution = config.pooler_resolution |
|
num_output = self.hidden_dim * pooler_resolution ** 2 |
|
self.out_layer = nn.Linear(num_output, self.hidden_dim) |
|
self.norm3 = nn.LayerNorm(self.hidden_dim) |
|
|
|
|
|
def forward(self, pro_features, roi_features): |
|
features = roi_features.permute(1, 0, 2) |
|
parameters = self.dynamic_layer(pro_features).permute(1, 0, 2) |
|
|
|
param1 = parameters[:, :, :self.num_params].view(-1, self.hidden_dim, self.dim_dynamic) |
|
param2 = parameters[:, :, self.num_params:].view(-1, self.dim_dynamic, self.hidden_dim) |
|
|
|
features = torch.bmm(features, param1) |
|
features = self.norm1(features) |
|
features = self.activation(features) |
|
|
|
features = torch.bmm(features, param2) |
|
features = self.norm2(features) |
|
features = self.activation(features) |
|
|
|
features = features.flatten(1) |
|
features = self.out_layer(features) |
|
features = self.norm3(features) |
|
features = self.activation(features) |
|
|
|
return features |
|
|
|
|
|
class DiffusionDetHead(nn.Module): |
|
def __init__(self, config, roi_input_shape, num_classes): |
|
super().__init__() |
|
|
|
dim_feedforward = config.dim_feedforward |
|
nhead = config.num_attn_heads |
|
dropout = config.dropout |
|
activation = config.activation |
|
in_features = config.roi_head_in_features |
|
pooler_resolution = config.pooler_resolution |
|
pooler_scales = tuple(1.0 / roi_input_shape[k]['stride'] for k in in_features) |
|
sampling_ratio = config.sampling_ratio |
|
|
|
self.hidden_dim = config.hidden_dim |
|
|
|
self.pooler = ROIPooler( |
|
output_size=pooler_resolution, |
|
scales=pooler_scales, |
|
sampling_ratio=sampling_ratio, |
|
) |
|
|
|
|
|
self.self_attn = nn.MultiheadAttention(self.hidden_dim, nhead, dropout=dropout) |
|
self.inst_interact = DynamicConv(config) |
|
|
|
self.linear1 = nn.Linear(self.hidden_dim, dim_feedforward) |
|
self.dropout = nn.Dropout(dropout) |
|
self.linear2 = nn.Linear(dim_feedforward, self.hidden_dim) |
|
|
|
self.norm1 = nn.LayerNorm(self.hidden_dim) |
|
self.norm2 = nn.LayerNorm(self.hidden_dim) |
|
self.norm3 = nn.LayerNorm(self.hidden_dim) |
|
self.dropout1 = nn.Dropout(dropout) |
|
self.dropout2 = nn.Dropout(dropout) |
|
self.dropout3 = nn.Dropout(dropout) |
|
|
|
self.activation = _get_activation_fn(activation) |
|
|
|
|
|
self.block_time_mlp = nn.Sequential(nn.SiLU(), nn.Linear(self.hidden_dim * 4, self.hidden_dim * 2)) |
|
|
|
|
|
num_cls = config.num_cls |
|
cls_module = list() |
|
for _ in range(num_cls): |
|
cls_module.append(nn.Linear(self.hidden_dim, self.hidden_dim, False)) |
|
cls_module.append(nn.LayerNorm(self.hidden_dim)) |
|
cls_module.append(nn.ReLU(inplace=True)) |
|
self.cls_module = nn.ModuleList(cls_module) |
|
|
|
|
|
num_reg = config.num_reg |
|
reg_module = list() |
|
for _ in range(num_reg): |
|
reg_module.append(nn.Linear(self.hidden_dim, self.hidden_dim, False)) |
|
reg_module.append(nn.LayerNorm(self.hidden_dim)) |
|
reg_module.append(nn.ReLU(inplace=True)) |
|
self.reg_module = nn.ModuleList(reg_module) |
|
|
|
|
|
self.use_focal = config.use_focal |
|
self.use_fed_loss = config.use_fed_loss |
|
if self.use_focal or self.use_fed_loss: |
|
self.class_logits = nn.Linear(self.hidden_dim, num_classes) |
|
else: |
|
self.class_logits = nn.Linear(self.hidden_dim, num_classes + 1) |
|
self.bboxes_delta = nn.Linear(self.hidden_dim, 4) |
|
self.scale_clamp = _DEFAULT_SCALE_CLAMP |
|
self.bbox_weights = (2.0, 2.0, 1.0, 1.0) |
|
|
|
def forward(self, features, bboxes, time_emb): |
|
bs, num_proposals = bboxes.shape[:2] |
|
|
|
|
|
roi_features = self.pooler(features, bboxes) |
|
|
|
pro_features = roi_features.view(bs, num_proposals, self.hidden_dim, -1).mean(-1) |
|
|
|
roi_features = roi_features.view(bs * num_proposals, self.hidden_dim, -1).permute(2, 0, 1) |
|
|
|
|
|
pro_features = pro_features.view(bs, num_proposals, self.hidden_dim).permute(1, 0, 2) |
|
pro_features2 = self.self_attn(pro_features, pro_features, value=pro_features)[0] |
|
pro_features = pro_features + self.dropout1(pro_features2) |
|
pro_features = self.norm1(pro_features) |
|
|
|
|
|
pro_features = pro_features.view(num_proposals, bs, self.hidden_dim).permute(1, 0, 2).reshape(1, bs * num_proposals, |
|
self.hidden_dim) |
|
pro_features2 = self.inst_interact(pro_features, roi_features) |
|
pro_features = pro_features + self.dropout2(pro_features2) |
|
obj_features = self.norm2(pro_features) |
|
|
|
|
|
obj_features2 = self.linear2(self.dropout(self.activation(self.linear1(obj_features)))) |
|
obj_features = obj_features + self.dropout3(obj_features2) |
|
obj_features = self.norm3(obj_features) |
|
|
|
fc_feature = obj_features.transpose(0, 1).reshape(bs * num_proposals, -1) |
|
|
|
scale_shift = self.block_time_mlp(time_emb) |
|
scale_shift = torch.repeat_interleave(scale_shift, num_proposals, dim=0) |
|
scale, shift = scale_shift.chunk(2, dim=1) |
|
fc_feature = fc_feature * (scale + 1) + shift |
|
|
|
cls_feature = fc_feature.clone() |
|
reg_feature = fc_feature.clone() |
|
for cls_layer in self.cls_module: |
|
cls_feature = cls_layer(cls_feature) |
|
for reg_layer in self.reg_module: |
|
reg_feature = reg_layer(reg_feature) |
|
class_logits = self.class_logits(cls_feature) |
|
bboxes_deltas = self.bboxes_delta(reg_feature) |
|
pred_bboxes = self.apply_deltas(bboxes_deltas, bboxes.view(-1, 4)) |
|
|
|
return class_logits.view(bs, num_proposals, -1), pred_bboxes.view(bs, num_proposals, -1), obj_features |
|
|
|
def apply_deltas(self, deltas, boxes): |
|
""" |
|
Apply transformation `deltas` (dx, dy, dw, dh) to `boxes`. |
|
|
|
Args: |
|
deltas (Tensor): transformation deltas of shape (N, k*4), where k >= 1. |
|
deltas[i] represents k potentially different class-specific |
|
box transformations for the single box boxes[i]. |
|
boxes (Tensor): boxes to transform, of shape (N, 4) |
|
""" |
|
boxes = boxes.to(deltas.dtype) |
|
|
|
widths = boxes[:, 2] - boxes[:, 0] |
|
heights = boxes[:, 3] - boxes[:, 1] |
|
ctr_x = boxes[:, 0] + 0.5 * widths |
|
ctr_y = boxes[:, 1] + 0.5 * heights |
|
|
|
wx, wy, ww, wh = self.bbox_weights |
|
dx = deltas[:, 0::4] / wx |
|
dy = deltas[:, 1::4] / wy |
|
dw = deltas[:, 2::4] / ww |
|
dh = deltas[:, 3::4] / wh |
|
|
|
|
|
dw = torch.clamp(dw, max=self.scale_clamp) |
|
dh = torch.clamp(dh, max=self.scale_clamp) |
|
|
|
pred_ctr_x = dx * widths[:, None] + ctr_x[:, None] |
|
pred_ctr_y = dy * heights[:, None] + ctr_y[:, None] |
|
pred_w = torch.exp(dw) * widths[:, None] |
|
pred_h = torch.exp(dh) * heights[:, None] |
|
|
|
pred_boxes = torch.zeros_like(deltas) |
|
pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w |
|
pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h |
|
pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w |
|
pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h |
|
|
|
return pred_boxes |
|
|
|
|
|
class ROIPooler(nn.Module): |
|
""" |
|
Region of interest feature map pooler that supports pooling from one or more |
|
feature maps. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
output_size, |
|
scales, |
|
sampling_ratio, |
|
canonical_box_size=224, |
|
canonical_level=4, |
|
): |
|
super().__init__() |
|
|
|
min_level = -(math.log2(scales[0])) |
|
max_level = -(math.log2(scales[-1])) |
|
|
|
if isinstance(output_size, int): |
|
output_size = (output_size, output_size) |
|
assert len(output_size) == 2 and isinstance(output_size[0], int) and isinstance(output_size[1], int) |
|
assert math.isclose(min_level, int(min_level)) and math.isclose(max_level, int(max_level)) |
|
assert (len(scales) == max_level - min_level + 1) |
|
assert 0 <= min_level <= max_level |
|
assert canonical_box_size > 0 |
|
|
|
self.output_size = output_size |
|
self.min_level = int(min_level) |
|
self.max_level = int(max_level) |
|
self.canonical_level = canonical_level |
|
self.canonical_box_size = canonical_box_size |
|
self.level_poolers = nn.ModuleList( |
|
RoIAlign( |
|
output_size, spatial_scale=scale, sampling_ratio=sampling_ratio, aligned=True |
|
) |
|
for scale in scales |
|
) |
|
|
|
def forward(self, x, bboxes): |
|
num_level_assignments = len(self.level_poolers) |
|
assert len(x) == num_level_assignments and len(bboxes) == x[0].size(0) |
|
|
|
pooler_fmt_boxes = convert_boxes_to_pooler_format(bboxes) |
|
|
|
if num_level_assignments == 1: |
|
return self.level_poolers[0](x[0], pooler_fmt_boxes) |
|
|
|
level_assignments = assign_boxes_to_levels( |
|
bboxes, self.min_level, self.max_level, self.canonical_box_size, self.canonical_level |
|
) |
|
|
|
batches = pooler_fmt_boxes.shape[0] |
|
channels = x[0].shape[1] |
|
output_size = self.output_size[0] |
|
sizes = (batches, channels, output_size, output_size) |
|
|
|
output = torch.zeros(sizes, dtype=x[0].dtype, device=x[0].device) |
|
|
|
for level, (x_level, pooler) in enumerate(zip(x, self.level_poolers)): |
|
inds = (level_assignments == level).nonzero(as_tuple=True)[0] |
|
pooler_fmt_boxes_level = pooler_fmt_boxes[inds] |
|
|
|
output.index_put_((inds,), pooler(x_level, pooler_fmt_boxes_level)) |
|
|
|
return output |
|
|