Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import climategan.strings as strings | |
| from climategan.blocks import InterpolateNearest2d, SPADEResnetBlock | |
| from climategan.norms import SpectralNorm | |
| def create_painter(opts, no_init=False, verbose=0): | |
| if verbose > 0: | |
| print(" - Add PainterSpadeDecoder Painter") | |
| return PainterSpadeDecoder(opts) | |
| class PainterSpadeDecoder(nn.Module): | |
| def __init__(self, opts): | |
| """Create a SPADE-based decoder, which forwards z and the conditioning | |
| tensors seg (in the original paper, conditioning is on a semantic map only). | |
| All along, z is conditioned on seg. First 3 SpadeResblocks (SRB) do not shrink | |
| the channel dimension, and an upsampling is applied after each. Therefore | |
| 2 upsamplings at this point. Then, for each remaining upsamplings | |
| (w.r.t. spade_n_up), the SRB shrinks channels by 2. Before final conv to get 3 | |
| channels, the number of channels is therefore: | |
| final_nc = channels(z) * 2 ** (spade_n_up - 2) | |
| Args: | |
| latent_dim (tuple): z's shape (only the number of channels matters) | |
| cond_nc (int): conditioning tensor's expected number of channels | |
| spade_n_up (int): Number of total upsamplings from z | |
| spade_use_spectral_norm (bool): use spectral normalization? | |
| spade_param_free_norm (str): norm to use before SPADE de-normalization | |
| spade_kernel_size (int): SPADE conv layers' kernel size | |
| Returns: | |
| [type]: [description] | |
| """ | |
| super().__init__() | |
| latent_dim = opts.gen.p.latent_dim | |
| cond_nc = 3 | |
| spade_n_up = opts.gen.p.spade_n_up | |
| spade_use_spectral_norm = opts.gen.p.spade_use_spectral_norm | |
| spade_param_free_norm = opts.gen.p.spade_param_free_norm | |
| spade_kernel_size = 3 | |
| self.z_nc = latent_dim | |
| self.spade_n_up = spade_n_up | |
| self.z_h = self.z_w = None | |
| self.fc = nn.Conv2d(3, latent_dim, 3, padding=1) | |
| self.head_0 = SPADEResnetBlock( | |
| self.z_nc, | |
| self.z_nc, | |
| cond_nc, | |
| spade_use_spectral_norm, | |
| spade_param_free_norm, | |
| spade_kernel_size, | |
| ) | |
| self.G_middle_0 = SPADEResnetBlock( | |
| self.z_nc, | |
| self.z_nc, | |
| cond_nc, | |
| spade_use_spectral_norm, | |
| spade_param_free_norm, | |
| spade_kernel_size, | |
| ) | |
| self.G_middle_1 = SPADEResnetBlock( | |
| self.z_nc, | |
| self.z_nc, | |
| cond_nc, | |
| spade_use_spectral_norm, | |
| spade_param_free_norm, | |
| spade_kernel_size, | |
| ) | |
| self.up_spades = nn.Sequential( | |
| *[ | |
| SPADEResnetBlock( | |
| self.z_nc // 2 ** i, | |
| self.z_nc // 2 ** (i + 1), | |
| cond_nc, | |
| spade_use_spectral_norm, | |
| spade_param_free_norm, | |
| spade_kernel_size, | |
| ) | |
| for i in range(spade_n_up - 2) | |
| ] | |
| ) | |
| self.final_nc = self.z_nc // 2 ** (spade_n_up - 2) | |
| self.final_spade = SPADEResnetBlock( | |
| self.final_nc, | |
| self.final_nc, | |
| cond_nc, | |
| spade_use_spectral_norm, | |
| spade_param_free_norm, | |
| spade_kernel_size, | |
| ) | |
| self.final_shortcut = None | |
| if opts.gen.p.use_final_shortcut: | |
| self.final_shortcut = nn.Sequential( | |
| *[ | |
| SpectralNorm(nn.Conv2d(self.final_nc, 3, 1)), | |
| nn.BatchNorm2d(3), | |
| nn.LeakyReLU(0.2, True), | |
| ] | |
| ) | |
| self.conv_img = nn.Conv2d(self.final_nc, 3, 3, padding=1) | |
| self.upsample = InterpolateNearest2d(scale_factor=2) | |
| def set_latent_shape(self, shape, is_input=True): | |
| """ | |
| Sets the latent shape to start the upsampling from, i.e. z_h and z_w. | |
| If is_input is True, then this is the actual input shape which should | |
| be divided by 2 ** spade_n_up | |
| Otherwise, just sets z_h and z_w from shape[-2] and shape[-1] | |
| Args: | |
| shape (tuple): The shape to start sampling from. | |
| is_input (bool, optional): Whether to divide shape by 2 ** spade_n_up | |
| """ | |
| if isinstance(shape, (list, tuple)): | |
| self.z_h = shape[-2] | |
| self.z_w = shape[-1] | |
| elif isinstance(shape, int): | |
| self.z_h = self.z_w = shape | |
| else: | |
| raise ValueError("Unknown shape type:", shape) | |
| if is_input: | |
| self.z_h = self.z_h // (2 ** self.spade_n_up) | |
| self.z_w = self.z_w // (2 ** self.spade_n_up) | |
| def _apply(self, fn): | |
| # print("Applying SpadeDecoder", fn) | |
| super()._apply(fn) | |
| # self.head_0 = fn(self.head_0) | |
| # self.G_middle_0 = fn(self.G_middle_0) | |
| # self.G_middle_1 = fn(self.G_middle_1) | |
| # for i, up in enumerate(self.up_spades): | |
| # self.up_spades[i] = fn(up) | |
| # self.conv_img = fn(self.conv_img) | |
| return self | |
| def forward(self, z, cond): | |
| if z is None: | |
| assert self.z_h is not None and self.z_w is not None | |
| z = self.fc(F.interpolate(cond, size=(self.z_h, self.z_w))) | |
| y = self.head_0(z, cond) | |
| y = self.upsample(y) | |
| y = self.G_middle_0(y, cond) | |
| y = self.upsample(y) | |
| y = self.G_middle_1(y, cond) | |
| for i, up in enumerate(self.up_spades): | |
| y = self.upsample(y) | |
| y = up(y, cond) | |
| if self.final_shortcut is not None: | |
| cond = self.final_shortcut(y) | |
| y = self.final_spade(y, cond) | |
| y = self.conv_img(F.leaky_relu(y, 2e-1)) | |
| y = torch.tanh(y) | |
| return y | |
| def __str__(self): | |
| return strings.spadedecoder(self) | |