Spaces:
Paused
Paused
| import torch | |
| import torch.nn as nn | |
| class AADLayer(nn.Module): | |
| def __init__(self, c_x, attr_c, c_id): | |
| super(AADLayer, self).__init__() | |
| self.attr_c = attr_c | |
| self.c_id = c_id | |
| self.c_x = c_x | |
| self.conv1 = nn.Conv2d(attr_c, c_x, kernel_size=1, stride=1, padding=0, bias=True) | |
| self.conv2 = nn.Conv2d(attr_c, c_x, kernel_size=1, stride=1, padding=0, bias=True) | |
| self.fc1 = nn.Linear(c_id, c_x) | |
| self.fc2 = nn.Linear(c_id, c_x) | |
| self.norm = nn.InstanceNorm2d(c_x, affine=False) | |
| self.conv_h = nn.Conv2d(c_x, 1, kernel_size=1, stride=1, padding=0, bias=True) | |
| def forward(self, h_in, z_attr, z_id): | |
| # h_in cxnxn | |
| # zid 256x1x1 | |
| # zattr cxnxn | |
| h = self.norm(h_in) | |
| gamma_attr = self.conv1(z_attr) | |
| beta_attr = self.conv2(z_attr) | |
| gamma_id = self.fc1(z_id) | |
| beta_id = self.fc2(z_id) | |
| A = gamma_attr * h + beta_attr | |
| gamma_id = gamma_id.reshape(h.shape[0], self.c_x, 1, 1).expand_as(h) | |
| beta_id = beta_id.reshape(h.shape[0], self.c_x, 1, 1).expand_as(h) | |
| I = gamma_id * h + beta_id | |
| M = torch.sigmoid(self.conv_h(h)) | |
| out = (torch.ones_like(M).to(M.device) - M) * A + M * I | |
| return out | |
| class AddBlocksSequential(nn.Sequential): | |
| def forward(self, *inputs): | |
| h, z_attr, z_id = inputs | |
| for i, module in enumerate(self._modules.values()): | |
| if i%3 == 0 and i > 0: | |
| inputs = (inputs, z_attr, z_id) | |
| if type(inputs) == tuple: | |
| inputs = module(*inputs) | |
| else: | |
| inputs = module(inputs) | |
| return inputs | |
| class AAD_ResBlk(nn.Module): | |
| def __init__(self, cin, cout, c_attr, c_id, num_blocks): | |
| super(AAD_ResBlk, self).__init__() | |
| self.cin = cin | |
| self.cout = cout | |
| add_blocks = [] | |
| for i in range(num_blocks): | |
| out = cin if i < (num_blocks-1) else cout | |
| add_blocks.extend([AADLayer(cin, c_attr, c_id), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(cin, out, kernel_size=3, stride=1, padding=1, bias=False) | |
| ]) | |
| self.add_blocks = AddBlocksSequential(*add_blocks) | |
| if cin != cout: | |
| last_add_block = [AADLayer(cin, c_attr, c_id), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(cin, cout, kernel_size=3, stride=1, padding=1, bias=False)] | |
| self.last_add_block = AddBlocksSequential(*last_add_block) | |
| def forward(self, h, z_attr, z_id): | |
| x = self.add_blocks(h, z_attr, z_id) | |
| if self.cin != self.cout: | |
| h = self.last_add_block(h, z_attr, z_id) | |
| x = x + h | |
| return x | |