Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from models.module import ResidualBlocks | |
| _DECODER_CHANNEL_DEFAULT = 512 | |
| class Decoder(nn.Module): | |
| def __init__(self, hp, in_channels=_DECODER_CHANNEL_DEFAULT, out_channels=1): | |
| super().__init__() | |
| self.module = nn.ModuleList() | |
| def forward(self, x): | |
| for block in self.module: | |
| x = block(x) | |
| return x | |
| class VanillaDecoder(Decoder): | |
| def __init__(self, hp, in_channels, out_channels): | |
| super().__init__(hp, in_channels, out_channels) | |
| self.depth = hp.decoder.depth | |
| self.blocks = hp.decoder.residual_blocks | |
| self.module = nn.ModuleList() | |
| if self.blocks > 0: | |
| self.module.append(ResidualBlocks(in_channels, n_blocks=self.blocks)) | |
| for layer_idx in range(1, self.depth + 1): # add upsampling layers | |
| self.module.append(nn.Sequential( | |
| nn.ConvTranspose2d(in_channels // (2 ** (layer_idx - 1)), | |
| in_channels // (2 ** layer_idx), | |
| kernel_size=3, stride=2, | |
| padding=1, output_padding=1, | |
| bias=False), | |
| nn.BatchNorm2d(in_channels // (2 ** layer_idx)), | |
| nn.ReLU(True) | |
| )) | |
| final = nn.Sequential( | |
| nn.Conv2d(in_channels // (2 ** self.depth), out_channels, kernel_size=7, padding=3, padding_mode='reflect'), | |
| nn.Tanh() | |
| ) | |
| self.module.append(final) | |