Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from easydict import EasyDict | |
| from .base import BaseGenerator | |
| import numpy as np | |
| from typing import List | |
| class LatentVariableConcat(nn.Module): | |
| def __init__(self, conv2d_config): | |
| super().__init__() | |
| def forward(self, _inp): | |
| x, mask, batch = _inp | |
| z = batch["z"] | |
| x = torch.cat((x, z), dim=1) | |
| return (x, mask, batch) | |
| def get_padding(kernel_size: int, dilation: int, stride: int): | |
| out = (dilation * (kernel_size - 1) - 1) / 2 + 1 | |
| return int(np.floor(out)) | |
| class Conv2d(nn.Conv2d): | |
| def __init__(self, in_channels, out_channels, kernel_size, stride=1, | |
| padding=None, dilation=1, groups=1, | |
| bias=True, padding_mode='zeros', | |
| demodulation=False, wsconv=False, gain=1, | |
| *args, **kwargs): | |
| if padding is None: | |
| padding = get_padding(kernel_size, dilation, stride) | |
| super().__init__( | |
| in_channels, out_channels, kernel_size, stride, padding, dilation, | |
| groups, bias, padding_mode) | |
| self.demodulation = demodulation | |
| self.wsconv = wsconv | |
| if self.wsconv: | |
| fan_in = np.prod(self.weight.shape[1:]) / self.groups | |
| self.ws_scale = gain / np.sqrt(fan_in) | |
| nn.init.normal_(self.weight) | |
| if bias: | |
| nn.init.constant_(self.bias, val=0) | |
| assert not self.padding_mode == "circular",\ | |
| "conv2d_forward does not support circular padding. Look at original pytorch code" | |
| def _get_weight(self): | |
| weight = self.weight | |
| if self.wsconv: | |
| weight = self.ws_scale * weight | |
| if self.demodulation: | |
| demod = torch.rsqrt(weight.pow(2).sum([1, 2, 3]) + 1e-7) | |
| weight = weight * demod.view(self.out_channels, 1, 1, 1) | |
| return weight | |
| def conv2d_forward(self, x, weight, bias=True): | |
| bias_ = None | |
| if bias: | |
| bias_ = self.bias | |
| return nn.functional.conv2d(x, weight, bias_, self.stride, | |
| self.padding, self.dilation, self.groups) | |
| def forward(self, _inp): | |
| x, mask = _inp | |
| weight = self._get_weight() | |
| return self.conv2d_forward(x, weight), mask | |
| def __repr__(self): | |
| return ", ".join([ | |
| super().__repr__(), | |
| f"Demodulation={self.demodulation}", | |
| f"Weight Scale={self.wsconv}", | |
| f"Bias={self.bias is not None}" | |
| ]) | |
| class LeakyReLU(nn.LeakyReLU): | |
| def forward(self, _inp): | |
| x, mask = _inp | |
| return super().forward(x), mask | |
| class AvgPool2d(nn.AvgPool2d): | |
| def forward(self, _inp): | |
| x, mask, *args = _inp | |
| x = super().forward(x) | |
| mask = super().forward(mask) | |
| if len(args) > 0: | |
| return (x, mask, *args) | |
| return x, mask | |
| def up(x): | |
| if x.shape[0] == 1 and x.shape[2] == 1 and x.shape[3] == 1: | |
| # Analytical normalization | |
| return x | |
| return nn.functional.interpolate( | |
| x, scale_factor=2, mode="nearest") | |
| class NearestUpsample(nn.Module): | |
| def forward(self, _inp): | |
| x, mask, *args = _inp | |
| x = up(x) | |
| mask = up(mask) | |
| if len(args) > 0: | |
| return (x, mask, *args) | |
| return x, mask | |
| class PixelwiseNormalization(nn.Module): | |
| def forward(self, _inp): | |
| x, mask = _inp | |
| norm = torch.rsqrt((x**2).mean(dim=1, keepdim=True) + 1e-7) | |
| return x * norm, mask | |
| class Linear(nn.Linear): | |
| def __init__(self, in_features, out_features): | |
| super().__init__(in_features, out_features) | |
| self.linear = nn.Linear(in_features, out_features) | |
| fanIn = in_features | |
| self.wtScale = 1 / np.sqrt(fanIn) | |
| nn.init.normal_(self.weight) | |
| nn.init.constant_(self.bias, val=0) | |
| def _get_weight(self): | |
| return self.weight * self.wtScale | |
| def forward_linear(self, x, weight): | |
| return nn.functional.linear(x, weight, self.bias) | |
| def forward(self, x): | |
| return self.forward_linear(x, self._get_weight()) | |
| class OneHotPoseConcat(nn.Module): | |
| def forward(self, _inp): | |
| x, mask, batch = _inp | |
| landmarks = batch["landmarks_oh"] | |
| res = x.shape[-1] | |
| landmark = landmarks[res] | |
| x = torch.cat((x, landmark), dim=1) | |
| del batch["landmarks_oh"][res] | |
| return x, mask, batch | |
| def transition_features(x_old, x_new, transition_variable): | |
| assert x_old.shape == x_new.shape,\ | |
| "Old shape: {}, New: {}".format(x_old.shape, x_new.shape) | |
| return torch.lerp(x_old.float(), x_new.float(), transition_variable) | |
| class TransitionBlock(nn.Module): | |
| def forward(self, _inp): | |
| x, mask, batch = _inp | |
| x = transition_features( | |
| batch["x_old"], x, batch["transition_value"]) | |
| mask = transition_features( | |
| batch["mask_old"], mask, batch["transition_value"]) | |
| del batch["x_old"] | |
| del batch["mask_old"] | |
| return x, mask, batch | |
| class UnetSkipConnection(nn.Module): | |
| def __init__(self, conv2d_config: dict, in_channels: int, | |
| out_channels: int, resolution: int, | |
| residual: bool, enabled: bool): | |
| super().__init__() | |
| self.use_iconv = conv2d_config.conv.type == "iconv" | |
| self._in_channels = in_channels | |
| self._out_channels = out_channels | |
| self._resolution = resolution | |
| self._enabled = enabled | |
| self._residual = residual | |
| if self.use_iconv: | |
| self.beta0 = torch.nn.Parameter(torch.tensor(1.)) | |
| self.beta1 = torch.nn.Parameter(torch.tensor(1.)) | |
| else: | |
| if self._residual: | |
| self.conv = build_base_conv( | |
| conv2d_config, False, in_channels // 2, | |
| out_channels, kernel_size=1, padding=0) | |
| else: | |
| self.conv = ConvAct( | |
| conv2d_config, in_channels, out_channels, | |
| kernel_size=1, padding=0) | |
| def forward(self, _inp): | |
| if not self._enabled: | |
| return _inp | |
| x, mask, batch = _inp | |
| skip_x, skip_mask = batch["unet_features"][self._resolution] | |
| assert x.shape == skip_x.shape, (x.shape, skip_x.shape) | |
| del batch["unet_features"][self._resolution] | |
| if self.use_iconv: | |
| denom = skip_mask * self.beta0.relu() + mask * self.beta1.relu() + 1e-8 | |
| gamma = skip_mask * self.beta0.relu() / denom | |
| x = skip_x * gamma + (1 - gamma) * x | |
| mask = skip_mask * gamma + (1 - gamma) * mask | |
| else: | |
| if self._residual: | |
| skip_x, skip_mask = self.conv((skip_x, skip_mask)) | |
| x = (x + skip_x) / np.sqrt(2) | |
| if self._probabilistic: | |
| mask = (mask + skip_mask) / np.sqrt(2) | |
| else: | |
| x = torch.cat((x, skip_x), dim=1) | |
| x, mask = self.conv((x, mask)) | |
| return x, mask, batch | |
| def __repr__(self): | |
| return " ".join([ | |
| self.__class__.__name__, | |
| f"In channels={self._in_channels}", | |
| f"Out channels={self._out_channels}", | |
| f"Residual: {self._residual}", | |
| f"Enabled: {self._enabled}" | |
| f"IConv: {self.use_iconv}" | |
| ]) | |
| def get_conv(ctype, post_act): | |
| type2conv = { | |
| "conv": Conv2d, | |
| "gconv": GatedConv | |
| } | |
| # Do not apply for output layer | |
| if not post_act and ctype in ["gconv", "iconv"]: | |
| return type2conv["conv"] | |
| assert ctype in type2conv | |
| return type2conv[ctype] | |
| def build_base_conv( | |
| conv2d_config, post_act: bool, *args, **kwargs) -> nn.Conv2d: | |
| for k, v in conv2d_config.conv.items(): | |
| assert k not in kwargs | |
| kwargs[k] = v | |
| # Demodulation should not be used for output layers. | |
| demodulation = conv2d_config.normalization == "demodulation" and post_act | |
| kwargs["demodulation"] = demodulation | |
| conv = get_conv(conv2d_config.conv.type, post_act) | |
| return conv(*args, **kwargs) | |
| def build_post_activation(in_channels, conv2d_config) -> List[nn.Module]: | |
| _layers = [] | |
| negative_slope = conv2d_config.leaky_relu_nslope | |
| _layers.append(LeakyReLU(negative_slope, inplace=True)) | |
| if conv2d_config.normalization == "pixel_wise": | |
| _layers.append(PixelwiseNormalization()) | |
| return _layers | |
| def build_avgpool(conv2d_config, kernel_size) -> nn.AvgPool2d: | |
| return AvgPool2d(kernel_size) | |
| def build_convact(conv2d_config, *args, **kwargs): | |
| conv = build_base_conv(conv2d_config, True, *args, **kwargs) | |
| out_channels = conv.out_channels | |
| post_act = build_post_activation(out_channels, conv2d_config) | |
| return nn.Sequential(conv, *post_act) | |
| class ConvAct(nn.Module): | |
| def __init__(self, conv2d_config, *args, **kwargs): | |
| super().__init__() | |
| self._conv2d_config = conv2d_config | |
| conv = build_base_conv(conv2d_config, True, *args, **kwargs) | |
| self.in_channels = conv.in_channels | |
| self.out_channels = conv.out_channels | |
| _layers = [conv] | |
| _layers.extend(build_post_activation(self.out_channels, conv2d_config)) | |
| self.layers = nn.Sequential(*_layers) | |
| def forward(self, _inp): | |
| return self.layers(_inp) | |
| class GatedConv(Conv2d): | |
| def __init__(self, in_channels, out_channels, *args, **kwargs): | |
| out_channels *= 2 | |
| super().__init__(in_channels, out_channels, *args, **kwargs) | |
| assert self.out_channels % 2 == 0 | |
| self.lrelu = nn.LeakyReLU(0.2, inplace=True) | |
| self.sigmoid = nn.Sigmoid() | |
| def conv2d_forward(self, x, weight, bias=True): | |
| x_ = super().conv2d_forward(x, weight, bias) | |
| x = x_[:, :self.out_channels // 2] | |
| y = x_[:, self.out_channels // 2:] | |
| x = self.lrelu(x) | |
| y = y.sigmoid() | |
| assert x.shape == y.shape, f"{x.shape}, {y.shape}" | |
| return x * y | |
| class BasicBlock(nn.Module): | |
| def __init__( | |
| self, conv2d_config, resolution: int, in_channels: int, | |
| out_channels: List[int], residual: bool): | |
| super().__init__() | |
| assert len(out_channels) == 2 | |
| self._resolution = resolution | |
| self._residual = residual | |
| self.out_channels = out_channels | |
| _layers = [] | |
| _in_channels = in_channels | |
| for out_ch in out_channels: | |
| conv = build_base_conv( | |
| conv2d_config, True, _in_channels, out_ch, kernel_size=3, | |
| resolution=resolution) | |
| _layers.append(conv) | |
| _layers.extend(build_post_activation(_in_channels, conv2d_config)) | |
| _in_channels = out_ch | |
| self.layers = nn.Sequential(*_layers) | |
| if self._residual: | |
| self.residual_conv = build_base_conv( | |
| conv2d_config, post_act=False, in_channels=in_channels, | |
| out_channels=out_channels[-1], | |
| kernel_size=1, padding=0) | |
| self.const = 1 / np.sqrt(2) | |
| def forward(self, _inp): | |
| x, mask, batch = _inp | |
| y = x | |
| mask_ = mask | |
| assert y.shape[-1] == self._resolution or y.shape[-1] == 1 | |
| y, mask = self.layers((x, mask)) | |
| if self._residual: | |
| residual, mask_ = self.residual_conv((x, mask_)) | |
| y = (y + residual) * self.const | |
| mask = (mask + mask_) * self.const | |
| return y, mask, batch | |
| def extra_repr(self): | |
| return f"Residual={self._residual}, Resolution={self._resolution}" | |
| class PoseNormalize(nn.Module): | |
| def forward(self, x): | |
| return x * 2 - 1 | |
| class ScalarPoseFCNN(nn.Module): | |
| def __init__(self, pose_size, hidden_size, | |
| output_shape): | |
| super().__init__() | |
| pose_size = pose_size | |
| self._hidden_size = hidden_size | |
| output_size = np.prod(output_shape) | |
| self.output_shape = output_shape | |
| self.pose_preprocessor = nn.Sequential( | |
| PoseNormalize(), | |
| Linear(pose_size, hidden_size), | |
| nn.LeakyReLU(.2), | |
| Linear(hidden_size, output_size), | |
| nn.LeakyReLU(.2) | |
| ) | |
| def forward(self, _inp): | |
| x, mask, batch = _inp | |
| pose_info = batch["landmarks"] | |
| del batch["landmarks"] | |
| pose = self.pose_preprocessor(pose_info) | |
| pose = pose.view(-1, *self.output_shape) | |
| if x.shape[0] == 1 and x.shape[2] == 1 and x.shape[3] == 1: | |
| # Analytical normalization propagation | |
| pose = pose.mean(dim=2, keepdim=True).mean(dim=3, keepdims=True) | |
| x = torch.cat((x, pose), dim=1) | |
| return x, mask, batch | |
| def __repr__(self): | |
| return " ".join([ | |
| self.__class__.__name__, | |
| f"hidden_size={self._hidden_size}", | |
| f"output shape={self.output_shape}" | |
| ]) | |
| class Attention(nn.Module): | |
| def __init__(self, in_channels): | |
| super(Attention, self).__init__() | |
| # Channel multiplier | |
| self.in_channels = in_channels | |
| self.theta = Conv2d( | |
| self.in_channels, self.in_channels // 8, kernel_size=1, padding=0, | |
| bias=False) | |
| self.phi = Conv2d( | |
| self.in_channels, self.in_channels // 8, kernel_size=1, padding=0, | |
| bias=False) | |
| self.g = Conv2d( | |
| self.in_channels, self.in_channels // 2, kernel_size=1, padding=0, | |
| bias=False) | |
| self.o = Conv2d( | |
| self.in_channels // 2, self.in_channels, kernel_size=1, padding=0, | |
| bias=False) | |
| # Learnable gain parameter | |
| self.gamma = nn.Parameter(torch.tensor(0.), requires_grad=True) | |
| def forward(self, _inp): | |
| x, mask, batch = _inp | |
| # Apply convs | |
| theta, _ = self.theta((x, None)) | |
| phi = nn.functional.max_pool2d(self.phi((x, None))[0], [2, 2]) | |
| g = nn.functional.max_pool2d(self.g((x, None))[0], [2, 2]) | |
| # Perform reshapes | |
| theta = theta.view(-1, self.in_channels // 8, x.shape[2] * x.shape[3]) | |
| phi = phi.view(-1, self.in_channels // 8, x.shape[2] * x.shape[3] // 4) | |
| g = g.view(-1, self.in_channels // 2, x.shape[2] * x.shape[3] // 4) | |
| # Matmul and softmax to get attention maps | |
| beta = nn.functional.softmax(torch.bmm(theta.transpose(1, 2), phi), -1) | |
| # Attention map times g path | |
| o = self.o((torch.bmm(g, beta.transpose(1, 2)).view(-1, | |
| self.in_channels // 2, x.shape[2], x.shape[3]), None))[0] | |
| return self.gamma * o + x, mask, batch | |
| class MSGGenerator(BaseGenerator): | |
| def __init__(self): | |
| super().__init__(512) | |
| max_imsize = 128 | |
| unet = dict(enabled=True, residual=False) | |
| min_fmap_resolution = 4 | |
| model_size = 512 | |
| image_channels = 3 | |
| pose_size = 14 | |
| residual = False | |
| conv_size = { | |
| 4: model_size, | |
| 8: model_size, | |
| 16: model_size, | |
| 32: model_size, | |
| 64: model_size//2, | |
| 128: model_size//4, | |
| 256: model_size//8, | |
| 512: model_size//16 | |
| } | |
| self.removable_hooks = [] | |
| self.rgb_convolutions = nn.ModuleDict() | |
| self.max_imsize = max_imsize | |
| self._image_channels = image_channels | |
| self._min_fmap_resolution = min_fmap_resolution | |
| self._residual = residual | |
| self._pose_size = pose_size | |
| self.current_imsize = max_imsize | |
| self._unet_cfg = unet | |
| self.concat_input_mask = True | |
| self.res2channels = {int(k): v for k, v in conv_size.items()} | |
| self.conv2d_config = EasyDict( | |
| pixel_normalization=True, | |
| leaky_relu_nslope=.2, | |
| normalization="pixel_wise", | |
| conv=dict( | |
| type="conv", | |
| wsconv=True, | |
| gain=1, | |
| ) | |
| ) | |
| self._init_decoder() | |
| self._init_encoder() | |
| def _init_encoder(self): | |
| self.encoder = nn.ModuleList() | |
| imsize = self.max_imsize | |
| self.from_rgb = build_convact( | |
| self.conv2d_config, | |
| in_channels=self._image_channels + self.concat_input_mask*2, | |
| out_channels=self.res2channels[imsize], | |
| kernel_size=1) | |
| while imsize >= self._min_fmap_resolution: | |
| current_size = self.res2channels[imsize] | |
| next_size = self.res2channels[max(imsize//2, self._min_fmap_resolution)] | |
| block = BasicBlock( | |
| self.conv2d_config, imsize, current_size, | |
| [current_size, next_size], self._residual) | |
| self.encoder.add_module(f"basic_block{imsize}", block) | |
| if imsize != self._min_fmap_resolution: | |
| self.encoder.add_module( | |
| f"downsample{imsize}", AvgPool2d(2)) | |
| imsize //= 2 | |
| def _init_decoder(self): | |
| self.decoder = nn.ModuleList() | |
| self.decoder.add_module( | |
| "latent_concat", LatentVariableConcat(self.conv2d_config)) | |
| if self._pose_size > 0: | |
| m = self._min_fmap_resolution | |
| pose_shape = (16, m, m) | |
| pose_fcnn = ScalarPoseFCNN(self._pose_size, 128, pose_shape) | |
| self.decoder.add_module("pose_fcnn", pose_fcnn) | |
| imsize = self._min_fmap_resolution | |
| self.rgb_convolutions = nn.ModuleDict() | |
| while imsize <= self.max_imsize: | |
| current_size = self.res2channels[max(imsize//2, self._min_fmap_resolution)] | |
| start_size = current_size | |
| if imsize == self._min_fmap_resolution: | |
| start_size += 32 | |
| if self._pose_size > 0: | |
| start_size += 16 | |
| else: | |
| self.decoder.add_module(f"upsample{imsize}", NearestUpsample()) | |
| skip = UnetSkipConnection( | |
| self.conv2d_config, current_size*2, current_size, imsize, | |
| **self._unet_cfg) | |
| self.decoder.add_module(f"skip_connection{imsize}", skip) | |
| next_size = self.res2channels[imsize] | |
| block = BasicBlock( | |
| self.conv2d_config, imsize, start_size, [start_size, next_size], | |
| residual=self._residual) | |
| self.decoder.add_module(f"basic_block{imsize}", block) | |
| to_rgb = build_base_conv( | |
| self.conv2d_config, False, in_channels=next_size, | |
| out_channels=self._image_channels, kernel_size=1) | |
| self.rgb_convolutions[str(imsize)] = to_rgb | |
| imsize *= 2 | |
| self.norm_constant = len(self.rgb_convolutions) | |
| def forward_decoder(self, x, mask, batch): | |
| imsize_start = max(x.shape[-1] // 2, 1) | |
| rgb = torch.zeros( | |
| (x.shape[0], self._image_channels, | |
| imsize_start, imsize_start), | |
| dtype=x.dtype, device=x.device) | |
| mask_size = 1 | |
| mask_out = torch.zeros( | |
| (x.shape[0], mask_size, | |
| imsize_start, imsize_start), | |
| dtype=x.dtype, device=x.device) | |
| imsize = self._min_fmap_resolution // 2 | |
| for module in self.decoder: | |
| x, mask, batch = module((x, mask, batch)) | |
| if isinstance(module, BasicBlock): | |
| imsize *= 2 | |
| rgb = up(rgb) | |
| mask_out = up(mask_out) | |
| conv = self.rgb_convolutions[str(imsize)] | |
| rgb_, mask_ = conv((x, mask)) | |
| assert rgb_.shape == rgb.shape,\ | |
| f"rgb_ {rgb_.shape}, rgb: {rgb.shape}" | |
| rgb = rgb + rgb_ | |
| return rgb / self.norm_constant, mask_out | |
| def forward_encoder(self, x, mask, batch): | |
| if self.concat_input_mask: | |
| x = torch.cat((x, mask, 1 - mask), dim=1) | |
| unet_features = {} | |
| x, mask = self.from_rgb((x, mask)) | |
| for module in self.encoder: | |
| x, mask, batch = module((x, mask, batch)) | |
| if isinstance(module, BasicBlock): | |
| unet_features[module._resolution] = (x, mask) | |
| return x, mask, unet_features | |
| def forward( | |
| self, | |
| condition, | |
| mask, keypoints=None, z=None, | |
| **kwargs): | |
| keypoints = keypoints.flatten(start_dim=1).clip(-1, 1) | |
| if z is None: | |
| z = self.get_z(condition) | |
| z = z.view(-1, 32, 4, 4) | |
| batch = dict( | |
| landmarks=keypoints, | |
| z=z) | |
| orig_mask = mask | |
| x, mask, unet_features = self.forward_encoder(condition, mask, batch) | |
| batch = dict( | |
| landmarks=keypoints, | |
| z=z, | |
| unet_features=unet_features) | |
| x, mask = self.forward_decoder(x, mask, batch) | |
| x = condition * orig_mask + (1 - orig_mask) * x | |
| return dict(img=x) | |
| def load_state_dict(self, state_dict, strict=True): | |
| if "parameters" in state_dict: | |
| state_dict = state_dict["parameters"] | |
| old_checkpoint = any("basic_block0" in key for key in state_dict) | |
| if not old_checkpoint: | |
| return super().load_state_dict(state_dict, strict=strict) | |
| mapping = {} | |
| imsize = self._min_fmap_resolution | |
| i = 0 | |
| while imsize <= self.max_imsize: | |
| old_key = f"decoder.basic_block{i}." | |
| new_key = f"decoder.basic_block{imsize}." | |
| mapping[old_key] = new_key | |
| if i >= 1: | |
| old_key = old_key.replace("basic_block", "skip_connection") | |
| new_key = new_key.replace("basic_block", "skip_connection") | |
| mapping[old_key] = new_key | |
| mapping[old_key] = new_key | |
| old_key = f"encoder.basic_block{i}." | |
| new_key = f"encoder.basic_block{imsize}." | |
| mapping[old_key] = new_key | |
| old_key = "from_rgb.conv.layers.0." | |
| new_key = "from_rgb.0." | |
| mapping[old_key] = new_key | |
| i += 1 | |
| imsize *= 2 | |
| new_sd = {} | |
| for key, value in state_dict.items(): | |
| old_key = key | |
| if "from_rgb" in key: | |
| new_sd[key.replace("encoder.", "").replace(".conv.layers", "")] = value | |
| continue | |
| for subkey, new_subkey in mapping.items(): | |
| if subkey in key: | |
| old_key = key | |
| key = key.replace(subkey, new_subkey) | |
| break | |
| if "decoder.to_rgb" in key: | |
| continue | |
| new_sd[key] = value | |
| return super().load_state_dict(new_sd, strict=strict) | |
| def update_w(self, *args, **kwargs): | |
| return | |