Saurabh1105's picture
MMdet Model for Image Segmentation
6c9ac8f
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Modified from https://github.com/ShoufaChen/DiffusionDet/blob/main/diffusiondet/detector.py # noqa
# Modified from https://github.com/ShoufaChen/DiffusionDet/blob/main/diffusiondet/head.py # noqa
# This work is licensed under the CC-BY-NC 4.0 License.
# Users should be careful about adopting these features in any commercial matters. # noqa
# For more details, please refer to https://github.com/ShoufaChen/DiffusionDet/blob/main/LICENSE # noqa
import copy
import math
import random
import warnings
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_activation_layer
from mmcv.ops import batched_nms
from mmengine.structures import InstanceData
from torch import Tensor
from mmdet.registry import MODELS, TASK_UTILS
from mmdet.structures import SampleList
from mmdet.structures.bbox import (bbox2roi, bbox_cxcywh_to_xyxy,
bbox_xyxy_to_cxcywh, get_box_wh,
scale_boxes)
from mmdet.utils import InstanceList
_DEFAULT_SCALE_CLAMP = math.log(100000.0 / 16)
def cosine_beta_schedule(timesteps, s=0.008):
"""Cosine schedule as proposed in
https://openreview.net/forum?id=-NEXDKk8gZ."""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
alphas_cumprod = torch.cos(
((x / timesteps) + s) / (1 + s) * math.pi * 0.5)**2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
def extract(a, t, x_shape):
"""extract the appropriate t index for a batch of indices."""
batch_size = t.shape[0]
out = a.gather(-1, t)
return out.reshape(batch_size, *((1, ) * (len(x_shape) - 1)))
class SinusoidalPositionEmbeddings(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(
torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
return embeddings
@MODELS.register_module()
class DynamicDiffusionDetHead(nn.Module):
def __init__(self,
num_classes=80,
feat_channels=256,
num_proposals=500,
num_heads=6,
prior_prob=0.01,
snr_scale=2.0,
timesteps=1000,
sampling_timesteps=1,
self_condition=False,
box_renewal=True,
use_ensemble=True,
deep_supervision=True,
ddim_sampling_eta=1.0,
criterion=dict(
type='DiffusionDetCriterion',
num_classes=80,
assigner=dict(
type='DiffusionDetMatcher',
match_costs=[
dict(
type='FocalLossCost',
alpha=2.0,
gamma=0.25,
weight=2.0),
dict(
type='BBoxL1Cost',
weight=5.0,
box_format='xyxy'),
dict(type='IoUCost', iou_mode='giou', weight=2.0)
],
center_radius=2.5,
candidate_topk=5),
),
single_head=dict(
type='DiffusionDetHead',
num_cls_convs=1,
num_reg_convs=3,
dim_feedforward=2048,
num_heads=8,
dropout=0.0,
act_cfg=dict(type='ReLU'),
dynamic_conv=dict(dynamic_dim=64, dynamic_num=2)),
roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(
type='RoIAlign', output_size=7, sampling_ratio=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
test_cfg=None,
**kwargs) -> None:
super().__init__()
self.roi_extractor = MODELS.build(roi_extractor)
self.num_classes = num_classes
self.num_classes = num_classes
self.feat_channels = feat_channels
self.num_proposals = num_proposals
self.num_heads = num_heads
# Build Diffusion
assert isinstance(timesteps, int), 'The type of `timesteps` should ' \
f'be int but got {type(timesteps)}'
assert sampling_timesteps <= timesteps
self.timesteps = timesteps
self.sampling_timesteps = sampling_timesteps
self.snr_scale = snr_scale
self.ddim_sampling = self.sampling_timesteps < self.timesteps
self.ddim_sampling_eta = ddim_sampling_eta
self.self_condition = self_condition
self.box_renewal = box_renewal
self.use_ensemble = use_ensemble
self._build_diffusion()
# Build assigner
assert criterion.get('assigner', None) is not None
assigner = TASK_UTILS.build(criterion.get('assigner'))
# Init parameters.
self.use_focal_loss = assigner.use_focal_loss
self.use_fed_loss = assigner.use_fed_loss
# build criterion
criterion.update(deep_supervision=deep_supervision)
self.criterion = TASK_UTILS.build(criterion)
# Build Dynamic Head.
single_head_ = single_head.copy()
single_head_num_classes = single_head_.get('num_classes', None)
if single_head_num_classes is None:
single_head_.update(num_classes=num_classes)
else:
if single_head_num_classes != num_classes:
warnings.warn(
'The `num_classes` of `DynamicDiffusionDetHead` and '
'`SingleDiffusionDetHead` should be same, changing '
f'`single_head.num_classes` to {num_classes}')
single_head_.update(num_classes=num_classes)
single_head_feat_channels = single_head_.get('feat_channels', None)
if single_head_feat_channels is None:
single_head_.update(feat_channels=feat_channels)
else:
if single_head_feat_channels != feat_channels:
warnings.warn(
'The `feat_channels` of `DynamicDiffusionDetHead` and '
'`SingleDiffusionDetHead` should be same, changing '
f'`single_head.feat_channels` to {feat_channels}')
single_head_.update(feat_channels=feat_channels)
default_pooler_resolution = roi_extractor['roi_layer'].get(
'output_size')
assert default_pooler_resolution is not None
single_head_pooler_resolution = single_head_.get('pooler_resolution')
if single_head_pooler_resolution is None:
single_head_.update(pooler_resolution=default_pooler_resolution)
else:
if single_head_pooler_resolution != default_pooler_resolution:
warnings.warn(
'The `pooler_resolution` of `DynamicDiffusionDetHead` '
'and `SingleDiffusionDetHead` should be same, changing '
f'`single_head.pooler_resolution` to {num_classes}')
single_head_.update(
pooler_resolution=default_pooler_resolution)
single_head_.update(
use_focal_loss=self.use_focal_loss, use_fed_loss=self.use_fed_loss)
single_head_module = MODELS.build(single_head_)
self.num_heads = num_heads
self.head_series = nn.ModuleList(
[copy.deepcopy(single_head_module) for _ in range(num_heads)])
self.deep_supervision = deep_supervision
# Gaussian random feature embedding layer for time
time_dim = feat_channels * 4
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(feat_channels),
nn.Linear(feat_channels, time_dim), nn.GELU(),
nn.Linear(time_dim, time_dim))
self.prior_prob = prior_prob
self.test_cfg = test_cfg
self.use_nms = self.test_cfg.get('use_nms', True)
self._init_weights()
def _init_weights(self):
# init all parameters.
bias_value = -math.log((1 - self.prior_prob) / self.prior_prob)
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
# initialize the bias for focal loss and fed loss.
if self.use_focal_loss or self.use_fed_loss:
if p.shape[-1] == self.num_classes or \
p.shape[-1] == self.num_classes + 1:
nn.init.constant_(p, bias_value)
def _build_diffusion(self):
betas = cosine_beta_schedule(self.timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.)
self.register_buffer('betas', betas)
self.register_buffer('alphas_cumprod', alphas_cumprod)
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
self.register_buffer('sqrt_one_minus_alphas_cumprod',
torch.sqrt(1. - alphas_cumprod))
self.register_buffer('log_one_minus_alphas_cumprod',
torch.log(1. - alphas_cumprod))
self.register_buffer('sqrt_recip_alphas_cumprod',
torch.sqrt(1. / alphas_cumprod))
self.register_buffer('sqrt_recipm1_alphas_cumprod',
torch.sqrt(1. / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
# equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (
1. - alphas_cumprod)
self.register_buffer('posterior_variance', posterior_variance)
# log calculation clipped because the posterior variance is 0 at
# the beginning of the diffusion chain
self.register_buffer('posterior_log_variance_clipped',
torch.log(posterior_variance.clamp(min=1e-20)))
self.register_buffer(
'posterior_mean_coef1',
betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2',
(1. - alphas_cumprod_prev) * torch.sqrt(alphas) /
(1. - alphas_cumprod))
def forward(self, features, init_bboxes, init_t, init_features=None):
time = self.time_mlp(init_t, )
inter_class_logits = []
inter_pred_bboxes = []
bs = len(features[0])
bboxes = init_bboxes
if init_features is not None:
init_features = init_features[None].repeat(1, bs, 1)
proposal_features = init_features.clone()
else:
proposal_features = None
for head_idx, single_head in enumerate(self.head_series):
class_logits, pred_bboxes, proposal_features = single_head(
features, bboxes, proposal_features, self.roi_extractor, time)
if self.deep_supervision:
inter_class_logits.append(class_logits)
inter_pred_bboxes.append(pred_bboxes)
bboxes = pred_bboxes.detach()
if self.deep_supervision:
return torch.stack(inter_class_logits), torch.stack(
inter_pred_bboxes)
else:
return class_logits[None, ...], pred_bboxes[None, ...]
def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList) -> dict:
"""Perform forward propagation and loss calculation of the detection
head on the features of the upstream network.
Args:
x (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
Returns:
dict: A dictionary of loss components.
"""
prepare_outputs = self.prepare_training_targets(batch_data_samples)
(batch_gt_instances, batch_pred_instances, batch_gt_instances_ignore,
batch_img_metas) = prepare_outputs
batch_diff_bboxes = torch.stack([
pred_instances.diff_bboxes_abs
for pred_instances in batch_pred_instances
])
batch_time = torch.stack(
[pred_instances.time for pred_instances in batch_pred_instances])
pred_logits, pred_bboxes = self(x, batch_diff_bboxes, batch_time)
output = {
'pred_logits': pred_logits[-1],
'pred_boxes': pred_bboxes[-1]
}
if self.deep_supervision:
output['aux_outputs'] = [{
'pred_logits': a,
'pred_boxes': b
} for a, b in zip(pred_logits[:-1], pred_bboxes[:-1])]
losses = self.criterion(output, batch_gt_instances, batch_img_metas)
return losses
def prepare_training_targets(self, batch_data_samples):
# hard-setting seed to keep results same (if necessary)
# random.seed(0)
# torch.manual_seed(0)
# torch.cuda.manual_seed_all(0)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
batch_gt_instances = []
batch_pred_instances = []
batch_gt_instances_ignore = []
batch_img_metas = []
for data_sample in batch_data_samples:
img_meta = data_sample.metainfo
gt_instances = data_sample.gt_instances
gt_bboxes = gt_instances.bboxes
h, w = img_meta['img_shape']
image_size = gt_bboxes.new_tensor([w, h, w, h])
norm_gt_bboxes = gt_bboxes / image_size
norm_gt_bboxes_cxcywh = bbox_xyxy_to_cxcywh(norm_gt_bboxes)
pred_instances = self.prepare_diffusion(norm_gt_bboxes_cxcywh,
image_size)
gt_instances.set_metainfo(dict(image_size=image_size))
gt_instances.norm_bboxes_cxcywh = norm_gt_bboxes_cxcywh
batch_gt_instances.append(gt_instances)
batch_pred_instances.append(pred_instances)
batch_img_metas.append(data_sample.metainfo)
if 'ignored_instances' in data_sample:
batch_gt_instances_ignore.append(data_sample.ignored_instances)
else:
batch_gt_instances_ignore.append(None)
return (batch_gt_instances, batch_pred_instances,
batch_gt_instances_ignore, batch_img_metas)
def prepare_diffusion(self, gt_boxes, image_size):
device = gt_boxes.device
time = torch.randint(
0, self.timesteps, (1, ), dtype=torch.long, device=device)
noise = torch.randn(self.num_proposals, 4, device=device)
num_gt = gt_boxes.shape[0]
if num_gt < self.num_proposals:
# 3 * sigma = 1/2 --> sigma: 1/6
box_placeholder = torch.randn(
self.num_proposals - num_gt, 4, device=device) / 6. + 0.5
box_placeholder[:, 2:] = torch.clip(
box_placeholder[:, 2:], min=1e-4)
x_start = torch.cat((gt_boxes, box_placeholder), dim=0)
else:
select_mask = [True] * self.num_proposals + \
[False] * (num_gt - self.num_proposals)
random.shuffle(select_mask)
x_start = gt_boxes[select_mask]
x_start = (x_start * 2. - 1.) * self.snr_scale
# noise sample
x = self.q_sample(x_start=x_start, time=time, noise=noise)
x = torch.clamp(x, min=-1 * self.snr_scale, max=self.snr_scale)
x = ((x / self.snr_scale) + 1) / 2.
diff_bboxes = bbox_cxcywh_to_xyxy(x)
# convert to abs bboxes
diff_bboxes_abs = diff_bboxes * image_size
metainfo = dict(time=time.squeeze(-1))
pred_instances = InstanceData(metainfo=metainfo)
pred_instances.diff_bboxes = diff_bboxes
pred_instances.diff_bboxes_abs = diff_bboxes_abs
pred_instances.noise = noise
return pred_instances
# forward diffusion
def q_sample(self, x_start, time, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
x_start_shape = x_start.shape
sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, time,
x_start_shape)
sqrt_one_minus_alphas_cumprod_t = extract(
self.sqrt_one_minus_alphas_cumprod, time, x_start_shape)
return sqrt_alphas_cumprod_t * x_start + \
sqrt_one_minus_alphas_cumprod_t * noise
def predict(self,
x: Tuple[Tensor],
batch_data_samples: SampleList,
rescale: bool = False) -> InstanceList:
"""Perform forward propagation of the detection head and predict
detection results on the features of the upstream network.
Args:
x (tuple[Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[obj:`InstanceData`]: Detection results of each image
after the post process.
"""
# hard-setting seed to keep results same (if necessary)
# seed = 0
# random.seed(seed)
# torch.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)
device = x[-1].device
batch_img_metas = [
data_samples.metainfo for data_samples in batch_data_samples
]
(time_pairs, batch_noise_bboxes, batch_noise_bboxes_raw,
batch_image_size) = self.prepare_testing_targets(
batch_img_metas, device)
predictions = self.predict_by_feat(
x,
time_pairs=time_pairs,
batch_noise_bboxes=batch_noise_bboxes,
batch_noise_bboxes_raw=batch_noise_bboxes_raw,
batch_image_size=batch_image_size,
device=device,
batch_img_metas=batch_img_metas)
return predictions
def predict_by_feat(self,
x,
time_pairs,
batch_noise_bboxes,
batch_noise_bboxes_raw,
batch_image_size,
device,
batch_img_metas=None,
cfg=None,
rescale=True):
batch_size = len(batch_img_metas)
cfg = self.test_cfg if cfg is None else cfg
cfg = copy.deepcopy(cfg)
ensemble_score, ensemble_label, ensemble_coord = [], [], []
for time, time_next in time_pairs:
batch_time = torch.full((batch_size, ),
time,
device=device,
dtype=torch.long)
# self_condition = x_start if self.self_condition else None
pred_logits, pred_bboxes = self(x, batch_noise_bboxes, batch_time)
x_start = pred_bboxes[-1]
x_start = x_start / batch_image_size[:, None, :]
x_start = bbox_xyxy_to_cxcywh(x_start)
x_start = (x_start * 2 - 1.) * self.snr_scale
x_start = torch.clamp(
x_start, min=-1 * self.snr_scale, max=self.snr_scale)
pred_noise = self.predict_noise_from_start(batch_noise_bboxes_raw,
batch_time, x_start)
pred_noise_list, x_start_list = [], []
noise_bboxes_list, num_remain_list = [], []
if self.box_renewal: # filter
score_thr = cfg.get('score_thr', 0)
for img_id in range(batch_size):
score_per_image = pred_logits[-1][img_id]
score_per_image = torch.sigmoid(score_per_image)
value, _ = torch.max(score_per_image, -1, keepdim=False)
keep_idx = value > score_thr
num_remain_list.append(torch.sum(keep_idx))
pred_noise_list.append(pred_noise[img_id, keep_idx, :])
x_start_list.append(x_start[img_id, keep_idx, :])
noise_bboxes_list.append(batch_noise_bboxes[img_id,
keep_idx, :])
if time_next < 0:
# Not same as original DiffusionDet
if self.use_ensemble and self.sampling_timesteps > 1:
box_pred_per_image, scores_per_image, labels_per_image = \
self.inference(
box_cls=pred_logits[-1],
box_pred=pred_bboxes[-1],
cfg=cfg,
device=device)
ensemble_score.append(scores_per_image)
ensemble_label.append(labels_per_image)
ensemble_coord.append(box_pred_per_image)
continue
alpha = self.alphas_cumprod[time]
alpha_next = self.alphas_cumprod[time_next]
sigma = self.ddim_sampling_eta * ((1 - alpha / alpha_next) *
(1 - alpha_next) /
(1 - alpha)).sqrt()
c = (1 - alpha_next - sigma**2).sqrt()
batch_noise_bboxes_list = []
batch_noise_bboxes_raw_list = []
for idx in range(batch_size):
pred_noise = pred_noise_list[idx]
x_start = x_start_list[idx]
noise_bboxes = noise_bboxes_list[idx]
num_remain = num_remain_list[idx]
noise = torch.randn_like(noise_bboxes)
noise_bboxes = x_start * alpha_next.sqrt() + \
c * pred_noise + sigma * noise
if self.box_renewal: # filter
# replenish with randn boxes
if num_remain < self.num_proposals:
noise_bboxes = torch.cat(
(noise_bboxes,
torch.randn(
self.num_proposals - num_remain,
4,
device=device)),
dim=0)
else:
select_mask = [True] * self.num_proposals + \
[False] * (num_remain -
self.num_proposals)
random.shuffle(select_mask)
noise_bboxes = noise_bboxes[select_mask]
# raw noise boxes
batch_noise_bboxes_raw_list.append(noise_bboxes)
# resize to xyxy
noise_bboxes = torch.clamp(
noise_bboxes,
min=-1 * self.snr_scale,
max=self.snr_scale)
noise_bboxes = ((noise_bboxes / self.snr_scale) + 1) / 2
noise_bboxes = bbox_cxcywh_to_xyxy(noise_bboxes)
noise_bboxes = noise_bboxes * batch_image_size[idx]
batch_noise_bboxes_list.append(noise_bboxes)
batch_noise_bboxes = torch.stack(batch_noise_bboxes_list)
batch_noise_bboxes_raw = torch.stack(batch_noise_bboxes_raw_list)
if self.use_ensemble and self.sampling_timesteps > 1:
box_pred_per_image, scores_per_image, labels_per_image = \
self.inference(
box_cls=pred_logits[-1],
box_pred=pred_bboxes[-1],
cfg=cfg,
device=device)
ensemble_score.append(scores_per_image)
ensemble_label.append(labels_per_image)
ensemble_coord.append(box_pred_per_image)
if self.use_ensemble and self.sampling_timesteps > 1:
steps = len(ensemble_score)
results_list = []
for idx in range(batch_size):
ensemble_score_per_img = [
ensemble_score[i][idx] for i in range(steps)
]
ensemble_label_per_img = [
ensemble_label[i][idx] for i in range(steps)
]
ensemble_coord_per_img = [
ensemble_coord[i][idx] for i in range(steps)
]
scores_per_image = torch.cat(ensemble_score_per_img, dim=0)
labels_per_image = torch.cat(ensemble_label_per_img, dim=0)
box_pred_per_image = torch.cat(ensemble_coord_per_img, dim=0)
if self.use_nms:
det_bboxes, keep_idxs = batched_nms(
box_pred_per_image, scores_per_image, labels_per_image,
cfg.nms)
box_pred_per_image = box_pred_per_image[keep_idxs]
labels_per_image = labels_per_image[keep_idxs]
scores_per_image = det_bboxes[:, -1]
results = InstanceData()
results.bboxes = box_pred_per_image
results.scores = scores_per_image
results.labels = labels_per_image
results_list.append(results)
else:
box_cls = pred_logits[-1]
box_pred = pred_bboxes[-1]
results_list = self.inference(box_cls, box_pred, cfg, device)
if rescale:
results_list = self.do_results_post_process(
results_list, cfg, batch_img_metas=batch_img_metas)
return results_list
@staticmethod
def do_results_post_process(results_list, cfg, batch_img_metas=None):
processed_results = []
for results, img_meta in zip(results_list, batch_img_metas):
assert img_meta.get('scale_factor') is not None
scale_factor = [1 / s for s in img_meta['scale_factor']]
results.bboxes = scale_boxes(results.bboxes, scale_factor)
# clip w, h
h, w = img_meta['ori_shape']
results.bboxes[:, 0::2] = results.bboxes[:, 0::2].clamp(
min=0, max=w)
results.bboxes[:, 1::2] = results.bboxes[:, 1::2].clamp(
min=0, max=h)
# filter small size bboxes
if cfg.get('min_bbox_size', 0) >= 0:
w, h = get_box_wh(results.bboxes)
valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size)
if not valid_mask.all():
results = results[valid_mask]
processed_results.append(results)
return processed_results
def prepare_testing_targets(self, batch_img_metas, device):
# [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == timesteps
times = torch.linspace(
-1, self.timesteps - 1, steps=self.sampling_timesteps + 1)
times = list(reversed(times.int().tolist()))
# [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
time_pairs = list(zip(times[:-1], times[1:]))
noise_bboxes_list = []
noise_bboxes_raw_list = []
image_size_list = []
for img_meta in batch_img_metas:
h, w = img_meta['img_shape']
image_size = torch.tensor([w, h, w, h],
dtype=torch.float32,
device=device)
noise_bboxes_raw = torch.randn((self.num_proposals, 4),
device=device)
noise_bboxes = torch.clamp(
noise_bboxes_raw, min=-1 * self.snr_scale, max=self.snr_scale)
noise_bboxes = ((noise_bboxes / self.snr_scale) + 1) / 2
noise_bboxes = bbox_cxcywh_to_xyxy(noise_bboxes)
noise_bboxes = noise_bboxes * image_size
noise_bboxes_raw_list.append(noise_bboxes_raw)
noise_bboxes_list.append(noise_bboxes)
image_size_list.append(image_size[None])
batch_noise_bboxes = torch.stack(noise_bboxes_list)
batch_image_size = torch.cat(image_size_list)
batch_noise_bboxes_raw = torch.stack(noise_bboxes_raw_list)
return (time_pairs, batch_noise_bboxes, batch_noise_bboxes_raw,
batch_image_size)
def predict_noise_from_start(self, x_t, t, x0):
results = (extract(
self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
return results
def inference(self, box_cls, box_pred, cfg, device):
"""
Args:
box_cls (Tensor): tensor of shape (batch_size, num_proposals, K).
The tensor predicts the classification probability for
each proposal.
box_pred (Tensor): tensors of shape (batch_size, num_proposals, 4).
The tensor predicts 4-vector (x,y,w,h) box
regression values for every proposal
Returns:
results (List[Instances]): a list of #images elements.
"""
results = []
if self.use_focal_loss or self.use_fed_loss:
scores = torch.sigmoid(box_cls)
labels = torch.arange(
self.num_classes,
device=device).unsqueeze(0).repeat(self.num_proposals,
1).flatten(0, 1)
box_pred_list = []
scores_list = []
labels_list = []
for i, (scores_per_image,
box_pred_per_image) in enumerate(zip(scores, box_pred)):
scores_per_image, topk_indices = scores_per_image.flatten(
0, 1).topk(
self.num_proposals, sorted=False)
labels_per_image = labels[topk_indices]
box_pred_per_image = box_pred_per_image.view(-1, 1, 4).repeat(
1, self.num_classes, 1).view(-1, 4)
box_pred_per_image = box_pred_per_image[topk_indices]
if self.use_ensemble and self.sampling_timesteps > 1:
box_pred_list.append(box_pred_per_image)
scores_list.append(scores_per_image)
labels_list.append(labels_per_image)
continue
if self.use_nms:
det_bboxes, keep_idxs = batched_nms(
box_pred_per_image, scores_per_image, labels_per_image,
cfg.nms)
box_pred_per_image = box_pred_per_image[keep_idxs]
labels_per_image = labels_per_image[keep_idxs]
# some nms would reweight the score, such as softnms
scores_per_image = det_bboxes[:, -1]
result = InstanceData()
result.bboxes = box_pred_per_image
result.scores = scores_per_image
result.labels = labels_per_image
results.append(result)
else:
# For each box we assign the best class or the second
# best if the best on is `no_object`.
scores, labels = F.softmax(box_cls, dim=-1)[:, :, :-1].max(-1)
for i, (scores_per_image, labels_per_image,
box_pred_per_image) in enumerate(
zip(scores, labels, box_pred)):
if self.use_ensemble and self.sampling_timesteps > 1:
return box_pred_per_image, scores_per_image, \
labels_per_image
if self.use_nms:
det_bboxes, keep_idxs = batched_nms(
box_pred_per_image, scores_per_image, labels_per_image,
cfg.nms)
box_pred_per_image = box_pred_per_image[keep_idxs]
labels_per_image = labels_per_image[keep_idxs]
# some nms would reweight the score, such as softnms
scores_per_image = det_bboxes[:, -1]
result = InstanceData()
result.bboxes = box_pred_per_image
result.scores = scores_per_image
result.labels = labels_per_image
results.append(result)
if self.use_ensemble and self.sampling_timesteps > 1:
return box_pred_list, scores_list, labels_list
else:
return results
@MODELS.register_module()
class SingleDiffusionDetHead(nn.Module):
def __init__(
self,
num_classes=80,
feat_channels=256,
dim_feedforward=2048,
num_cls_convs=1,
num_reg_convs=3,
num_heads=8,
dropout=0.0,
pooler_resolution=7,
scale_clamp=_DEFAULT_SCALE_CLAMP,
bbox_weights=(2.0, 2.0, 1.0, 1.0),
use_focal_loss=True,
use_fed_loss=False,
act_cfg=dict(type='ReLU', inplace=True),
dynamic_conv=dict(dynamic_dim=64, dynamic_num=2)
) -> None:
super().__init__()
self.feat_channels = feat_channels
# Dynamic
self.self_attn = nn.MultiheadAttention(
feat_channels, num_heads, dropout=dropout)
self.inst_interact = DynamicConv(
feat_channels=feat_channels,
pooler_resolution=pooler_resolution,
dynamic_dim=dynamic_conv['dynamic_dim'],
dynamic_num=dynamic_conv['dynamic_num'])
self.linear1 = nn.Linear(feat_channels, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, feat_channels)
self.norm1 = nn.LayerNorm(feat_channels)
self.norm2 = nn.LayerNorm(feat_channels)
self.norm3 = nn.LayerNorm(feat_channels)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = build_activation_layer(act_cfg)
# block time mlp
self.block_time_mlp = nn.Sequential(
nn.SiLU(), nn.Linear(feat_channels * 4, feat_channels * 2))
# cls.
cls_module = list()
for _ in range(num_cls_convs):
cls_module.append(nn.Linear(feat_channels, feat_channels, False))
cls_module.append(nn.LayerNorm(feat_channels))
cls_module.append(nn.ReLU(inplace=True))
self.cls_module = nn.ModuleList(cls_module)
# reg.
reg_module = list()
for _ in range(num_reg_convs):
reg_module.append(nn.Linear(feat_channels, feat_channels, False))
reg_module.append(nn.LayerNorm(feat_channels))
reg_module.append(nn.ReLU(inplace=True))
self.reg_module = nn.ModuleList(reg_module)
# pred.
self.use_focal_loss = use_focal_loss
self.use_fed_loss = use_fed_loss
if self.use_focal_loss or self.use_fed_loss:
self.class_logits = nn.Linear(feat_channels, num_classes)
else:
self.class_logits = nn.Linear(feat_channels, num_classes + 1)
self.bboxes_delta = nn.Linear(feat_channels, 4)
self.scale_clamp = scale_clamp
self.bbox_weights = bbox_weights
def forward(self, features, bboxes, pro_features, pooler, time_emb):
"""
:param bboxes: (N, num_boxes, 4)
:param pro_features: (N, num_boxes, feat_channels)
"""
N, num_boxes = bboxes.shape[:2]
# roi_feature.
proposal_boxes = list()
for b in range(N):
proposal_boxes.append(bboxes[b])
rois = bbox2roi(proposal_boxes)
roi_features = pooler(features, rois)
if pro_features is None:
pro_features = roi_features.view(N, num_boxes, self.feat_channels,
-1).mean(-1)
roi_features = roi_features.view(N * num_boxes, self.feat_channels,
-1).permute(2, 0, 1)
# self_att.
pro_features = pro_features.view(N, num_boxes,
self.feat_channels).permute(1, 0, 2)
pro_features2 = self.self_attn(
pro_features, pro_features, value=pro_features)[0]
pro_features = pro_features + self.dropout1(pro_features2)
pro_features = self.norm1(pro_features)
# inst_interact.
pro_features = pro_features.view(
num_boxes, N,
self.feat_channels).permute(1, 0,
2).reshape(1, N * num_boxes,
self.feat_channels)
pro_features2 = self.inst_interact(pro_features, roi_features)
pro_features = pro_features + self.dropout2(pro_features2)
obj_features = self.norm2(pro_features)
# obj_feature.
obj_features2 = self.linear2(
self.dropout(self.activation(self.linear1(obj_features))))
obj_features = obj_features + self.dropout3(obj_features2)
obj_features = self.norm3(obj_features)
fc_feature = obj_features.transpose(0, 1).reshape(N * num_boxes, -1)
scale_shift = self.block_time_mlp(time_emb)
scale_shift = torch.repeat_interleave(scale_shift, num_boxes, dim=0)
scale, shift = scale_shift.chunk(2, dim=1)
fc_feature = fc_feature * (scale + 1) + shift
cls_feature = fc_feature.clone()
reg_feature = fc_feature.clone()
for cls_layer in self.cls_module:
cls_feature = cls_layer(cls_feature)
for reg_layer in self.reg_module:
reg_feature = reg_layer(reg_feature)
class_logits = self.class_logits(cls_feature)
bboxes_deltas = self.bboxes_delta(reg_feature)
pred_bboxes = self.apply_deltas(bboxes_deltas, bboxes.view(-1, 4))
return (class_logits.view(N, num_boxes,
-1), pred_bboxes.view(N, num_boxes,
-1), obj_features)
def apply_deltas(self, deltas, boxes):
"""Apply transformation `deltas` (dx, dy, dw, dh) to `boxes`.
Args:
deltas (Tensor): transformation deltas of shape (N, k*4),
where k >= 1. deltas[i] represents k potentially
different class-specific box transformations for
the single box boxes[i].
boxes (Tensor): boxes to transform, of shape (N, 4)
"""
boxes = boxes.to(deltas.dtype)
widths = boxes[:, 2] - boxes[:, 0]
heights = boxes[:, 3] - boxes[:, 1]
ctr_x = boxes[:, 0] + 0.5 * widths
ctr_y = boxes[:, 1] + 0.5 * heights
wx, wy, ww, wh = self.bbox_weights
dx = deltas[:, 0::4] / wx
dy = deltas[:, 1::4] / wy
dw = deltas[:, 2::4] / ww
dh = deltas[:, 3::4] / wh
# Prevent sending too large values into torch.exp()
dw = torch.clamp(dw, max=self.scale_clamp)
dh = torch.clamp(dh, max=self.scale_clamp)
pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
pred_w = torch.exp(dw) * widths[:, None]
pred_h = torch.exp(dh) * heights[:, None]
pred_boxes = torch.zeros_like(deltas)
pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w # x1
pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h # y1
pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w # x2
pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h # y2
return pred_boxes
class DynamicConv(nn.Module):
def __init__(self,
feat_channels: int,
dynamic_dim: int = 64,
dynamic_num: int = 2,
pooler_resolution: int = 7) -> None:
super().__init__()
self.feat_channels = feat_channels
self.dynamic_dim = dynamic_dim
self.dynamic_num = dynamic_num
self.num_params = self.feat_channels * self.dynamic_dim
self.dynamic_layer = nn.Linear(self.feat_channels,
self.dynamic_num * self.num_params)
self.norm1 = nn.LayerNorm(self.dynamic_dim)
self.norm2 = nn.LayerNorm(self.feat_channels)
self.activation = nn.ReLU(inplace=True)
num_output = self.feat_channels * pooler_resolution**2
self.out_layer = nn.Linear(num_output, self.feat_channels)
self.norm3 = nn.LayerNorm(self.feat_channels)
def forward(self, pro_features: Tensor, roi_features: Tensor) -> Tensor:
"""Forward function.
Args:
pro_features: (1, N * num_boxes, self.feat_channels)
roi_features: (49, N * num_boxes, self.feat_channels)
Returns:
"""
features = roi_features.permute(1, 0, 2)
parameters = self.dynamic_layer(pro_features).permute(1, 0, 2)
param1 = parameters[:, :, :self.num_params].view(
-1, self.feat_channels, self.dynamic_dim)
param2 = parameters[:, :,
self.num_params:].view(-1, self.dynamic_dim,
self.feat_channels)
features = torch.bmm(features, param1)
features = self.norm1(features)
features = self.activation(features)
features = torch.bmm(features, param2)
features = self.norm2(features)
features = self.activation(features)
features = features.flatten(1)
features = self.out_layer(features)
features = self.norm3(features)
features = self.activation(features)
return features