Spaces:
Runtime error
Runtime error
| """Modified from https://github.com/chaofengc/PSFRGAN | |
| """ | |
| import numpy as np | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| class NormLayer(nn.Module): | |
| """Normalization Layers. | |
| Args: | |
| channels: input channels, for batch norm and instance norm. | |
| input_size: input shape without batch size, for layer norm. | |
| """ | |
| def __init__(self, channels, normalize_shape=None, norm_type='bn'): | |
| super(NormLayer, self).__init__() | |
| norm_type = norm_type.lower() | |
| self.norm_type = norm_type | |
| if norm_type == 'bn': | |
| self.norm = nn.BatchNorm2d(channels, affine=True) | |
| elif norm_type == 'in': | |
| self.norm = nn.InstanceNorm2d(channels, affine=False) | |
| elif norm_type == 'gn': | |
| self.norm = nn.GroupNorm(32, channels, affine=True) | |
| elif norm_type == 'pixel': | |
| self.norm = lambda x: F.normalize(x, p=2, dim=1) | |
| elif norm_type == 'layer': | |
| self.norm = nn.LayerNorm(normalize_shape) | |
| elif norm_type == 'none': | |
| self.norm = lambda x: x * 1.0 | |
| else: | |
| assert 1 == 0, f'Norm type {norm_type} not support.' | |
| def forward(self, x, ref=None): | |
| if self.norm_type == 'spade': | |
| return self.norm(x, ref) | |
| else: | |
| return self.norm(x) | |
| class ReluLayer(nn.Module): | |
| """Relu Layer. | |
| Args: | |
| relu type: type of relu layer, candidates are | |
| - ReLU | |
| - LeakyReLU: default relu slope 0.2 | |
| - PRelu | |
| - SELU | |
| - none: direct pass | |
| """ | |
| def __init__(self, channels, relu_type='relu'): | |
| super(ReluLayer, self).__init__() | |
| relu_type = relu_type.lower() | |
| if relu_type == 'relu': | |
| self.func = nn.ReLU(True) | |
| elif relu_type == 'leakyrelu': | |
| self.func = nn.LeakyReLU(0.2, inplace=True) | |
| elif relu_type == 'prelu': | |
| self.func = nn.PReLU(channels) | |
| elif relu_type == 'selu': | |
| self.func = nn.SELU(True) | |
| elif relu_type == 'none': | |
| self.func = lambda x: x * 1.0 | |
| else: | |
| assert 1 == 0, f'Relu type {relu_type} not support.' | |
| def forward(self, x): | |
| return self.func(x) | |
| class ConvLayer(nn.Module): | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| kernel_size=3, | |
| scale='none', | |
| norm_type='none', | |
| relu_type='none', | |
| use_pad=True, | |
| bias=True): | |
| super(ConvLayer, self).__init__() | |
| self.use_pad = use_pad | |
| self.norm_type = norm_type | |
| if norm_type in ['bn']: | |
| bias = False | |
| stride = 2 if scale == 'down' else 1 | |
| self.scale_func = lambda x: x | |
| if scale == 'up': | |
| self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest') | |
| self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.) / 2))) | |
| self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias) | |
| self.relu = ReluLayer(out_channels, relu_type) | |
| self.norm = NormLayer(out_channels, norm_type=norm_type) | |
| def forward(self, x): | |
| out = self.scale_func(x) | |
| if self.use_pad: | |
| out = self.reflection_pad(out) | |
| out = self.conv2d(out) | |
| out = self.norm(out) | |
| out = self.relu(out) | |
| return out | |
| class ResidualBlock(nn.Module): | |
| """ | |
| Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html | |
| """ | |
| def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', scale='none'): | |
| super(ResidualBlock, self).__init__() | |
| if scale == 'none' and c_in == c_out: | |
| self.shortcut_func = lambda x: x | |
| else: | |
| self.shortcut_func = ConvLayer(c_in, c_out, 3, scale) | |
| scale_config_dict = {'down': ['none', 'down'], 'up': ['up', 'none'], 'none': ['none', 'none']} | |
| scale_conf = scale_config_dict[scale] | |
| self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type) | |
| self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type='none') | |
| def forward(self, x): | |
| identity = self.shortcut_func(x) | |
| res = self.conv1(x) | |
| res = self.conv2(res) | |
| return identity + res | |
| class ParseNet(nn.Module): | |
| def __init__(self, | |
| in_size=128, | |
| out_size=128, | |
| min_feat_size=32, | |
| base_ch=64, | |
| parsing_ch=19, | |
| res_depth=10, | |
| relu_type='LeakyReLU', | |
| norm_type='bn', | |
| ch_range=[32, 256]): | |
| super().__init__() | |
| self.res_depth = res_depth | |
| act_args = {'norm_type': norm_type, 'relu_type': relu_type} | |
| min_ch, max_ch = ch_range | |
| ch_clip = lambda x: max(min_ch, min(x, max_ch)) # noqa: E731 | |
| min_feat_size = min(in_size, min_feat_size) | |
| down_steps = int(np.log2(in_size // min_feat_size)) | |
| up_steps = int(np.log2(out_size // min_feat_size)) | |
| # =============== define encoder-body-decoder ==================== | |
| self.encoder = [] | |
| self.encoder.append(ConvLayer(3, base_ch, 3, 1)) | |
| head_ch = base_ch | |
| for i in range(down_steps): | |
| cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2) | |
| self.encoder.append(ResidualBlock(cin, cout, scale='down', **act_args)) | |
| head_ch = head_ch * 2 | |
| self.body = [] | |
| for i in range(res_depth): | |
| self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args)) | |
| self.decoder = [] | |
| for i in range(up_steps): | |
| cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2) | |
| self.decoder.append(ResidualBlock(cin, cout, scale='up', **act_args)) | |
| head_ch = head_ch // 2 | |
| self.encoder = nn.Sequential(*self.encoder) | |
| self.body = nn.Sequential(*self.body) | |
| self.decoder = nn.Sequential(*self.decoder) | |
| self.out_img_conv = ConvLayer(ch_clip(head_ch), 3) | |
| self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch) | |
| def forward(self, x): | |
| feat = self.encoder(x) | |
| x = feat + self.body(feat) | |
| x = self.decoder(x) | |
| out_img = self.out_img_conv(x) | |
| out_mask = self.out_mask_conv(x) | |
| return out_mask, out_img | |