|
|
|
|
|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from mmengine.model.weight_init import caffe2_xavier_init, kaiming_init |
|
from torch.nn import init |
|
|
|
from mmdet.registry import MODELS |
|
|
|
|
|
def _make_stack_3x3_convs(num_convs, |
|
in_channels, |
|
out_channels, |
|
act_cfg=dict(type='ReLU', inplace=True)): |
|
convs = [] |
|
for _ in range(num_convs): |
|
convs.append(nn.Conv2d(in_channels, out_channels, 3, padding=1)) |
|
convs.append(MODELS.build(act_cfg)) |
|
in_channels = out_channels |
|
return nn.Sequential(*convs) |
|
|
|
|
|
class InstanceBranch(nn.Module): |
|
|
|
def __init__(self, |
|
in_channels, |
|
dim=256, |
|
num_convs=4, |
|
num_masks=100, |
|
num_classes=80, |
|
kernel_dim=128, |
|
act_cfg=dict(type='ReLU', inplace=True)): |
|
super().__init__() |
|
num_masks = num_masks |
|
self.num_classes = num_classes |
|
|
|
self.inst_convs = _make_stack_3x3_convs(num_convs, in_channels, dim, |
|
act_cfg) |
|
|
|
self.iam_conv = nn.Conv2d(dim, num_masks, 3, padding=1) |
|
|
|
|
|
self.cls_score = nn.Linear(dim, self.num_classes) |
|
self.mask_kernel = nn.Linear(dim, kernel_dim) |
|
self.objectness = nn.Linear(dim, 1) |
|
|
|
self.prior_prob = 0.01 |
|
self._init_weights() |
|
|
|
def _init_weights(self): |
|
for m in self.inst_convs.modules(): |
|
if isinstance(m, nn.Conv2d): |
|
kaiming_init(m) |
|
bias_value = -math.log((1 - self.prior_prob) / self.prior_prob) |
|
for module in [self.iam_conv, self.cls_score]: |
|
init.constant_(module.bias, bias_value) |
|
init.normal_(self.iam_conv.weight, std=0.01) |
|
init.normal_(self.cls_score.weight, std=0.01) |
|
|
|
init.normal_(self.mask_kernel.weight, std=0.01) |
|
init.constant_(self.mask_kernel.bias, 0.0) |
|
|
|
def forward(self, features): |
|
|
|
features = self.inst_convs(features) |
|
|
|
iam = self.iam_conv(features) |
|
iam_prob = iam.sigmoid() |
|
|
|
B, N = iam_prob.shape[:2] |
|
C = features.size(1) |
|
|
|
iam_prob = iam_prob.view(B, N, -1) |
|
normalizer = iam_prob.sum(-1).clamp(min=1e-6) |
|
iam_prob = iam_prob / normalizer[:, :, None] |
|
|
|
inst_features = torch.bmm(iam_prob, |
|
features.view(B, C, -1).permute(0, 2, 1)) |
|
|
|
pred_logits = self.cls_score(inst_features) |
|
pred_kernel = self.mask_kernel(inst_features) |
|
pred_scores = self.objectness(inst_features) |
|
return pred_logits, pred_kernel, pred_scores, iam |
|
|
|
|
|
class MaskBranch(nn.Module): |
|
|
|
def __init__(self, |
|
in_channels, |
|
dim=256, |
|
num_convs=4, |
|
kernel_dim=128, |
|
act_cfg=dict(type='ReLU', inplace=True)): |
|
super().__init__() |
|
self.mask_convs = _make_stack_3x3_convs(num_convs, in_channels, dim, |
|
act_cfg) |
|
self.projection = nn.Conv2d(dim, kernel_dim, kernel_size=1) |
|
self._init_weights() |
|
|
|
def _init_weights(self): |
|
for m in self.mask_convs.modules(): |
|
if isinstance(m, nn.Conv2d): |
|
kaiming_init(m) |
|
kaiming_init(self.projection) |
|
|
|
def forward(self, features): |
|
|
|
features = self.mask_convs(features) |
|
return self.projection(features) |
|
|
|
|
|
@MODELS.register_module() |
|
class BaseIAMDecoder(nn.Module): |
|
|
|
def __init__(self, |
|
in_channels, |
|
num_classes, |
|
ins_dim=256, |
|
ins_conv=4, |
|
mask_dim=256, |
|
mask_conv=4, |
|
kernel_dim=128, |
|
scale_factor=2.0, |
|
output_iam=False, |
|
num_masks=100, |
|
act_cfg=dict(type='ReLU', inplace=True)): |
|
super().__init__() |
|
|
|
in_channels = in_channels |
|
|
|
self.scale_factor = scale_factor |
|
self.output_iam = output_iam |
|
|
|
self.inst_branch = InstanceBranch( |
|
in_channels, |
|
dim=ins_dim, |
|
num_convs=ins_conv, |
|
num_masks=num_masks, |
|
num_classes=num_classes, |
|
kernel_dim=kernel_dim, |
|
act_cfg=act_cfg) |
|
self.mask_branch = MaskBranch( |
|
in_channels, |
|
dim=mask_dim, |
|
num_convs=mask_conv, |
|
kernel_dim=kernel_dim, |
|
act_cfg=act_cfg) |
|
|
|
@torch.no_grad() |
|
def compute_coordinates_linspace(self, x): |
|
|
|
h, w = x.size(2), x.size(3) |
|
y_loc = torch.linspace(-1, 1, h, device=x.device) |
|
x_loc = torch.linspace(-1, 1, w, device=x.device) |
|
y_loc, x_loc = torch.meshgrid(y_loc, x_loc) |
|
y_loc = y_loc.expand([x.shape[0], 1, -1, -1]) |
|
x_loc = x_loc.expand([x.shape[0], 1, -1, -1]) |
|
locations = torch.cat([x_loc, y_loc], 1) |
|
return locations.to(x) |
|
|
|
@torch.no_grad() |
|
def compute_coordinates(self, x): |
|
h, w = x.size(2), x.size(3) |
|
y_loc = -1.0 + 2.0 * torch.arange(h, device=x.device) / (h - 1) |
|
x_loc = -1.0 + 2.0 * torch.arange(w, device=x.device) / (w - 1) |
|
y_loc, x_loc = torch.meshgrid(y_loc, x_loc) |
|
y_loc = y_loc.expand([x.shape[0], 1, -1, -1]) |
|
x_loc = x_loc.expand([x.shape[0], 1, -1, -1]) |
|
locations = torch.cat([x_loc, y_loc], 1) |
|
return locations.to(x) |
|
|
|
def forward(self, features): |
|
coord_features = self.compute_coordinates(features) |
|
features = torch.cat([coord_features, features], dim=1) |
|
pred_logits, pred_kernel, pred_scores, iam = self.inst_branch(features) |
|
mask_features = self.mask_branch(features) |
|
|
|
N = pred_kernel.shape[1] |
|
|
|
B, C, H, W = mask_features.shape |
|
pred_masks = torch.bmm(pred_kernel, |
|
mask_features.view(B, C, |
|
H * W)).view(B, N, H, W) |
|
|
|
pred_masks = F.interpolate( |
|
pred_masks, |
|
scale_factor=self.scale_factor, |
|
mode='bilinear', |
|
align_corners=False) |
|
|
|
output = { |
|
'pred_logits': pred_logits, |
|
'pred_masks': pred_masks, |
|
'pred_scores': pred_scores, |
|
} |
|
|
|
if self.output_iam: |
|
iam = F.interpolate( |
|
iam, |
|
scale_factor=self.scale_factor, |
|
mode='bilinear', |
|
align_corners=False) |
|
output['pred_iam'] = iam |
|
|
|
return output |
|
|
|
|
|
class GroupInstanceBranch(nn.Module): |
|
|
|
def __init__(self, |
|
in_channels, |
|
num_groups=4, |
|
dim=256, |
|
num_convs=4, |
|
num_masks=100, |
|
num_classes=80, |
|
kernel_dim=128, |
|
act_cfg=dict(type='ReLU', inplace=True)): |
|
super().__init__() |
|
self.num_groups = num_groups |
|
self.num_classes = num_classes |
|
|
|
self.inst_convs = _make_stack_3x3_convs( |
|
num_convs, in_channels, dim, act_cfg=act_cfg) |
|
|
|
expand_dim = dim * self.num_groups |
|
self.iam_conv = nn.Conv2d( |
|
dim, |
|
num_masks * self.num_groups, |
|
3, |
|
padding=1, |
|
groups=self.num_groups) |
|
|
|
self.fc = nn.Linear(expand_dim, expand_dim) |
|
|
|
self.cls_score = nn.Linear(expand_dim, self.num_classes) |
|
self.mask_kernel = nn.Linear(expand_dim, kernel_dim) |
|
self.objectness = nn.Linear(expand_dim, 1) |
|
|
|
self.prior_prob = 0.01 |
|
self._init_weights() |
|
|
|
def _init_weights(self): |
|
for m in self.inst_convs.modules(): |
|
if isinstance(m, nn.Conv2d): |
|
kaiming_init(m) |
|
bias_value = -math.log((1 - self.prior_prob) / self.prior_prob) |
|
for module in [self.iam_conv, self.cls_score]: |
|
init.constant_(module.bias, bias_value) |
|
init.normal_(self.iam_conv.weight, std=0.01) |
|
init.normal_(self.cls_score.weight, std=0.01) |
|
|
|
init.normal_(self.mask_kernel.weight, std=0.01) |
|
init.constant_(self.mask_kernel.bias, 0.0) |
|
caffe2_xavier_init(self.fc) |
|
|
|
def forward(self, features): |
|
|
|
features = self.inst_convs(features) |
|
|
|
iam = self.iam_conv(features) |
|
iam_prob = iam.sigmoid() |
|
|
|
B, N = iam_prob.shape[:2] |
|
C = features.size(1) |
|
|
|
iam_prob = iam_prob.view(B, N, -1) |
|
normalizer = iam_prob.sum(-1).clamp(min=1e-6) |
|
iam_prob = iam_prob / normalizer[:, :, None] |
|
|
|
|
|
inst_features = torch.bmm(iam_prob, |
|
features.view(B, C, -1).permute(0, 2, 1)) |
|
|
|
inst_features = inst_features.reshape(B, 4, N // self.num_groups, |
|
-1).transpose(1, 2).reshape( |
|
B, N // self.num_groups, -1) |
|
|
|
inst_features = F.relu_(self.fc(inst_features)) |
|
|
|
pred_logits = self.cls_score(inst_features) |
|
pred_kernel = self.mask_kernel(inst_features) |
|
pred_scores = self.objectness(inst_features) |
|
return pred_logits, pred_kernel, pred_scores, iam |
|
|
|
|
|
@MODELS.register_module() |
|
class GroupIAMDecoder(BaseIAMDecoder): |
|
|
|
def __init__(self, |
|
in_channels, |
|
num_classes, |
|
num_groups=4, |
|
ins_dim=256, |
|
ins_conv=4, |
|
mask_dim=256, |
|
mask_conv=4, |
|
kernel_dim=128, |
|
scale_factor=2.0, |
|
output_iam=False, |
|
num_masks=100, |
|
act_cfg=dict(type='ReLU', inplace=True)): |
|
super().__init__( |
|
in_channels=in_channels, |
|
num_classes=num_classes, |
|
ins_dim=ins_dim, |
|
ins_conv=ins_conv, |
|
mask_dim=mask_dim, |
|
mask_conv=mask_conv, |
|
kernel_dim=kernel_dim, |
|
scale_factor=scale_factor, |
|
output_iam=output_iam, |
|
num_masks=num_masks, |
|
act_cfg=act_cfg) |
|
self.inst_branch = GroupInstanceBranch( |
|
in_channels, |
|
num_groups=num_groups, |
|
dim=ins_dim, |
|
num_convs=ins_conv, |
|
num_masks=num_masks, |
|
num_classes=num_classes, |
|
kernel_dim=kernel_dim, |
|
act_cfg=act_cfg) |
|
|
|
|
|
class GroupInstanceSoftBranch(GroupInstanceBranch): |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.softmax_bias = nn.Parameter(torch.ones([ |
|
1, |
|
])) |
|
|
|
def forward(self, features): |
|
|
|
features = self.inst_convs(features) |
|
|
|
iam = self.iam_conv(features) |
|
|
|
B, N = iam.shape[:2] |
|
C = features.size(1) |
|
|
|
iam_prob = F.softmax(iam.view(B, N, -1) + self.softmax_bias, dim=-1) |
|
|
|
inst_features = torch.bmm(iam_prob, |
|
features.view(B, C, -1).permute(0, 2, 1)) |
|
|
|
inst_features = inst_features.reshape(B, self.num_groups, |
|
N // self.num_groups, |
|
-1).transpose(1, 2).reshape( |
|
B, N // self.num_groups, -1) |
|
|
|
inst_features = F.relu_(self.fc(inst_features)) |
|
|
|
pred_logits = self.cls_score(inst_features) |
|
pred_kernel = self.mask_kernel(inst_features) |
|
pred_scores = self.objectness(inst_features) |
|
return pred_logits, pred_kernel, pred_scores, iam |
|
|
|
|
|
@MODELS.register_module() |
|
class GroupIAMSoftDecoder(BaseIAMDecoder): |
|
|
|
def __init__(self, |
|
in_channels, |
|
num_classes, |
|
num_groups=4, |
|
ins_dim=256, |
|
ins_conv=4, |
|
mask_dim=256, |
|
mask_conv=4, |
|
kernel_dim=128, |
|
scale_factor=2.0, |
|
output_iam=False, |
|
num_masks=100, |
|
act_cfg=dict(type='ReLU', inplace=True)): |
|
super().__init__( |
|
in_channels=in_channels, |
|
num_classes=num_classes, |
|
ins_dim=ins_dim, |
|
ins_conv=ins_conv, |
|
mask_dim=mask_dim, |
|
mask_conv=mask_conv, |
|
kernel_dim=kernel_dim, |
|
scale_factor=scale_factor, |
|
output_iam=output_iam, |
|
num_masks=num_masks, |
|
act_cfg=act_cfg) |
|
self.inst_branch = GroupInstanceSoftBranch( |
|
in_channels, |
|
num_groups=num_groups, |
|
dim=ins_dim, |
|
num_convs=ins_conv, |
|
num_masks=num_masks, |
|
num_classes=num_classes, |
|
kernel_dim=kernel_dim, |
|
act_cfg=act_cfg) |
|
|