Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
class AADLayer(nn.Module): | |
def __init__(self, c_x, attr_c, c_id): | |
super(AADLayer, self).__init__() | |
self.attr_c = attr_c | |
self.c_id = c_id | |
self.c_x = c_x | |
self.conv1 = nn.Conv2d(attr_c, c_x, kernel_size=1, stride=1, padding=0, bias=True) | |
self.conv2 = nn.Conv2d(attr_c, c_x, kernel_size=1, stride=1, padding=0, bias=True) | |
self.fc1 = nn.Linear(c_id, c_x) | |
self.fc2 = nn.Linear(c_id, c_x) | |
self.norm = nn.InstanceNorm2d(c_x, affine=False) | |
self.conv_h = nn.Conv2d(c_x, 1, kernel_size=1, stride=1, padding=0, bias=True) | |
def forward(self, h_in, z_attr, z_id): | |
# h_in cxnxn | |
# zid 256x1x1 | |
# zattr cxnxn | |
h = self.norm(h_in) | |
gamma_attr = self.conv1(z_attr) | |
beta_attr = self.conv2(z_attr) | |
gamma_id = self.fc1(z_id) | |
beta_id = self.fc2(z_id) | |
A = gamma_attr * h + beta_attr | |
gamma_id = gamma_id.reshape(h.shape[0], self.c_x, 1, 1).expand_as(h) | |
beta_id = beta_id.reshape(h.shape[0], self.c_x, 1, 1).expand_as(h) | |
I = gamma_id * h + beta_id | |
M = torch.sigmoid(self.conv_h(h)) | |
out = (torch.ones_like(M).to(M.device) - M) * A + M * I | |
return out | |
class AddBlocksSequential(nn.Sequential): | |
def forward(self, *inputs): | |
h, z_attr, z_id = inputs | |
for i, module in enumerate(self._modules.values()): | |
if i%3 == 0 and i > 0: | |
inputs = (inputs, z_attr, z_id) | |
if type(inputs) == tuple: | |
inputs = module(*inputs) | |
else: | |
inputs = module(inputs) | |
return inputs | |
class AAD_ResBlk(nn.Module): | |
def __init__(self, cin, cout, c_attr, c_id, num_blocks): | |
super(AAD_ResBlk, self).__init__() | |
self.cin = cin | |
self.cout = cout | |
add_blocks = [] | |
for i in range(num_blocks): | |
out = cin if i < (num_blocks-1) else cout | |
add_blocks.extend([AADLayer(cin, c_attr, c_id), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(cin, out, kernel_size=3, stride=1, padding=1, bias=False) | |
]) | |
self.add_blocks = AddBlocksSequential(*add_blocks) | |
if cin != cout: | |
last_add_block = [AADLayer(cin, c_attr, c_id), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(cin, cout, kernel_size=3, stride=1, padding=1, bias=False)] | |
self.last_add_block = AddBlocksSequential(*last_add_block) | |
def forward(self, h, z_attr, z_id): | |
x = self.add_blocks(h, z_attr, z_id) | |
if self.cin != self.cout: | |
h = self.last_add_block(h, z_attr, z_id) | |
x = x + h | |
return x | |