import torch from torch import nn from detectron2.layers import Conv2d class _NewEmptyTensorOp(torch.autograd.Function): @staticmethod def forward(ctx, x, new_shape): ctx.shape = x.shape return x.new_empty(new_shape) @staticmethod def backward(ctx, grad): shape = ctx.shape return _NewEmptyTensorOp.apply(grad, shape), None class DFConv2d(nn.Module): """Deformable convolutional layer""" def __init__( self, in_channels, out_channels, with_modulated_dcn=True, kernel_size=3, stride=1, groups=1, dilation=1, deformable_groups=1, bias=False, padding=None ): super(DFConv2d, self).__init__() if isinstance(kernel_size, (list, tuple)): assert isinstance(stride, (list, tuple)) assert isinstance(dilation, (list, tuple)) assert len(kernel_size) == 2 assert len(stride) == 2 assert len(dilation) == 2 padding = ( dilation[0] * (kernel_size[0] - 1) // 2, dilation[1] * (kernel_size[1] - 1) // 2 ) offset_base_channels = kernel_size[0] * kernel_size[1] else: padding = dilation * (kernel_size - 1) // 2 offset_base_channels = kernel_size * kernel_size if with_modulated_dcn: from detectron2.layers.deform_conv import ModulatedDeformConv offset_channels = offset_base_channels * 3 # default: 27 conv_block = ModulatedDeformConv else: from detectron2.layers.deform_conv import DeformConv offset_channels = offset_base_channels * 2 # default: 18 conv_block = DeformConv self.offset = Conv2d( in_channels, deformable_groups * offset_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=1, dilation=dilation ) nn.init.constant_(self.offset.weight, 0) nn.init.constant_(self.offset.bias, 0) ''' for l in [self.offset, ]: nn.init.kaiming_uniform_(l.weight, a=1) torch.nn.init.constant_(l.bias, 0.) ''' self.conv = conv_block( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, deformable_groups=deformable_groups, bias=bias ) self.with_modulated_dcn = with_modulated_dcn self.kernel_size = kernel_size self.stride = stride self.padding = padding self.dilation = dilation self.offset_split = offset_base_channels * deformable_groups * 2 def forward(self, x, return_offset=False): if x.numel() > 0: if not self.with_modulated_dcn: offset_mask = self.offset(x) x = self.conv(x, offset_mask) else: offset_mask = self.offset(x) offset = offset_mask[:, :self.offset_split, :, :] mask = offset_mask[:, self.offset_split:, :, :].sigmoid() x = self.conv(x, offset, mask) if return_offset: return x, offset_mask return x # get output shape output_shape = [ (i + 2 * p - (di * (k - 1) + 1)) // d + 1 for i, p, di, k, d in zip( x.shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride ) ] output_shape = [x.shape[0], self.conv.weight.shape[0]] + output_shape return _NewEmptyTensorOp.apply(x, output_shape)