Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from models.module import Conv2d, StyleAttentionBlock | |
| _ENCODER_CHANNEL_DEFAULT = 256 | |
| class Encoder(nn.Module): | |
| def __init__(self, hp, in_channels=1, out_channels=_ENCODER_CHANNEL_DEFAULT): | |
| super().__init__() | |
| self.hp = hp | |
| self.module = nn.ModuleList() | |
| def forward(self, x): | |
| for block in self.module: | |
| x = block(x) | |
| return x | |
| class ContentVanillaEncoder(Encoder): | |
| def __init__(self, hp, in_channels, out_channels): | |
| super().__init__(hp, in_channels, out_channels) | |
| self.depth = hp.encoder.content.depth | |
| assert out_channels // (2 ** self.depth) >= in_channels * 2, "Output channel should be increased" | |
| self.module = nn.ModuleList() | |
| self.module.append( | |
| Conv2d(in_channels, out_channels // (2 ** self.depth), | |
| kernel_size=7, padding=3, padding_mode='reflect', bias=False) | |
| ) | |
| for layer_idx in range(1, self.depth + 1): # downsample | |
| self.module.append( | |
| Conv2d(out_channels // (2 ** (self.depth - layer_idx + 1)), | |
| out_channels // (2 ** (self.depth - layer_idx)), | |
| kernel_size=3, stride=2, padding=1, bias=False) | |
| ) | |
| class StyleVanillaEncoder(Encoder): | |
| def __init__(self, hp, in_channels, out_channels): | |
| super().__init__(hp, in_channels, out_channels) | |
| self.depth = hp.encoder.style.depth | |
| assert out_channels // (2 ** self.depth) >= in_channels * 2, "Output channel should be increased" | |
| encoder_module = [] | |
| encoder_module.append( | |
| Conv2d(in_channels, out_channels // (2 ** self.depth), | |
| kernel_size=7, padding=3, padding_mode='reflect', bias=False) | |
| ) | |
| for layer_idx in range(1, self.depth + 1): # downsample | |
| encoder_module.append( | |
| Conv2d(out_channels // (2 ** (self.depth - layer_idx + 1)), | |
| out_channels // (2 ** (self.depth - layer_idx)), | |
| kernel_size=3, stride=2, padding=1, bias=False) | |
| ) | |
| self.add_module("encoder_module", nn.Sequential(*encoder_module)) | |
| self.add_module("attention_module", StyleAttentionBlock(out_channels)) | |
| def forward(self, x): | |
| B, K, H, W = x.size() | |
| out = self.encoder_module(x.view(-1, 1, H, W)) | |
| out = self.attention_module(out, B, K) | |
| return out | |