kairunwen's picture
Update Code
57746f1
r""" Implementation of center-pivot 4D convolution """
import torch
import torch.nn as nn
class CenterPivotConv4d(nn.Module):
r""" CenterPivot 4D conv"""
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=True):
super(CenterPivotConv4d, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size[:2], stride=stride[:2],
bias=bias, padding=padding[:2])
self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size[2:], stride=stride[2:],
bias=bias, padding=padding[2:])
self.stride34 = stride[2:]
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.idx_initialized = False
def prune(self, ct):
bsz, ch, ha, wa, hb, wb = ct.size()
if not self.idx_initialized:
idxh = torch.arange(start=0, end=hb, step=self.stride[2:][0], device=ct.device)
idxw = torch.arange(start=0, end=wb, step=self.stride[2:][1], device=ct.device)
self.len_h = len(idxh)
self.len_w = len(idxw)
self.idx = (idxw.repeat(self.len_h, 1) + idxh.repeat(self.len_w, 1).t() * wb).view(-1)
self.idx_initialized = True
ct_pruned = ct.view(bsz, ch, ha, wa, -1).index_select(4, self.idx).view(bsz, ch, ha, wa, self.len_h, self.len_w)
return ct_pruned
def forward(self, x):
if self.stride[2:][-1] > 1:
out1 = self.prune(x)
else:
out1 = x
bsz, inch, ha, wa, hb, wb = out1.size()
out1 = out1.permute(0, 4, 5, 1, 2, 3).contiguous().view(-1, inch, ha, wa)
out1 = self.conv1(out1)
outch, o_ha, o_wa = out1.size(-3), out1.size(-2), out1.size(-1)
out1 = out1.view(bsz, hb, wb, outch, o_ha, o_wa).permute(0, 3, 4, 5, 1, 2).contiguous()
bsz, inch, ha, wa, hb, wb = x.size()
out2 = x.permute(0, 2, 3, 1, 4, 5).contiguous().view(-1, inch, hb, wb)
out2 = self.conv2(out2)
outch, o_hb, o_wb = out2.size(-3), out2.size(-2), out2.size(-1)
out2 = out2.view(bsz, ha, wa, outch, o_hb, o_wb).permute(0, 3, 1, 2, 4, 5).contiguous()
if out1.size()[-2:] != out2.size()[-2:] and self.padding[-2:] == (0, 0):
out1 = out1.view(bsz, outch, o_ha, o_wa, -1).sum(dim=-1)
out2 = out2.squeeze()
y = out1 + out2
return y