File size: 14,890 Bytes
89b2487 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 |
import copy
import math
from dataclasses import astuple
import torch
from torch import nn
from torch.nn.modules.transformer import _get_activation_fn
from torchvision.ops import RoIAlign
_DEFAULT_SCALE_CLAMP = math.log(1000.0 / 16)
def convert_boxes_to_pooler_format(bboxes):
bs, num_proposals = bboxes.shape[:2]
sizes = torch.full((bs,), num_proposals).to(bboxes.device)
aggregated_bboxes = bboxes.view(bs * num_proposals, -1)
indices = torch.repeat_interleave(
torch.arange(len(sizes), dtype=aggregated_bboxes.dtype, device=aggregated_bboxes.device), sizes
)
return torch.cat([indices[:, None], aggregated_bboxes], dim=1)
def assign_boxes_to_levels(
bboxes,
min_level,
max_level,
canonical_box_size,
canonical_level,
):
aggregated_bboxes = bboxes.view(bboxes.shape[0] * bboxes.shape[1], -1)
area = (aggregated_bboxes[:, 2] - aggregated_bboxes[:, 0]) * (aggregated_bboxes[:, 3] - aggregated_bboxes[:, 1])
box_sizes = torch.sqrt(area)
# Eqn.(1) in FPN paper
level_assignments = torch.floor(canonical_level + torch.log2(box_sizes / canonical_box_size + 1e-8))
# clamp level to (min, max), in case the box size is too large or too small
# for the available feature maps
level_assignments = torch.clamp(level_assignments, min=min_level, max=max_level)
return level_assignments.to(torch.int64) - min_level
class SinusoidalPositionEmbeddings(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
return embeddings
class HeadDynamicK(nn.Module):
def __init__(self, config, roi_input_shape):
super().__init__()
num_classes = config.num_labels
ddet_head = DiffusionDetHead(config, roi_input_shape, num_classes)
self.num_head = config.num_heads
self.head_series = nn.ModuleList([copy.deepcopy(ddet_head) for _ in range(self.num_head)])
self.return_intermediate = config.deep_supervision
# Gaussian random feature embedding layer for time
self.hidden_dim = config.hidden_dim
time_dim = self.hidden_dim * 4
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(self.hidden_dim),
nn.Linear(self.hidden_dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim),
)
# Init parameters.
self.use_focal = config.use_focal
self.use_fed_loss = config.use_fed_loss
self.num_classes = num_classes
if self.use_focal or self.use_fed_loss:
prior_prob = config.prior_prob
self.bias_value = -math.log((1 - prior_prob) / prior_prob)
self._reset_parameters()
def _reset_parameters(self):
# init all parameters.
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
# initialize the bias for focal loss and fed loss.
if self.use_focal or self.use_fed_loss:
if p.shape[-1] == self.num_classes or p.shape[-1] == self.num_classes + 1:
nn.init.constant_(p, self.bias_value)
def forward(self, features, bboxes, t):
# assert t shape (batch_size)
time = self.time_mlp(t)
inter_class_logits = []
inter_pred_bboxes = []
bs = len(features[0])
class_logits, pred_bboxes = None, None
for head_idx, ddet_head in enumerate(self.head_series):
class_logits, pred_bboxes, proposal_features = ddet_head(features, bboxes, time)
if self.return_intermediate:
inter_class_logits.append(class_logits)
inter_pred_bboxes.append(pred_bboxes)
bboxes = pred_bboxes.detach()
if self.return_intermediate:
return torch.stack(inter_class_logits), torch.stack(inter_pred_bboxes)
return class_logits[None], pred_bboxes[None]
class DynamicConv(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_dim = config.hidden_dim
self.dim_dynamic = config.dim_dynamic
self.num_dynamic = config.num_dynamic
self.num_params = self.hidden_dim * self.dim_dynamic
self.dynamic_layer = nn.Linear(self.hidden_dim, self.num_dynamic * self.num_params)
self.norm1 = nn.LayerNorm(self.dim_dynamic)
self.norm2 = nn.LayerNorm(self.hidden_dim)
self.activation = nn.ReLU(inplace=True)
pooler_resolution = config.pooler_resolution
num_output = self.hidden_dim * pooler_resolution ** 2
self.out_layer = nn.Linear(num_output, self.hidden_dim)
self.norm3 = nn.LayerNorm(self.hidden_dim)
def forward(self, pro_features, roi_features):
features = roi_features.permute(1, 0, 2)
parameters = self.dynamic_layer(pro_features).permute(1, 0, 2)
param1 = parameters[:, :, :self.num_params].view(-1, self.hidden_dim, self.dim_dynamic)
param2 = parameters[:, :, self.num_params:].view(-1, self.dim_dynamic, self.hidden_dim)
features = torch.bmm(features, param1)
features = self.norm1(features)
features = self.activation(features)
features = torch.bmm(features, param2)
features = self.norm2(features)
features = self.activation(features)
features = features.flatten(1)
features = self.out_layer(features)
features = self.norm3(features)
features = self.activation(features)
return features
class DiffusionDetHead(nn.Module):
def __init__(self, config, roi_input_shape, num_classes):
super().__init__()
dim_feedforward = config.dim_feedforward
nhead = config.num_attn_heads
dropout = config.dropout
activation = config.activation
in_features = config.roi_head_in_features
pooler_resolution = config.pooler_resolution
pooler_scales = tuple(1.0 / roi_input_shape[k]['stride'] for k in in_features)
sampling_ratio = config.sampling_ratio
self.hidden_dim = config.hidden_dim
self.pooler = ROIPooler(
output_size=pooler_resolution,
scales=pooler_scales,
sampling_ratio=sampling_ratio,
)
# dynamic.
self.self_attn = nn.MultiheadAttention(self.hidden_dim, nhead, dropout=dropout)
self.inst_interact = DynamicConv(config)
self.linear1 = nn.Linear(self.hidden_dim, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, self.hidden_dim)
self.norm1 = nn.LayerNorm(self.hidden_dim)
self.norm2 = nn.LayerNorm(self.hidden_dim)
self.norm3 = nn.LayerNorm(self.hidden_dim)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
# block time mlp
self.block_time_mlp = nn.Sequential(nn.SiLU(), nn.Linear(self.hidden_dim * 4, self.hidden_dim * 2))
# cls.
num_cls = config.num_cls
cls_module = list()
for _ in range(num_cls):
cls_module.append(nn.Linear(self.hidden_dim, self.hidden_dim, False))
cls_module.append(nn.LayerNorm(self.hidden_dim))
cls_module.append(nn.ReLU(inplace=True))
self.cls_module = nn.ModuleList(cls_module)
# reg.
num_reg = config.num_reg
reg_module = list()
for _ in range(num_reg):
reg_module.append(nn.Linear(self.hidden_dim, self.hidden_dim, False))
reg_module.append(nn.LayerNorm(self.hidden_dim))
reg_module.append(nn.ReLU(inplace=True))
self.reg_module = nn.ModuleList(reg_module)
# pred.
self.use_focal = config.use_focal
self.use_fed_loss = config.use_fed_loss
if self.use_focal or self.use_fed_loss:
self.class_logits = nn.Linear(self.hidden_dim, num_classes)
else:
self.class_logits = nn.Linear(self.hidden_dim, num_classes + 1)
self.bboxes_delta = nn.Linear(self.hidden_dim, 4)
self.scale_clamp = _DEFAULT_SCALE_CLAMP
self.bbox_weights = (2.0, 2.0, 1.0, 1.0)
def forward(self, features, bboxes, time_emb):
bs, num_proposals = bboxes.shape[:2]
# roi_feature.
roi_features = self.pooler(features, bboxes)
pro_features = roi_features.view(bs, num_proposals, self.hidden_dim, -1).mean(-1)
roi_features = roi_features.view(bs * num_proposals, self.hidden_dim, -1).permute(2, 0, 1)
# self_att.
pro_features = pro_features.view(bs, num_proposals, self.hidden_dim).permute(1, 0, 2)
pro_features2 = self.self_attn(pro_features, pro_features, value=pro_features)[0]
pro_features = pro_features + self.dropout1(pro_features2)
pro_features = self.norm1(pro_features)
# inst_interact.
pro_features = pro_features.view(num_proposals, bs, self.hidden_dim).permute(1, 0, 2).reshape(1, bs * num_proposals,
self.hidden_dim)
pro_features2 = self.inst_interact(pro_features, roi_features)
pro_features = pro_features + self.dropout2(pro_features2)
obj_features = self.norm2(pro_features)
# obj_feature.
obj_features2 = self.linear2(self.dropout(self.activation(self.linear1(obj_features))))
obj_features = obj_features + self.dropout3(obj_features2)
obj_features = self.norm3(obj_features)
fc_feature = obj_features.transpose(0, 1).reshape(bs * num_proposals, -1)
scale_shift = self.block_time_mlp(time_emb)
scale_shift = torch.repeat_interleave(scale_shift, num_proposals, dim=0)
scale, shift = scale_shift.chunk(2, dim=1)
fc_feature = fc_feature * (scale + 1) + shift
cls_feature = fc_feature.clone()
reg_feature = fc_feature.clone()
for cls_layer in self.cls_module:
cls_feature = cls_layer(cls_feature)
for reg_layer in self.reg_module:
reg_feature = reg_layer(reg_feature)
class_logits = self.class_logits(cls_feature)
bboxes_deltas = self.bboxes_delta(reg_feature)
pred_bboxes = self.apply_deltas(bboxes_deltas, bboxes.view(-1, 4))
return class_logits.view(bs, num_proposals, -1), pred_bboxes.view(bs, num_proposals, -1), obj_features
def apply_deltas(self, deltas, boxes):
"""
Apply transformation `deltas` (dx, dy, dw, dh) to `boxes`.
Args:
deltas (Tensor): transformation deltas of shape (N, k*4), where k >= 1.
deltas[i] represents k potentially different class-specific
box transformations for the single box boxes[i].
boxes (Tensor): boxes to transform, of shape (N, 4)
"""
boxes = boxes.to(deltas.dtype)
widths = boxes[:, 2] - boxes[:, 0]
heights = boxes[:, 3] - boxes[:, 1]
ctr_x = boxes[:, 0] + 0.5 * widths
ctr_y = boxes[:, 1] + 0.5 * heights
wx, wy, ww, wh = self.bbox_weights
dx = deltas[:, 0::4] / wx
dy = deltas[:, 1::4] / wy
dw = deltas[:, 2::4] / ww
dh = deltas[:, 3::4] / wh
# Prevent sending too large values into torch.exp()
dw = torch.clamp(dw, max=self.scale_clamp)
dh = torch.clamp(dh, max=self.scale_clamp)
pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
pred_w = torch.exp(dw) * widths[:, None]
pred_h = torch.exp(dh) * heights[:, None]
pred_boxes = torch.zeros_like(deltas)
pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w # x1
pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h # y1
pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w # x2
pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h # y2
return pred_boxes
class ROIPooler(nn.Module):
"""
Region of interest feature map pooler that supports pooling from one or more
feature maps.
"""
def __init__(
self,
output_size,
scales,
sampling_ratio,
canonical_box_size=224,
canonical_level=4,
):
super().__init__()
min_level = -(math.log2(scales[0]))
max_level = -(math.log2(scales[-1]))
if isinstance(output_size, int):
output_size = (output_size, output_size)
assert len(output_size) == 2 and isinstance(output_size[0], int) and isinstance(output_size[1], int)
assert math.isclose(min_level, int(min_level)) and math.isclose(max_level, int(max_level))
assert (len(scales) == max_level - min_level + 1)
assert 0 <= min_level <= max_level
assert canonical_box_size > 0
self.output_size = output_size
self.min_level = int(min_level)
self.max_level = int(max_level)
self.canonical_level = canonical_level
self.canonical_box_size = canonical_box_size
self.level_poolers = nn.ModuleList(
RoIAlign(
output_size, spatial_scale=scale, sampling_ratio=sampling_ratio, aligned=True
)
for scale in scales
)
def forward(self, x, bboxes):
num_level_assignments = len(self.level_poolers)
assert len(x) == num_level_assignments and len(bboxes) == x[0].size(0)
pooler_fmt_boxes = convert_boxes_to_pooler_format(bboxes)
if num_level_assignments == 1:
return self.level_poolers[0](x[0], pooler_fmt_boxes)
level_assignments = assign_boxes_to_levels(
bboxes, self.min_level, self.max_level, self.canonical_box_size, self.canonical_level
)
batches = pooler_fmt_boxes.shape[0]
channels = x[0].shape[1]
output_size = self.output_size[0]
sizes = (batches, channels, output_size, output_size)
output = torch.zeros(sizes, dtype=x[0].dtype, device=x[0].device)
for level, (x_level, pooler) in enumerate(zip(x, self.level_poolers)):
inds = (level_assignments == level).nonzero(as_tuple=True)[0]
pooler_fmt_boxes_level = pooler_fmt_boxes[inds]
# Use index_put_ instead of advance indexing, to avoid pytorch/issues/49852
output.index_put_((inds,), pooler(x_level, pooler_fmt_boxes_level))
return output
|