ghost / network /AADLayer.py
Jagrut Thakare
v1
9be8aa9
import torch
import torch.nn as nn
class AADLayer(nn.Module):
def __init__(self, c_x, attr_c, c_id):
super(AADLayer, self).__init__()
self.attr_c = attr_c
self.c_id = c_id
self.c_x = c_x
self.conv1 = nn.Conv2d(attr_c, c_x, kernel_size=1, stride=1, padding=0, bias=True)
self.conv2 = nn.Conv2d(attr_c, c_x, kernel_size=1, stride=1, padding=0, bias=True)
self.fc1 = nn.Linear(c_id, c_x)
self.fc2 = nn.Linear(c_id, c_x)
self.norm = nn.InstanceNorm2d(c_x, affine=False)
self.conv_h = nn.Conv2d(c_x, 1, kernel_size=1, stride=1, padding=0, bias=True)
def forward(self, h_in, z_attr, z_id):
# h_in cxnxn
# zid 256x1x1
# zattr cxnxn
h = self.norm(h_in)
gamma_attr = self.conv1(z_attr)
beta_attr = self.conv2(z_attr)
gamma_id = self.fc1(z_id)
beta_id = self.fc2(z_id)
A = gamma_attr * h + beta_attr
gamma_id = gamma_id.reshape(h.shape[0], self.c_x, 1, 1).expand_as(h)
beta_id = beta_id.reshape(h.shape[0], self.c_x, 1, 1).expand_as(h)
I = gamma_id * h + beta_id
M = torch.sigmoid(self.conv_h(h))
out = (torch.ones_like(M).to(M.device) - M) * A + M * I
return out
class AddBlocksSequential(nn.Sequential):
def forward(self, *inputs):
h, z_attr, z_id = inputs
for i, module in enumerate(self._modules.values()):
if i%3 == 0 and i > 0:
inputs = (inputs, z_attr, z_id)
if type(inputs) == tuple:
inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs
class AAD_ResBlk(nn.Module):
def __init__(self, cin, cout, c_attr, c_id, num_blocks):
super(AAD_ResBlk, self).__init__()
self.cin = cin
self.cout = cout
add_blocks = []
for i in range(num_blocks):
out = cin if i < (num_blocks-1) else cout
add_blocks.extend([AADLayer(cin, c_attr, c_id),
nn.ReLU(inplace=True),
nn.Conv2d(cin, out, kernel_size=3, stride=1, padding=1, bias=False)
])
self.add_blocks = AddBlocksSequential(*add_blocks)
if cin != cout:
last_add_block = [AADLayer(cin, c_attr, c_id),
nn.ReLU(inplace=True),
nn.Conv2d(cin, cout, kernel_size=3, stride=1, padding=1, bias=False)]
self.last_add_block = AddBlocksSequential(*last_add_block)
def forward(self, h, z_attr, z_id):
x = self.add_blocks(h, z_attr, z_id)
if self.cin != self.cout:
h = self.last_add_block(h, z_attr, z_id)
x = x + h
return x