Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| from dp2.layers import Sequential | |
| from dp2.layers.sg2_layers import Conv2d, FullyConnectedLayer, ResidualBlock | |
| from .base import BaseStyleGAN | |
| from typing import List, Tuple | |
| from .utils import spatial_embed_keypoints, mask_output | |
| def get_chsize(imsize, cnum, max_imsize, max_cnum_mul): | |
| n = int(np.log2(max_imsize) - np.log2(imsize)) | |
| mul = min(2**n, max_cnum_mul) | |
| ch = cnum * mul | |
| return int(ch) | |
| class StyleGANUnet(BaseStyleGAN): | |
| def __init__( | |
| self, | |
| scale_grad: bool, | |
| im_channels: int, | |
| min_fmap_resolution: int, | |
| imsize: List[int], | |
| cnum: int, | |
| max_cnum_mul: int, | |
| mask_output: bool, | |
| conv_clamp: int, | |
| input_cse: bool, | |
| cse_nc: int, | |
| n_middle_blocks: int, | |
| input_keypoints: bool, | |
| n_keypoints: int, | |
| input_keypoint_indices: Tuple[int], | |
| fix_errors: bool, | |
| **kwargs | |
| ) -> None: | |
| super().__init__(**kwargs) | |
| self.n_keypoints = n_keypoints | |
| self.input_keypoint_indices = list(input_keypoint_indices) | |
| self.input_keypoints = input_keypoints | |
| assert not (input_cse and input_keypoints) | |
| cse_nc = 0 if cse_nc is None else cse_nc | |
| self.imsize = imsize | |
| self._cnum = cnum | |
| self._max_cnum_mul = max_cnum_mul | |
| self._min_fmap_resolution = min_fmap_resolution | |
| self._image_channels = im_channels | |
| self._max_imsize = max(imsize) | |
| self.input_cse = input_cse | |
| self.gain_unet = np.sqrt(1/3) | |
| n_levels = int(np.log2(self._max_imsize) - np.log2(min_fmap_resolution))+1 | |
| encoder_layers = [] | |
| self.from_rgb = Conv2d( | |
| im_channels + 1 + input_cse*(cse_nc+1) + input_keypoints*len(self.input_keypoint_indices), | |
| cnum, 1 | |
| ) | |
| for i in range(n_levels): # Encoder layers | |
| resolution = [x//2**i for x in imsize] | |
| in_ch = get_chsize(max(resolution), cnum, self._max_imsize, max_cnum_mul) | |
| second_ch = in_ch | |
| out_ch = get_chsize(max(resolution)//2, cnum, self._max_imsize, max_cnum_mul) | |
| down = 2 | |
| if i == 0: # first (lowest) block. Downsampling is performed at the start of the block | |
| down = 1 | |
| if i == n_levels - 1: | |
| out_ch = second_ch | |
| block = ResidualBlock(in_ch, out_ch, down=down, conv_clamp=conv_clamp, fix_residual=fix_errors) | |
| encoder_layers.append(block) | |
| self._encoder_out_shape = [ | |
| get_chsize(min_fmap_resolution, cnum, self._max_imsize, max_cnum_mul), | |
| *resolution] | |
| self.encoder = torch.nn.ModuleList(encoder_layers) | |
| # initialize decoder | |
| decoder_layers = [] | |
| for i in range(n_levels): | |
| resolution = [x//2**(n_levels-1-i) for x in imsize] | |
| in_ch = get_chsize(max(resolution)//2, cnum, self._max_imsize, max_cnum_mul) | |
| out_ch = get_chsize(max(resolution), cnum, self._max_imsize, max_cnum_mul) | |
| if i == 0: # first (lowest) block | |
| in_ch = get_chsize(max(resolution), cnum, self._max_imsize, max_cnum_mul) | |
| up = 1 | |
| if i != n_levels - 1: | |
| up = 2 | |
| block = ResidualBlock( | |
| in_ch, out_ch, conv_clamp=conv_clamp, gain_out=np.sqrt(1/3), | |
| w_dim=self.style_net.w_dim, norm=True, up=up, | |
| fix_residual=fix_errors | |
| ) | |
| decoder_layers.append(block) | |
| if i != 0: | |
| unet_block = Conv2d( | |
| in_ch, in_ch, kernel_size=1, conv_clamp=conv_clamp, norm=True, | |
| gain=np.sqrt(1/3) if fix_errors else np.sqrt(.5)) | |
| setattr(self, f"unet_block{i}", unet_block) | |
| # Initialize "middle blocks" that do not have down/up sample | |
| middle_blocks = [] | |
| for i in range(n_middle_blocks): | |
| ch = get_chsize(min_fmap_resolution, cnum, self._max_imsize, max_cnum_mul) | |
| block = ResidualBlock( | |
| ch, ch, conv_clamp=conv_clamp, gain_out=np.sqrt(.5) if fix_errors else np.sqrt(1/3), | |
| w_dim=self.style_net.w_dim, norm=True, | |
| ) | |
| middle_blocks.append(block) | |
| if n_middle_blocks != 0: | |
| self.middle_blocks = Sequential(*middle_blocks) | |
| self.decoder = torch.nn.ModuleList(decoder_layers) | |
| self.to_rgb = Conv2d(cnum, im_channels, 1, activation="linear", conv_clamp=conv_clamp) | |
| # Initialize "middle blocks" that do not have down/up sample | |
| self.decoder = torch.nn.ModuleList(decoder_layers) | |
| self.scale_grad = scale_grad | |
| self.mask_output = mask_output | |
| def forward_dec(self, x, w, unet_features, condition, mask, s, **kwargs): | |
| for i, layer in enumerate(self.decoder): | |
| if i != 0: | |
| unet_layer = getattr(self, f"unet_block{i}") | |
| x = x + unet_layer(unet_features[-i]) | |
| x = layer(x, w=w, s=s) | |
| x = self.to_rgb(x) | |
| if self.mask_output: | |
| x = mask_output(True, condition, x, mask) | |
| return dict(img=x) | |
| def forward_enc(self, condition, mask, embedding, keypoints, E_mask, **kwargs): | |
| if self.input_cse: | |
| x = torch.cat((condition, mask, embedding, E_mask), dim=1) | |
| else: | |
| x = torch.cat((condition, mask), dim=1) | |
| if self.input_keypoints: | |
| keypoints = keypoints[:, self.input_keypoint_indices] | |
| one_hot_pose = spatial_embed_keypoints(keypoints, x) | |
| x = torch.cat((x, one_hot_pose), dim=1) | |
| x = self.from_rgb(x) | |
| unet_features = [] | |
| for i, layer in enumerate(self.encoder): | |
| x = layer(x) | |
| if i != len(self.encoder)-1: | |
| unet_features.append(x) | |
| if hasattr(self, "middle_blocks"): | |
| for layer in self.middle_blocks: | |
| x = layer(x) | |
| return x, unet_features | |
| def forward( | |
| self, condition, mask, | |
| z=None, embedding=None, w=None, update_emas=False, x=None, | |
| s=None, | |
| keypoints=None, | |
| unet_features=None, | |
| E_mask=None, | |
| **kwargs): | |
| # Used to skip sampling from encoder in inference. E.g. for w projection. | |
| if x is not None and unet_features is not None: | |
| assert not self.training | |
| else: | |
| x, unet_features = self.forward_enc(condition, mask, embedding, keypoints, E_mask, **kwargs) | |
| if w is None: | |
| if z is None: | |
| z = self.get_z(condition) | |
| w = self.get_w(z, update_emas=update_emas) | |
| return self.forward_dec(x, w, unet_features, condition, mask, s, **kwargs) | |
| class ComodStyleUNet(StyleGANUnet): | |
| def __init__(self, min_comod_res=4, lr_multiplier_comod=1, **kwargs) -> None: | |
| super().__init__(**kwargs) | |
| min_fmap = min(self._encoder_out_shape[1:]) | |
| enc_out_ch = self._encoder_out_shape[0] | |
| n_down = int(np.ceil(np.log2(min_fmap) - np.log2(min_comod_res))) | |
| comod_layers = [] | |
| in_ch = enc_out_ch | |
| for i in range(n_down): | |
| comod_layers.append(Conv2d(enc_out_ch, 256, kernel_size=3, down=2, lr_multiplier=lr_multiplier_comod)) | |
| in_ch = 256 | |
| if n_down == 0: | |
| comod_layers = [Conv2d(in_ch, 256, kernel_size=3)] | |
| comod_layers.append(torch.nn.Flatten()) | |
| out_res = [x//2**n_down for x in self._encoder_out_shape[1:]] | |
| in_ch_fc = np.prod(out_res) * 256 | |
| comod_layers.append(FullyConnectedLayer(in_ch_fc, 512, lr_multiplier=lr_multiplier_comod)) | |
| self.comod_block = Sequential(*comod_layers) | |
| self.comod_fc = FullyConnectedLayer( | |
| 512+self.style_net.w_dim, self.style_net.w_dim, lr_multiplier=lr_multiplier_comod) | |
| def forward_dec(self, x, w, unet_features, condition, mask, **kwargs): | |
| y = self.comod_block(x) | |
| y = torch.cat((y, w), dim=1) | |
| y = self.comod_fc(y) | |
| for i, layer in enumerate(self.decoder): | |
| if i != 0: | |
| unet_layer = getattr(self, f"unet_block{i}") | |
| x = x + unet_layer(unet_features[-i], gain=np.sqrt(.5)) | |
| x = layer(x, w=y) | |
| x = self.to_rgb(x) | |
| if self.mask_output: | |
| x = mask_output(True, condition, x, mask) | |
| return dict(img=x) | |
| def get_comod_y(self, batch, w): | |
| x, unet_features = self.forward_enc(**batch) | |
| y = self.comod_block(x) | |
| y = torch.cat((y, w), dim=1) | |
| y = self.comod_fc(y) | |
| return y | |