Spaces:
Running
on
Zero
Running
on
Zero
| # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: LicenseRef-NvidiaProprietary | |
| # | |
| # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual | |
| # property and proprietary rights in and to this material, related | |
| # documentation and any modifications thereto. Any use, reproduction, | |
| # disclosure or distribution of this material and related documentation | |
| # without an express license agreement from NVIDIA CORPORATION or | |
| # its affiliates is strictly prohibited. | |
| """Superresolution network architectures from the paper | |
| "Efficient Geometry-aware 3D Generative Adversarial Networks".""" | |
| import torch | |
| from nsr.networks_stylegan2 import Conv2dLayer, SynthesisLayer, ToRGBLayer | |
| from utils.torch_utils.ops import upfirdn2d | |
| from utils.torch_utils import persistence | |
| from utils.torch_utils import misc | |
| from nsr.networks_stylegan2 import SynthesisBlock | |
| import numpy as np | |
| from pdb import set_trace as st | |
| class SynthesisBlockNoUp(torch.nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, # Number of input channels, 0 = first block. | |
| out_channels, # Number of output channels. | |
| w_dim, # Intermediate latent (W) dimensionality. | |
| resolution, # Resolution of this block. | |
| img_channels, # Number of output color channels. | |
| is_last, # Is this the last block? | |
| architecture='skip', # Architecture: 'orig', 'skip', 'resnet'. | |
| resample_filter=[ | |
| 1, 3, 3, 1 | |
| ], # Low-pass filter to apply when resampling activations. | |
| conv_clamp=256, # Clamp the output of convolution layers to +-X, None = disable clamping. | |
| use_fp16=False, # Use FP16 for this block? | |
| fp16_channels_last=False, # Use channels-last memory format with FP16? | |
| fused_modconv_default=True, # Default value of fused_modconv. 'inference_only' = True for inference, False for training. | |
| **layer_kwargs, # Arguments for SynthesisLayer. | |
| ): | |
| assert architecture in ['orig', 'skip', 'resnet'] | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.w_dim = w_dim | |
| self.resolution = resolution | |
| self.img_channels = img_channels | |
| self.is_last = is_last | |
| self.architecture = architecture | |
| self.use_fp16 = use_fp16 | |
| self.channels_last = (use_fp16 and fp16_channels_last) | |
| self.fused_modconv_default = fused_modconv_default | |
| self.register_buffer('resample_filter', | |
| upfirdn2d.setup_filter(resample_filter)) | |
| self.num_conv = 0 | |
| self.num_torgb = 0 | |
| if in_channels == 0: | |
| self.const = torch.nn.Parameter( | |
| torch.randn([out_channels, resolution, resolution])) | |
| if in_channels != 0: | |
| self.conv0 = SynthesisLayer(in_channels, | |
| out_channels, | |
| w_dim=w_dim, | |
| resolution=resolution, | |
| conv_clamp=conv_clamp, | |
| channels_last=self.channels_last, | |
| **layer_kwargs) | |
| self.num_conv += 1 | |
| self.conv1 = SynthesisLayer(out_channels, | |
| out_channels, | |
| w_dim=w_dim, | |
| resolution=resolution, | |
| conv_clamp=conv_clamp, | |
| channels_last=self.channels_last, | |
| **layer_kwargs) | |
| self.num_conv += 1 | |
| if is_last or architecture == 'skip': | |
| self.torgb = ToRGBLayer(out_channels, | |
| img_channels, | |
| w_dim=w_dim, | |
| conv_clamp=conv_clamp, | |
| channels_last=self.channels_last) | |
| self.num_torgb += 1 | |
| if in_channels != 0 and architecture == 'resnet': | |
| self.skip = Conv2dLayer(in_channels, | |
| out_channels, | |
| kernel_size=1, | |
| bias=False, | |
| up=2, | |
| resample_filter=resample_filter, | |
| channels_last=self.channels_last) | |
| def forward(self, | |
| x, | |
| img, | |
| ws, | |
| force_fp32=False, | |
| fused_modconv=None, | |
| update_emas=False, | |
| **layer_kwargs): | |
| _ = update_emas # unused | |
| misc.assert_shape(ws, | |
| [None, self.num_conv + self.num_torgb, self.w_dim]) | |
| w_iter = iter(ws.unbind(dim=1)) | |
| if ws.device.type != 'cuda': | |
| force_fp32 = True | |
| dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 | |
| memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format | |
| if fused_modconv is None: | |
| fused_modconv = self.fused_modconv_default | |
| if fused_modconv == 'inference_only': | |
| fused_modconv = (not self.training) | |
| # Input. | |
| if self.in_channels == 0: | |
| x = self.const.to(dtype=dtype, memory_format=memory_format) | |
| x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1]) | |
| else: | |
| misc.assert_shape( | |
| x, [None, self.in_channels, self.resolution, self.resolution]) | |
| x = x.to(dtype=dtype, memory_format=memory_format) | |
| # Main layers. | |
| if self.in_channels == 0: | |
| x = self.conv1(x, | |
| next(w_iter), | |
| fused_modconv=fused_modconv, | |
| **layer_kwargs) | |
| elif self.architecture == 'resnet': | |
| y = self.skip(x, gain=np.sqrt(0.5)) | |
| x = self.conv0(x, | |
| next(w_iter), | |
| fused_modconv=fused_modconv, | |
| **layer_kwargs) | |
| x = self.conv1(x, | |
| next(w_iter), | |
| fused_modconv=fused_modconv, | |
| gain=np.sqrt(0.5), | |
| **layer_kwargs) | |
| x = y.add_(x) | |
| else: | |
| x = self.conv0(x, | |
| next(w_iter), | |
| fused_modconv=fused_modconv, | |
| **layer_kwargs) | |
| x = self.conv1(x, | |
| next(w_iter), | |
| fused_modconv=fused_modconv, | |
| **layer_kwargs) | |
| # ToRGB. | |
| # if img is not None: | |
| # misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2]) | |
| # img = upfirdn2d.upsample2d(img, self.resample_filter) | |
| if self.is_last or self.architecture == 'skip': | |
| y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv) | |
| y = y.to(dtype=torch.float32, | |
| memory_format=torch.contiguous_format) | |
| img = img.add_(y) if img is not None else y | |
| # assert x.dtype == dtype # support AMP in this library | |
| assert img is None or img.dtype == torch.float32 | |
| return x, img | |
| def extra_repr(self): | |
| return f'resolution={self.resolution:d}, architecture={self.architecture:s}' | |
| #---------------------------------------------------------------------------- | |
| # for 512x512 generation | |
| class SuperresolutionHybrid8X(torch.nn.Module): | |
| def __init__( | |
| self, | |
| channels, | |
| img_resolution, | |
| sr_num_fp16_res, | |
| sr_antialias, | |
| num_fp16_res=4, | |
| conv_clamp=None, | |
| channel_base=None, | |
| channel_max=None, # IGNORE | |
| **block_kwargs): | |
| super().__init__() | |
| # assert img_resolution == 512 | |
| use_fp16 = sr_num_fp16_res > 0 | |
| self.input_resolution = 128 | |
| self.sr_antialias = sr_antialias | |
| self.block0 = SynthesisBlock(channels, | |
| 128, | |
| w_dim=512, | |
| resolution=256, | |
| img_channels=3, | |
| is_last=False, | |
| use_fp16=use_fp16, | |
| conv_clamp=(256 if use_fp16 else None), | |
| **block_kwargs) | |
| self.block1 = SynthesisBlock(128, | |
| 64, | |
| w_dim=512, | |
| resolution=512, | |
| img_channels=3, | |
| is_last=True, | |
| use_fp16=use_fp16, | |
| conv_clamp=(256 if use_fp16 else None), | |
| **block_kwargs) | |
| self.register_buffer('resample_filter', | |
| upfirdn2d.setup_filter([1, 3, 3, 1])) | |
| def forward(self, rgb, x, ws, **block_kwargs): | |
| ws = ws[:, -1:, :].repeat(1, 3, 1) | |
| if x.shape[-1] != self.input_resolution: | |
| x = torch.nn.functional.interpolate(x, | |
| size=(self.input_resolution, | |
| self.input_resolution), | |
| mode='bilinear', | |
| align_corners=False, | |
| antialias=self.sr_antialias) | |
| rgb = torch.nn.functional.interpolate(rgb, | |
| size=(self.input_resolution, | |
| self.input_resolution), | |
| mode='bilinear', | |
| align_corners=False, | |
| antialias=self.sr_antialias) | |
| x, rgb = self.block0(x, rgb, ws, **block_kwargs) # block_kwargs: {'noise_mode': 'none'} | |
| x, rgb = self.block1(x, rgb, ws, **block_kwargs) | |
| return rgb | |
| #---------------------------------------------------------------------------- | |
| # for 256x256 generation | |
| class SuperresolutionHybrid4X(torch.nn.Module): | |
| def __init__( | |
| self, | |
| channels, | |
| img_resolution, | |
| sr_num_fp16_res, | |
| sr_antialias, | |
| num_fp16_res=4, | |
| conv_clamp=None, | |
| channel_base=None, | |
| channel_max=None, # IGNORE | |
| **block_kwargs): | |
| super().__init__() | |
| # assert img_resolution == 256 | |
| use_fp16 = sr_num_fp16_res > 0 | |
| self.sr_antialias = sr_antialias | |
| self.input_resolution = 128 | |
| self.block0 = SynthesisBlockNoUp( | |
| channels, | |
| 128, | |
| w_dim=512, | |
| resolution=128, | |
| img_channels=3, | |
| is_last=False, | |
| use_fp16=use_fp16, | |
| conv_clamp=(256 if use_fp16 else None), | |
| **block_kwargs) | |
| self.block1 = SynthesisBlock(128, | |
| 64, | |
| w_dim=512, | |
| resolution=256, | |
| img_channels=3, | |
| is_last=True, | |
| use_fp16=use_fp16, | |
| conv_clamp=(256 if use_fp16 else None), | |
| **block_kwargs) | |
| self.register_buffer('resample_filter', | |
| upfirdn2d.setup_filter([1, 3, 3, 1])) | |
| def forward(self, rgb, x, ws, **block_kwargs): | |
| ws = ws[:, -1:, :].repeat(1, 3, 1) | |
| if x.shape[-1] < self.input_resolution: | |
| x = torch.nn.functional.interpolate(x, | |
| size=(self.input_resolution, | |
| self.input_resolution), | |
| mode='bilinear', | |
| align_corners=False, | |
| antialias=self.sr_antialias) | |
| rgb = torch.nn.functional.interpolate(rgb, | |
| size=(self.input_resolution, | |
| self.input_resolution), | |
| mode='bilinear', | |
| align_corners=False, | |
| antialias=self.sr_antialias) | |
| x, rgb = self.block0(x, rgb, ws, **block_kwargs) | |
| x, rgb = self.block1(x, rgb, ws, **block_kwargs) | |
| return rgb | |
| #---------------------------------------------------------------------------- | |
| # for 128 x 128 generation | |
| class SuperresolutionHybrid2X(torch.nn.Module): | |
| def __init__( | |
| self, | |
| channels, | |
| img_resolution, | |
| sr_num_fp16_res, | |
| sr_antialias, | |
| num_fp16_res=4, | |
| conv_clamp=None, | |
| channel_base=None, | |
| channel_max=None, # IGNORE | |
| **block_kwargs): | |
| super().__init__() | |
| assert img_resolution == 128 | |
| use_fp16 = sr_num_fp16_res > 0 | |
| self.input_resolution = 64 | |
| # self.input_resolution = 128 | |
| self.sr_antialias = sr_antialias | |
| self.block0 = SynthesisBlockNoUp( | |
| channels, | |
| 128, | |
| w_dim=512, | |
| resolution=64, | |
| # resolution=128, | |
| img_channels=3, | |
| is_last=False, | |
| use_fp16=use_fp16, | |
| conv_clamp=(256 if use_fp16 else None), | |
| **block_kwargs) | |
| self.block1 = SynthesisBlock(128, | |
| 64, | |
| w_dim=512, | |
| resolution=128, | |
| # resolution=256, | |
| img_channels=3, | |
| is_last=True, | |
| use_fp16=use_fp16, | |
| conv_clamp=(256 if use_fp16 else None), | |
| **block_kwargs) | |
| self.register_buffer('resample_filter', | |
| upfirdn2d.setup_filter([1, 3, 3, 1])) | |
| def forward(self, rgb, x, ws, **block_kwargs): | |
| ws = ws[:, -1:, :].repeat(1, 3, 1) | |
| if x.shape[-1] != self.input_resolution: | |
| x = torch.nn.functional.interpolate(x, | |
| size=(self.input_resolution, | |
| self.input_resolution), | |
| mode='bilinear', | |
| align_corners=False, | |
| antialias=self.sr_antialias) | |
| rgb = torch.nn.functional.interpolate(rgb, | |
| size=(self.input_resolution, | |
| self.input_resolution), | |
| mode='bilinear', | |
| align_corners=False, | |
| antialias=self.sr_antialias) | |
| x, rgb = self.block0(x, rgb, ws, **block_kwargs) | |
| x, rgb = self.block1(x, rgb, ws, **block_kwargs) | |
| return rgb | |
| #---------------------------------------------------------------------------- | |
| # for 512x512 generation | |
| class SuperresolutionHybrid8XDC(torch.nn.Module): | |
| def __init__( | |
| self, | |
| channels, | |
| img_resolution, | |
| sr_num_fp16_res, | |
| sr_antialias, | |
| num_fp16_res=4, | |
| conv_clamp=None, | |
| channel_base=None, | |
| channel_max=None, # IGNORE | |
| **block_kwargs): | |
| super().__init__() | |
| # assert img_resolution == 512 | |
| use_fp16 = sr_num_fp16_res > 0 | |
| self.input_resolution = 128 | |
| self.sr_antialias = sr_antialias | |
| self.block0 = SynthesisBlock(channels, | |
| 256, | |
| w_dim=512, | |
| resolution=256, | |
| img_channels=3, | |
| is_last=False, | |
| use_fp16=use_fp16, | |
| conv_clamp=(256 if use_fp16 else None), | |
| **block_kwargs) | |
| self.block1 = SynthesisBlock(256, | |
| 128, | |
| w_dim=512, | |
| resolution=512, | |
| img_channels=3, | |
| is_last=True, | |
| use_fp16=use_fp16, | |
| conv_clamp=(256 if use_fp16 else None), | |
| **block_kwargs) | |
| def forward(self, rgb, x, ws, base_x=None, **block_kwargs): | |
| ws = ws[:, -1:, :].repeat(1, 3, 1) # BS 3 512 | |
| # st() | |
| if x.shape[-1] != self.input_resolution: # resize 64 => 128 | |
| x = torch.nn.functional.interpolate(x, | |
| size=(self.input_resolution, | |
| self.input_resolution), | |
| mode='bilinear', | |
| align_corners=False, | |
| antialias=self.sr_antialias) | |
| rgb = torch.nn.functional.interpolate(rgb, | |
| size=(self.input_resolution, | |
| self.input_resolution), | |
| mode='bilinear', | |
| align_corners=False, | |
| antialias=self.sr_antialias) | |
| x, rgb = self.block0(x, rgb, ws, **block_kwargs) | |
| # print(f'device={self.block0.conv1.weight.device}') | |
| x, rgb = self.block1(x, rgb, ws, **block_kwargs) | |
| # print(f'device={self.block1.conv1.weight.device}') | |
| return rgb | |
| #---------------------------------------------------------------------------- | |