Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import List | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer | |
| from mmcv.cnn.bricks.transformer import (FFN, MultiheadAttention, | |
| build_transformer_layer) | |
| from mmengine.logging import print_log | |
| from torch import Tensor | |
| from mmseg.models.decode_heads.decode_head import BaseDecodeHead | |
| from mmseg.registry import MODELS | |
| from mmseg.utils import SampleList | |
| class KernelUpdator(nn.Module): | |
| """Dynamic Kernel Updator in Kernel Update Head. | |
| Args: | |
| in_channels (int): The number of channels of input feature map. | |
| Default: 256. | |
| feat_channels (int): The number of middle-stage channels in | |
| the kernel updator. Default: 64. | |
| out_channels (int): The number of output channels. | |
| gate_sigmoid (bool): Whether use sigmoid function in gate | |
| mechanism. Default: True. | |
| gate_norm_act (bool): Whether add normalization and activation | |
| layer in gate mechanism. Default: False. | |
| activate_out: Whether add activation after gate mechanism. | |
| Default: False. | |
| norm_cfg (dict | None): Config of norm layers. | |
| Default: dict(type='LN'). | |
| act_cfg (dict): Config of activation layers. | |
| Default: dict(type='ReLU'). | |
| """ | |
| def __init__( | |
| self, | |
| in_channels=256, | |
| feat_channels=64, | |
| out_channels=None, | |
| gate_sigmoid=True, | |
| gate_norm_act=False, | |
| activate_out=False, | |
| norm_cfg=dict(type='LN'), | |
| act_cfg=dict(type='ReLU', inplace=True), | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.feat_channels = feat_channels | |
| self.out_channels_raw = out_channels | |
| self.gate_sigmoid = gate_sigmoid | |
| self.gate_norm_act = gate_norm_act | |
| self.activate_out = activate_out | |
| self.act_cfg = act_cfg | |
| self.norm_cfg = norm_cfg | |
| self.out_channels = out_channels if out_channels else in_channels | |
| self.num_params_in = self.feat_channels | |
| self.num_params_out = self.feat_channels | |
| self.dynamic_layer = nn.Linear( | |
| self.in_channels, self.num_params_in + self.num_params_out) | |
| self.input_layer = nn.Linear(self.in_channels, | |
| self.num_params_in + self.num_params_out, | |
| 1) | |
| self.input_gate = nn.Linear(self.in_channels, self.feat_channels, 1) | |
| self.update_gate = nn.Linear(self.in_channels, self.feat_channels, 1) | |
| if self.gate_norm_act: | |
| self.gate_norm = build_norm_layer(norm_cfg, self.feat_channels)[1] | |
| self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] | |
| self.norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1] | |
| self.input_norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] | |
| self.input_norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1] | |
| self.activation = build_activation_layer(act_cfg) | |
| self.fc_layer = nn.Linear(self.feat_channels, self.out_channels, 1) | |
| self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1] | |
| def forward(self, update_feature, input_feature): | |
| """Forward function of KernelUpdator. | |
| Args: | |
| update_feature (torch.Tensor): Feature map assembled from | |
| each group. It would be reshaped with last dimension | |
| shape: `self.in_channels`. | |
| input_feature (torch.Tensor): Intermediate feature | |
| with shape: (N, num_classes, conv_kernel_size**2, channels). | |
| Returns: | |
| Tensor: The output tensor of shape (N*C1/C2, K*K, C2), where N is | |
| the number of classes, C1 and C2 are the feature map channels of | |
| KernelUpdateHead and KernelUpdator, respectively. | |
| """ | |
| update_feature = update_feature.reshape(-1, self.in_channels) | |
| num_proposals = update_feature.size(0) | |
| # dynamic_layer works for | |
| # phi_1 and psi_3 in Eq.(4) and (5) of K-Net paper | |
| parameters = self.dynamic_layer(update_feature) | |
| param_in = parameters[:, :self.num_params_in].view( | |
| -1, self.feat_channels) | |
| param_out = parameters[:, -self.num_params_out:].view( | |
| -1, self.feat_channels) | |
| # input_layer works for | |
| # phi_2 and psi_4 in Eq.(4) and (5) of K-Net paper | |
| input_feats = self.input_layer( | |
| input_feature.reshape(num_proposals, -1, self.feat_channels)) | |
| input_in = input_feats[..., :self.num_params_in] | |
| input_out = input_feats[..., -self.num_params_out:] | |
| # `gate_feats` is F^G in K-Net paper | |
| gate_feats = input_in * param_in.unsqueeze(-2) | |
| if self.gate_norm_act: | |
| gate_feats = self.activation(self.gate_norm(gate_feats)) | |
| input_gate = self.input_norm_in(self.input_gate(gate_feats)) | |
| update_gate = self.norm_in(self.update_gate(gate_feats)) | |
| if self.gate_sigmoid: | |
| input_gate = input_gate.sigmoid() | |
| update_gate = update_gate.sigmoid() | |
| param_out = self.norm_out(param_out) | |
| input_out = self.input_norm_out(input_out) | |
| if self.activate_out: | |
| param_out = self.activation(param_out) | |
| input_out = self.activation(input_out) | |
| # Gate mechanism. Eq.(5) in original paper. | |
| # param_out has shape (batch_size, feat_channels, out_channels) | |
| features = update_gate * param_out.unsqueeze( | |
| -2) + input_gate * input_out | |
| features = self.fc_layer(features) | |
| features = self.fc_norm(features) | |
| features = self.activation(features) | |
| return features | |
| class KernelUpdateHead(nn.Module): | |
| """Kernel Update Head in K-Net. | |
| Args: | |
| num_classes (int): Number of classes. Default: 150. | |
| num_ffn_fcs (int): The number of fully-connected layers in | |
| FFNs. Default: 2. | |
| num_heads (int): The number of parallel attention heads. | |
| Default: 8. | |
| num_mask_fcs (int): The number of fully connected layers for | |
| mask prediction. Default: 3. | |
| feedforward_channels (int): The hidden dimension of FFNs. | |
| Defaults: 2048. | |
| in_channels (int): The number of channels of input feature map. | |
| Default: 256. | |
| out_channels (int): The number of output channels. | |
| Default: 256. | |
| dropout (float): The Probability of an element to be | |
| zeroed in MultiheadAttention and FFN. Default 0.0. | |
| act_cfg (dict): Config of activation layers. | |
| Default: dict(type='ReLU'). | |
| ffn_act_cfg (dict): Config of activation layers in FFN. | |
| Default: dict(type='ReLU'). | |
| conv_kernel_size (int): The kernel size of convolution in | |
| Kernel Update Head for dynamic kernel updation. | |
| Default: 1. | |
| feat_transform_cfg (dict | None): Config of feature transform. | |
| Default: None. | |
| kernel_init (bool): Whether initiate mask kernel in mask head. | |
| Default: False. | |
| with_ffn (bool): Whether add FFN in kernel update head. | |
| Default: True. | |
| feat_gather_stride (int): Stride of convolution in feature transform. | |
| Default: 1. | |
| mask_transform_stride (int): Stride of mask transform. | |
| Default: 1. | |
| kernel_updator_cfg (dict): Config of kernel updator. | |
| Default: dict( | |
| type='DynamicConv', | |
| in_channels=256, | |
| feat_channels=64, | |
| out_channels=256, | |
| act_cfg=dict(type='ReLU', inplace=True), | |
| norm_cfg=dict(type='LN')). | |
| """ | |
| def __init__(self, | |
| num_classes=150, | |
| num_ffn_fcs=2, | |
| num_heads=8, | |
| num_mask_fcs=3, | |
| feedforward_channels=2048, | |
| in_channels=256, | |
| out_channels=256, | |
| dropout=0.0, | |
| act_cfg=dict(type='ReLU', inplace=True), | |
| ffn_act_cfg=dict(type='ReLU', inplace=True), | |
| conv_kernel_size=1, | |
| feat_transform_cfg=None, | |
| kernel_init=False, | |
| with_ffn=True, | |
| feat_gather_stride=1, | |
| mask_transform_stride=1, | |
| kernel_updator_cfg=dict( | |
| type='DynamicConv', | |
| in_channels=256, | |
| feat_channels=64, | |
| out_channels=256, | |
| act_cfg=dict(type='ReLU', inplace=True), | |
| norm_cfg=dict(type='LN'))): | |
| super().__init__() | |
| self.num_classes = num_classes | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.fp16_enabled = False | |
| self.dropout = dropout | |
| self.num_heads = num_heads | |
| self.kernel_init = kernel_init | |
| self.with_ffn = with_ffn | |
| self.conv_kernel_size = conv_kernel_size | |
| self.feat_gather_stride = feat_gather_stride | |
| self.mask_transform_stride = mask_transform_stride | |
| self.attention = MultiheadAttention(in_channels * conv_kernel_size**2, | |
| num_heads, dropout) | |
| self.attention_norm = build_norm_layer( | |
| dict(type='LN'), in_channels * conv_kernel_size**2)[1] | |
| self.kernel_update_conv = build_transformer_layer(kernel_updator_cfg) | |
| if feat_transform_cfg is not None: | |
| kernel_size = feat_transform_cfg.pop('kernel_size', 1) | |
| transform_channels = in_channels | |
| self.feat_transform = ConvModule( | |
| transform_channels, | |
| in_channels, | |
| kernel_size, | |
| stride=feat_gather_stride, | |
| padding=int(feat_gather_stride // 2), | |
| **feat_transform_cfg) | |
| else: | |
| self.feat_transform = None | |
| if self.with_ffn: | |
| self.ffn = FFN( | |
| in_channels, | |
| feedforward_channels, | |
| num_ffn_fcs, | |
| act_cfg=ffn_act_cfg, | |
| dropout=dropout) | |
| self.ffn_norm = build_norm_layer(dict(type='LN'), in_channels)[1] | |
| self.mask_fcs = nn.ModuleList() | |
| for _ in range(num_mask_fcs): | |
| self.mask_fcs.append( | |
| nn.Linear(in_channels, in_channels, bias=False)) | |
| self.mask_fcs.append( | |
| build_norm_layer(dict(type='LN'), in_channels)[1]) | |
| self.mask_fcs.append(build_activation_layer(act_cfg)) | |
| self.fc_mask = nn.Linear(in_channels, out_channels) | |
| def init_weights(self): | |
| """Use xavier initialization for all weight parameter and set | |
| classification head bias as a specific value when use focal loss.""" | |
| for p in self.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| else: | |
| # adopt the default initialization for | |
| # the weight and bias of the layer norm | |
| pass | |
| if self.kernel_init: | |
| print_log( | |
| 'mask kernel in mask head is normal initialized by std 0.01') | |
| nn.init.normal_(self.fc_mask.weight, mean=0, std=0.01) | |
| def forward(self, x, proposal_feat, mask_preds, mask_shape=None): | |
| """Forward function of Dynamic Instance Interactive Head. | |
| Args: | |
| x (Tensor): Feature map from FPN with shape | |
| (batch_size, feature_dimensions, H , W). | |
| proposal_feat (Tensor): Intermediate feature get from | |
| diihead in last stage, has shape | |
| (batch_size, num_proposals, feature_dimensions) | |
| mask_preds (Tensor): mask prediction from the former stage in shape | |
| (batch_size, num_proposals, H, W). | |
| Returns: | |
| Tuple: The first tensor is predicted mask with shape | |
| (N, num_classes, H, W), the second tensor is dynamic kernel | |
| with shape (N, num_classes, channels, K, K). | |
| """ | |
| N, num_proposals = proposal_feat.shape[:2] | |
| if self.feat_transform is not None: | |
| x = self.feat_transform(x) | |
| C, H, W = x.shape[-3:] | |
| mask_h, mask_w = mask_preds.shape[-2:] | |
| if mask_h != H or mask_w != W: | |
| gather_mask = F.interpolate( | |
| mask_preds, (H, W), align_corners=False, mode='bilinear') | |
| else: | |
| gather_mask = mask_preds | |
| sigmoid_masks = gather_mask.softmax(dim=1) | |
| # Group Feature Assembling. Eq.(3) in original paper. | |
| # einsum is faster than bmm by 30% | |
| x_feat = torch.einsum('bnhw,bchw->bnc', sigmoid_masks, x) | |
| # obj_feat in shape [B, N, C, K, K] -> [B, N, C, K*K] -> [B, N, K*K, C] | |
| proposal_feat = proposal_feat.reshape(N, num_proposals, | |
| self.in_channels, | |
| -1).permute(0, 1, 3, 2) | |
| obj_feat = self.kernel_update_conv(x_feat, proposal_feat) | |
| # [B, N, K*K, C] -> [B, N, K*K*C] -> [N, B, K*K*C] | |
| obj_feat = obj_feat.reshape(N, num_proposals, -1).permute(1, 0, 2) | |
| obj_feat = self.attention_norm(self.attention(obj_feat)) | |
| # [N, B, K*K*C] -> [B, N, K*K*C] | |
| obj_feat = obj_feat.permute(1, 0, 2) | |
| # obj_feat in shape [B, N, K*K*C] -> [B, N, K*K, C] | |
| obj_feat = obj_feat.reshape(N, num_proposals, -1, self.in_channels) | |
| # FFN | |
| if self.with_ffn: | |
| obj_feat = self.ffn_norm(self.ffn(obj_feat)) | |
| mask_feat = obj_feat | |
| for reg_layer in self.mask_fcs: | |
| mask_feat = reg_layer(mask_feat) | |
| # [B, N, K*K, C] -> [B, N, C, K*K] | |
| mask_feat = self.fc_mask(mask_feat).permute(0, 1, 3, 2) | |
| if (self.mask_transform_stride == 2 and self.feat_gather_stride == 1): | |
| mask_x = F.interpolate( | |
| x, scale_factor=0.5, mode='bilinear', align_corners=False) | |
| H, W = mask_x.shape[-2:] | |
| else: | |
| mask_x = x | |
| # group conv is 5x faster than unfold and uses about 1/5 memory | |
| # Group conv vs. unfold vs. concat batch, 2.9ms :13.5ms :3.8ms | |
| # Group conv vs. unfold vs. concat batch, 278 : 1420 : 369 | |
| # but in real training group conv is slower than concat batch | |
| # so we keep using concat batch. | |
| # fold_x = F.unfold( | |
| # mask_x, | |
| # self.conv_kernel_size, | |
| # padding=int(self.conv_kernel_size // 2)) | |
| # mask_feat = mask_feat.reshape(N, num_proposals, -1) | |
| # new_mask_preds = torch.einsum('bnc,bcl->bnl', mask_feat, fold_x) | |
| # [B, N, C, K*K] -> [B*N, C, K, K] | |
| mask_feat = mask_feat.reshape(N, num_proposals, C, | |
| self.conv_kernel_size, | |
| self.conv_kernel_size) | |
| # [B, C, H, W] -> [1, B*C, H, W] | |
| new_mask_preds = [] | |
| for i in range(N): | |
| new_mask_preds.append( | |
| F.conv2d( | |
| mask_x[i:i + 1], | |
| mask_feat[i], | |
| padding=int(self.conv_kernel_size // 2))) | |
| new_mask_preds = torch.cat(new_mask_preds, dim=0) | |
| new_mask_preds = new_mask_preds.reshape(N, num_proposals, H, W) | |
| if self.mask_transform_stride == 2: | |
| new_mask_preds = F.interpolate( | |
| new_mask_preds, | |
| scale_factor=2, | |
| mode='bilinear', | |
| align_corners=False) | |
| if mask_shape is not None and mask_shape[0] != H: | |
| new_mask_preds = F.interpolate( | |
| new_mask_preds, | |
| mask_shape, | |
| align_corners=False, | |
| mode='bilinear') | |
| return new_mask_preds, obj_feat.permute(0, 1, 3, 2).reshape( | |
| N, num_proposals, self.in_channels, self.conv_kernel_size, | |
| self.conv_kernel_size) | |
| class IterativeDecodeHead(BaseDecodeHead): | |
| """K-Net: Towards Unified Image Segmentation. | |
| This head is the implementation of | |
| `K-Net: <https://arxiv.org/abs/2106.14855>`_. | |
| Args: | |
| num_stages (int): The number of stages (kernel update heads) | |
| in IterativeDecodeHead. Default: 3. | |
| kernel_generate_head:(dict): Config of kernel generate head which | |
| generate mask predictions, dynamic kernels and class predictions | |
| for next kernel update heads. | |
| kernel_update_head (dict): Config of kernel update head which refine | |
| dynamic kernels and class predictions iteratively. | |
| """ | |
| def __init__(self, num_stages, kernel_generate_head, kernel_update_head, | |
| **kwargs): | |
| # ``IterativeDecodeHead`` would skip initialization of | |
| # ``BaseDecodeHead`` which would be called when building | |
| # ``self.kernel_generate_head``. | |
| super(BaseDecodeHead, self).__init__(**kwargs) | |
| assert num_stages == len(kernel_update_head) | |
| self.num_stages = num_stages | |
| self.kernel_generate_head = MODELS.build(kernel_generate_head) | |
| self.kernel_update_head = nn.ModuleList() | |
| self.align_corners = self.kernel_generate_head.align_corners | |
| self.num_classes = self.kernel_generate_head.num_classes | |
| self.input_transform = self.kernel_generate_head.input_transform | |
| self.ignore_index = self.kernel_generate_head.ignore_index | |
| self.out_channels = self.num_classes | |
| for head_cfg in kernel_update_head: | |
| self.kernel_update_head.append(MODELS.build(head_cfg)) | |
| def forward(self, inputs): | |
| """Forward function.""" | |
| feats = self.kernel_generate_head._forward_feature(inputs) | |
| sem_seg = self.kernel_generate_head.cls_seg(feats) | |
| seg_kernels = self.kernel_generate_head.conv_seg.weight.clone() | |
| seg_kernels = seg_kernels[None].expand( | |
| feats.size(0), *seg_kernels.size()) | |
| stage_segs = [sem_seg] | |
| for i in range(self.num_stages): | |
| sem_seg, seg_kernels = self.kernel_update_head[i](feats, | |
| seg_kernels, | |
| sem_seg) | |
| stage_segs.append(sem_seg) | |
| if self.training: | |
| return stage_segs | |
| # only return the prediction of the last stage during testing | |
| return stage_segs[-1] | |
| def loss_by_feat(self, seg_logits: List[Tensor], | |
| batch_data_samples: SampleList, **kwargs) -> dict: | |
| losses = dict() | |
| for i, logit in enumerate(seg_logits): | |
| loss = self.kernel_generate_head.loss_by_feat( | |
| logit, batch_data_samples) | |
| for k, v in loss.items(): | |
| losses[f'{k}.s{i}'] = v | |
| return losses | |