Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from annotator.uniformer.mmcv.cnn import CONV_LAYERS, ConvAWS2d, constant_init | |
| from annotator.uniformer.mmcv.ops.deform_conv import deform_conv2d | |
| from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version | |
| class SAConv2d(ConvAWS2d): | |
| """SAC (Switchable Atrous Convolution) | |
| This is an implementation of SAC in DetectoRS | |
| (https://arxiv.org/pdf/2006.02334.pdf). | |
| Args: | |
| in_channels (int): Number of channels in the input image | |
| out_channels (int): Number of channels produced by the convolution | |
| kernel_size (int or tuple): Size of the convolving kernel | |
| stride (int or tuple, optional): Stride of the convolution. Default: 1 | |
| padding (int or tuple, optional): Zero-padding added to both sides of | |
| the input. Default: 0 | |
| padding_mode (string, optional): ``'zeros'``, ``'reflect'``, | |
| ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` | |
| dilation (int or tuple, optional): Spacing between kernel elements. | |
| Default: 1 | |
| groups (int, optional): Number of blocked connections from input | |
| channels to output channels. Default: 1 | |
| bias (bool, optional): If ``True``, adds a learnable bias to the | |
| output. Default: ``True`` | |
| use_deform: If ``True``, replace convolution with deformable | |
| convolution. Default: ``False``. | |
| """ | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride=1, | |
| padding=0, | |
| dilation=1, | |
| groups=1, | |
| bias=True, | |
| use_deform=False): | |
| super().__init__( | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation, | |
| groups=groups, | |
| bias=bias) | |
| self.use_deform = use_deform | |
| self.switch = nn.Conv2d( | |
| self.in_channels, 1, kernel_size=1, stride=stride, bias=True) | |
| self.weight_diff = nn.Parameter(torch.Tensor(self.weight.size())) | |
| self.pre_context = nn.Conv2d( | |
| self.in_channels, self.in_channels, kernel_size=1, bias=True) | |
| self.post_context = nn.Conv2d( | |
| self.out_channels, self.out_channels, kernel_size=1, bias=True) | |
| if self.use_deform: | |
| self.offset_s = nn.Conv2d( | |
| self.in_channels, | |
| 18, | |
| kernel_size=3, | |
| padding=1, | |
| stride=stride, | |
| bias=True) | |
| self.offset_l = nn.Conv2d( | |
| self.in_channels, | |
| 18, | |
| kernel_size=3, | |
| padding=1, | |
| stride=stride, | |
| bias=True) | |
| self.init_weights() | |
| def init_weights(self): | |
| constant_init(self.switch, 0, bias=1) | |
| self.weight_diff.data.zero_() | |
| constant_init(self.pre_context, 0) | |
| constant_init(self.post_context, 0) | |
| if self.use_deform: | |
| constant_init(self.offset_s, 0) | |
| constant_init(self.offset_l, 0) | |
| def forward(self, x): | |
| # pre-context | |
| avg_x = F.adaptive_avg_pool2d(x, output_size=1) | |
| avg_x = self.pre_context(avg_x) | |
| avg_x = avg_x.expand_as(x) | |
| x = x + avg_x | |
| # switch | |
| avg_x = F.pad(x, pad=(2, 2, 2, 2), mode='reflect') | |
| avg_x = F.avg_pool2d(avg_x, kernel_size=5, stride=1, padding=0) | |
| switch = self.switch(avg_x) | |
| # sac | |
| weight = self._get_weight(self.weight) | |
| zero_bias = torch.zeros( | |
| self.out_channels, device=weight.device, dtype=weight.dtype) | |
| if self.use_deform: | |
| offset = self.offset_s(avg_x) | |
| out_s = deform_conv2d(x, offset, weight, self.stride, self.padding, | |
| self.dilation, self.groups, 1) | |
| else: | |
| if (TORCH_VERSION == 'parrots' | |
| or digit_version(TORCH_VERSION) < digit_version('1.5.0')): | |
| out_s = super().conv2d_forward(x, weight) | |
| elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'): | |
| # bias is a required argument of _conv_forward in torch 1.8.0 | |
| out_s = super()._conv_forward(x, weight, zero_bias) | |
| else: | |
| out_s = super()._conv_forward(x, weight) | |
| ori_p = self.padding | |
| ori_d = self.dilation | |
| self.padding = tuple(3 * p for p in self.padding) | |
| self.dilation = tuple(3 * d for d in self.dilation) | |
| weight = weight + self.weight_diff | |
| if self.use_deform: | |
| offset = self.offset_l(avg_x) | |
| out_l = deform_conv2d(x, offset, weight, self.stride, self.padding, | |
| self.dilation, self.groups, 1) | |
| else: | |
| if (TORCH_VERSION == 'parrots' | |
| or digit_version(TORCH_VERSION) < digit_version('1.5.0')): | |
| out_l = super().conv2d_forward(x, weight) | |
| elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'): | |
| # bias is a required argument of _conv_forward in torch 1.8.0 | |
| out_l = super()._conv_forward(x, weight, zero_bias) | |
| else: | |
| out_l = super()._conv_forward(x, weight) | |
| out = switch * out_s + (1 - switch) * out_l | |
| self.padding = ori_p | |
| self.dilation = ori_d | |
| # post-context | |
| avg_x = F.adaptive_avg_pool2d(out, output_size=1) | |
| avg_x = self.post_context(avg_x) | |
| avg_x = avg_x.expand_as(out) | |
| out = out + avg_x | |
| return out | |