diff --git a/LICENSE b/Time_TravelRephotography/LICENSE similarity index 100% rename from LICENSE rename to Time_TravelRephotography/LICENSE diff --git a/LICENSE-NVIDIA b/Time_TravelRephotography/LICENSE-NVIDIA similarity index 100% rename from LICENSE-NVIDIA rename to Time_TravelRephotography/LICENSE-NVIDIA diff --git a/LICENSE-STYLEGAN2 b/Time_TravelRephotography/LICENSE-STYLEGAN2 similarity index 100% rename from LICENSE-STYLEGAN2 rename to Time_TravelRephotography/LICENSE-STYLEGAN2 diff --git a/dnnlib/__init__.py b/Time_TravelRephotography/dnnlib/__init__.py similarity index 100% rename from dnnlib/__init__.py rename to Time_TravelRephotography/dnnlib/__init__.py diff --git a/dnnlib/tflib/__init__.py b/Time_TravelRephotography/dnnlib/tflib/__init__.py similarity index 100% rename from dnnlib/tflib/__init__.py rename to Time_TravelRephotography/dnnlib/tflib/__init__.py diff --git a/dnnlib/tflib/autosummary.py b/Time_TravelRephotography/dnnlib/tflib/autosummary.py similarity index 100% rename from dnnlib/tflib/autosummary.py rename to Time_TravelRephotography/dnnlib/tflib/autosummary.py diff --git a/dnnlib/tflib/custom_ops.py b/Time_TravelRephotography/dnnlib/tflib/custom_ops.py similarity index 100% rename from dnnlib/tflib/custom_ops.py rename to Time_TravelRephotography/dnnlib/tflib/custom_ops.py diff --git a/dnnlib/tflib/network.py b/Time_TravelRephotography/dnnlib/tflib/network.py similarity index 100% rename from dnnlib/tflib/network.py rename to Time_TravelRephotography/dnnlib/tflib/network.py diff --git a/dnnlib/tflib/ops/__init__.py b/Time_TravelRephotography/dnnlib/tflib/ops/__init__.py similarity index 100% rename from dnnlib/tflib/ops/__init__.py rename to Time_TravelRephotography/dnnlib/tflib/ops/__init__.py diff --git a/dnnlib/tflib/ops/fused_bias_act.cu b/Time_TravelRephotography/dnnlib/tflib/ops/fused_bias_act.cu similarity index 100% rename from dnnlib/tflib/ops/fused_bias_act.cu rename to Time_TravelRephotography/dnnlib/tflib/ops/fused_bias_act.cu diff --git a/dnnlib/tflib/ops/fused_bias_act.py b/Time_TravelRephotography/dnnlib/tflib/ops/fused_bias_act.py similarity index 100% rename from dnnlib/tflib/ops/fused_bias_act.py rename to Time_TravelRephotography/dnnlib/tflib/ops/fused_bias_act.py diff --git a/dnnlib/tflib/ops/upfirdn_2d.cu b/Time_TravelRephotography/dnnlib/tflib/ops/upfirdn_2d.cu similarity index 100% rename from dnnlib/tflib/ops/upfirdn_2d.cu rename to Time_TravelRephotography/dnnlib/tflib/ops/upfirdn_2d.cu diff --git a/dnnlib/tflib/ops/upfirdn_2d.py b/Time_TravelRephotography/dnnlib/tflib/ops/upfirdn_2d.py similarity index 100% rename from dnnlib/tflib/ops/upfirdn_2d.py rename to Time_TravelRephotography/dnnlib/tflib/ops/upfirdn_2d.py diff --git a/dnnlib/tflib/optimizer.py b/Time_TravelRephotography/dnnlib/tflib/optimizer.py similarity index 100% rename from dnnlib/tflib/optimizer.py rename to Time_TravelRephotography/dnnlib/tflib/optimizer.py diff --git a/dnnlib/tflib/tfutil.py b/Time_TravelRephotography/dnnlib/tflib/tfutil.py similarity index 100% rename from dnnlib/tflib/tfutil.py rename to Time_TravelRephotography/dnnlib/tflib/tfutil.py diff --git a/dnnlib/util.py b/Time_TravelRephotography/dnnlib/util.py similarity index 100% rename from dnnlib/util.py rename to Time_TravelRephotography/dnnlib/util.py diff --git a/losses/color_transfer_loss.py b/Time_TravelRephotography/losses/color_transfer_loss.py similarity index 100% rename from losses/color_transfer_loss.py rename to Time_TravelRephotography/losses/color_transfer_loss.py diff --git a/losses/joint_loss.py b/Time_TravelRephotography/losses/joint_loss.py similarity index 100% rename from losses/joint_loss.py rename to Time_TravelRephotography/losses/joint_loss.py diff --git a/losses/perceptual_loss.py b/Time_TravelRephotography/losses/perceptual_loss.py similarity index 100% rename from losses/perceptual_loss.py rename to Time_TravelRephotography/losses/perceptual_loss.py diff --git a/losses/reconstruction.py b/Time_TravelRephotography/losses/reconstruction.py similarity index 100% rename from losses/reconstruction.py rename to Time_TravelRephotography/losses/reconstruction.py diff --git a/losses/regularize_noise.py b/Time_TravelRephotography/losses/regularize_noise.py similarity index 100% rename from losses/regularize_noise.py rename to Time_TravelRephotography/losses/regularize_noise.py diff --git a/model.py b/Time_TravelRephotography/model.py similarity index 100% rename from model.py rename to Time_TravelRephotography/model.py diff --git a/models/__init__.py b/Time_TravelRephotography/models/__init__.py similarity index 100% rename from models/__init__.py rename to Time_TravelRephotography/models/__init__.py diff --git a/models/degrade.py b/Time_TravelRephotography/models/degrade.py similarity index 100% rename from models/degrade.py rename to Time_TravelRephotography/models/degrade.py diff --git a/models/encoder.py b/Time_TravelRephotography/models/encoder.py similarity index 100% rename from models/encoder.py rename to Time_TravelRephotography/models/encoder.py diff --git a/models/gaussian_smoothing.py b/Time_TravelRephotography/models/gaussian_smoothing.py similarity index 100% rename from models/gaussian_smoothing.py rename to Time_TravelRephotography/models/gaussian_smoothing.py diff --git a/models/resnet.py b/Time_TravelRephotography/models/resnet.py similarity index 100% rename from models/resnet.py rename to Time_TravelRephotography/models/resnet.py diff --git a/models/vggface.py b/Time_TravelRephotography/models/vggface.py similarity index 100% rename from models/vggface.py rename to Time_TravelRephotography/models/vggface.py diff --git a/op/__init__.py b/Time_TravelRephotography/op/__init__.py similarity index 100% rename from op/__init__.py rename to Time_TravelRephotography/op/__init__.py diff --git a/op/fused_act.py b/Time_TravelRephotography/op/fused_act.py similarity index 100% rename from op/fused_act.py rename to Time_TravelRephotography/op/fused_act.py diff --git a/op/fused_bias_act.cpp b/Time_TravelRephotography/op/fused_bias_act.cpp similarity index 100% rename from op/fused_bias_act.cpp rename to Time_TravelRephotography/op/fused_bias_act.cpp diff --git a/op/fused_bias_act_kernel.cu b/Time_TravelRephotography/op/fused_bias_act_kernel.cu similarity index 100% rename from op/fused_bias_act_kernel.cu rename to Time_TravelRephotography/op/fused_bias_act_kernel.cu diff --git a/op/upfirdn2d.cpp b/Time_TravelRephotography/op/upfirdn2d.cpp similarity index 100% rename from op/upfirdn2d.cpp rename to Time_TravelRephotography/op/upfirdn2d.cpp diff --git a/op/upfirdn2d.py b/Time_TravelRephotography/op/upfirdn2d.py similarity index 100% rename from op/upfirdn2d.py rename to Time_TravelRephotography/op/upfirdn2d.py diff --git a/op/upfirdn2d_kernel.cu b/Time_TravelRephotography/op/upfirdn2d_kernel.cu similarity index 100% rename from op/upfirdn2d_kernel.cu rename to Time_TravelRephotography/op/upfirdn2d_kernel.cu diff --git a/optim/__init__.py b/Time_TravelRephotography/optim/__init__.py similarity index 100% rename from optim/__init__.py rename to Time_TravelRephotography/optim/__init__.py diff --git a/optim/radam.py b/Time_TravelRephotography/optim/radam.py similarity index 100% rename from optim/radam.py rename to Time_TravelRephotography/optim/radam.py diff --git a/projector.py b/Time_TravelRephotography/projector.py similarity index 100% rename from projector.py rename to Time_TravelRephotography/projector.py diff --git a/scripts/download_checkpoints.sh b/Time_TravelRephotography/scripts/download_checkpoints.sh similarity index 100% rename from scripts/download_checkpoints.sh rename to Time_TravelRephotography/scripts/download_checkpoints.sh diff --git a/scripts/install.sh b/Time_TravelRephotography/scripts/install.sh similarity index 100% rename from scripts/install.sh rename to Time_TravelRephotography/scripts/install.sh diff --git a/scripts/run.sh b/Time_TravelRephotography/scripts/run.sh similarity index 100% rename from scripts/run.sh rename to Time_TravelRephotography/scripts/run.sh diff --git a/tools/__init__.py b/Time_TravelRephotography/tools/__init__.py similarity index 100% rename from tools/__init__.py rename to Time_TravelRephotography/tools/__init__.py diff --git a/tools/data/__init__.py b/Time_TravelRephotography/tools/data/__init__.py similarity index 100% rename from tools/data/__init__.py rename to Time_TravelRephotography/tools/data/__init__.py diff --git a/tools/data/align_images.py b/Time_TravelRephotography/tools/data/align_images.py similarity index 100% rename from tools/data/align_images.py rename to Time_TravelRephotography/tools/data/align_images.py diff --git a/tools/initialize.py b/Time_TravelRephotography/tools/initialize.py similarity index 100% rename from tools/initialize.py rename to Time_TravelRephotography/tools/initialize.py diff --git a/tools/match_histogram.py b/Time_TravelRephotography/tools/match_histogram.py similarity index 100% rename from tools/match_histogram.py rename to Time_TravelRephotography/tools/match_histogram.py diff --git a/tools/match_skin_histogram.py b/Time_TravelRephotography/tools/match_skin_histogram.py similarity index 100% rename from tools/match_skin_histogram.py rename to Time_TravelRephotography/tools/match_skin_histogram.py diff --git a/tools/parse_face.py b/Time_TravelRephotography/tools/parse_face.py similarity index 100% rename from tools/parse_face.py rename to Time_TravelRephotography/tools/parse_face.py diff --git a/utils/__init__.py b/Time_TravelRephotography/utils/__init__.py similarity index 100% rename from utils/__init__.py rename to Time_TravelRephotography/utils/__init__.py diff --git a/utils/ffhq_dataset/__init__.py b/Time_TravelRephotography/utils/ffhq_dataset/__init__.py similarity index 100% rename from utils/ffhq_dataset/__init__.py rename to Time_TravelRephotography/utils/ffhq_dataset/__init__.py diff --git a/utils/ffhq_dataset/face_alignment.py b/Time_TravelRephotography/utils/ffhq_dataset/face_alignment.py similarity index 100% rename from utils/ffhq_dataset/face_alignment.py rename to Time_TravelRephotography/utils/ffhq_dataset/face_alignment.py diff --git a/utils/ffhq_dataset/landmarks_detector.py b/Time_TravelRephotography/utils/ffhq_dataset/landmarks_detector.py similarity index 100% rename from utils/ffhq_dataset/landmarks_detector.py rename to Time_TravelRephotography/utils/ffhq_dataset/landmarks_detector.py diff --git a/utils/misc.py b/Time_TravelRephotography/utils/misc.py similarity index 100% rename from utils/misc.py rename to Time_TravelRephotography/utils/misc.py diff --git a/utils/optimize.py b/Time_TravelRephotography/utils/optimize.py similarity index 100% rename from utils/optimize.py rename to Time_TravelRephotography/utils/optimize.py diff --git a/utils/projector_arguments.py b/Time_TravelRephotography/utils/projector_arguments.py similarity index 100% rename from utils/projector_arguments.py rename to Time_TravelRephotography/utils/projector_arguments.py diff --git a/utils/torch_helpers.py b/Time_TravelRephotography/utils/torch_helpers.py similarity index 100% rename from utils/torch_helpers.py rename to Time_TravelRephotography/utils/torch_helpers.py diff --git a/torch_utils/__init__.py b/torch_utils/__init__.py deleted file mode 100644 index a0b0f4efcbe1e3cd4199eeecb043d5afe1548307..0000000000000000000000000000000000000000 --- a/torch_utils/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright (c) SenseTime Research. All rights reserved. - -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -# empty diff --git a/torch_utils/custom_ops.py b/torch_utils/custom_ops.py deleted file mode 100644 index fda77a69777a69bd3eda96713c29f66fe3b016b9..0000000000000000000000000000000000000000 --- a/torch_utils/custom_ops.py +++ /dev/null @@ -1,238 +0,0 @@ -# Copyright (c) SenseTime Research. All rights reserved. - -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -import os -import glob -import torch -import torch.utils.cpp_extension -import importlib -import hashlib -import shutil -from pathlib import Path -import re -import uuid - -from torch.utils.file_baton import FileBaton - -#---------------------------------------------------------------------------- -# Global options. - -verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' - -#---------------------------------------------------------------------------- -# Internal helper funcs. - -def _find_compiler_bindir(): - patterns = [ - 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', - 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', - 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', - 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', - ] - for pattern in patterns: - matches = sorted(glob.glob(pattern)) - if len(matches): - return matches[-1] - return None - -def _get_mangled_gpu_name(): - name = torch.cuda.get_device_name().lower() - out = [] - for c in name: - if re.match('[a-z0-9_-]+', c): - out.append(c) - else: - out.append('-') - return ''.join(out) - - -#---------------------------------------------------------------------------- -# Main entry point for compiling and loading C++/CUDA plugins. - -_cached_plugins = dict() - -def get_plugin(module_name, sources, **build_kwargs): - assert verbosity in ['none', 'brief', 'full'] - - # Already cached? - if module_name in _cached_plugins: - return _cached_plugins[module_name] - - # Print status. - if verbosity == 'full': - print(f'Setting up PyTorch plugin "{module_name}"...') - elif verbosity == 'brief': - print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) - - try: # pylint: disable=too-many-nested-blocks - # Make sure we can find the necessary compiler binaries. - if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: - compiler_bindir = _find_compiler_bindir() - if compiler_bindir is None: - raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') - os.environ['PATH'] += ';' + compiler_bindir - - # Compile and load. - verbose_build = (verbosity == 'full') - - # Incremental build md5sum trickery. Copies all the input source files - # into a cached build directory under a combined md5 digest of the input - # source files. Copying is done only if the combined digest has changed. - # This keeps input file timestamps and filenames the same as in previous - # extension builds, allowing for fast incremental rebuilds. - # - # This optimization is done only in case all the source files reside in - # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR - # environment variable is set (we take this as a signal that the user - # actually cares about this.) - source_dirs_set = set(os.path.dirname(source) for source in sources) - if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ): - all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())) - - # Compute a combined hash digest for all source files in the same - # custom op directory (usually .cu, .cpp, .py and .h files). - hash_md5 = hashlib.md5() - for src in all_source_files: - with open(src, 'rb') as f: - hash_md5.update(f.read()) - build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access - digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) - - if not os.path.isdir(digest_build_dir): - os.makedirs(digest_build_dir, exist_ok=True) - baton = FileBaton(os.path.join(digest_build_dir, 'lock')) - if baton.try_acquire(): - try: - for src in all_source_files: - shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src))) - finally: - baton.release() - else: - # Someone else is copying source files under the digest dir, - # wait until done and continue. - baton.wait() - digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources] - torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir, - verbose=verbose_build, sources=digest_sources, **build_kwargs) - else: - torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) - module = importlib.import_module(module_name) - - except: - if verbosity == 'brief': - print('Failed!') - raise - - # Print status and add to cache. - if verbosity == 'full': - print(f'Done setting up PyTorch plugin "{module_name}".') - elif verbosity == 'brief': - print('Done.') - _cached_plugins[module_name] = module - return module - -#---------------------------------------------------------------------------- -def get_plugin_v3(module_name, sources, headers=None, source_dir=None, **build_kwargs): - assert verbosity in ['none', 'brief', 'full'] - if headers is None: - headers = [] - if source_dir is not None: - sources = [os.path.join(source_dir, fname) for fname in sources] - headers = [os.path.join(source_dir, fname) for fname in headers] - - # Already cached? - if module_name in _cached_plugins: - return _cached_plugins[module_name] - - # Print status. - if verbosity == 'full': - print(f'Setting up PyTorch plugin "{module_name}"...') - elif verbosity == 'brief': - print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) - verbose_build = (verbosity == 'full') - - # Compile and load. - try: # pylint: disable=too-many-nested-blocks - # Make sure we can find the necessary compiler binaries. - if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: - compiler_bindir = _find_compiler_bindir() - if compiler_bindir is None: - raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') - os.environ['PATH'] += ';' + compiler_bindir - - # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either - # break the build or unnecessarily restrict what's available to nvcc. - # Unset it to let nvcc decide based on what's available on the - # machine. - os.environ['TORCH_CUDA_ARCH_LIST'] = '' - - # Incremental build md5sum trickery. Copies all the input source files - # into a cached build directory under a combined md5 digest of the input - # source files. Copying is done only if the combined digest has changed. - # This keeps input file timestamps and filenames the same as in previous - # extension builds, allowing for fast incremental rebuilds. - # - # This optimization is done only in case all the source files reside in - # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR - # environment variable is set (we take this as a signal that the user - # actually cares about this.) - # - # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work - # around the *.cu dependency bug in ninja config. - # - all_source_files = sorted(sources + headers) - all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files) - if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ): - - # Compute combined hash digest for all source files. - hash_md5 = hashlib.md5() - for src in all_source_files: - with open(src, 'rb') as f: - hash_md5.update(f.read()) - - # Select cached build directory name. - source_digest = hash_md5.hexdigest() - build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access - cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}') - - if not os.path.isdir(cached_build_dir): - tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}' - os.makedirs(tmpdir) - for src in all_source_files: - shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src))) - try: - os.replace(tmpdir, cached_build_dir) # atomic - except OSError: - # source directory already exists, delete tmpdir and its contents. - shutil.rmtree(tmpdir) - if not os.path.isdir(cached_build_dir): raise - - # Compile. - cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources] - torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir, - verbose=verbose_build, sources=cached_sources, **build_kwargs) - else: - torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) - - # Load. - module = importlib.import_module(module_name) - - except: - if verbosity == 'brief': - print('Failed!') - raise - - # Print status and add to cache dict. - if verbosity == 'full': - print(f'Done setting up PyTorch plugin "{module_name}".') - elif verbosity == 'brief': - print('Done.') - _cached_plugins[module_name] = module - return module \ No newline at end of file diff --git a/torch_utils/misc.py b/torch_utils/misc.py deleted file mode 100644 index cd512ab8b61ece35d81ec35f43948a843efbbce1..0000000000000000000000000000000000000000 --- a/torch_utils/misc.py +++ /dev/null @@ -1,264 +0,0 @@ -# Copyright (c) SenseTime Research. All rights reserved. - -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -import re -import contextlib -import numpy as np -import torch -import warnings -import dnnlib - -#---------------------------------------------------------------------------- -# Cached construction of constant tensors. Avoids CPU=>GPU copy when the -# same constant is used multiple times. - -_constant_cache = dict() - -def constant(value, shape=None, dtype=None, device=None, memory_format=None): - value = np.asarray(value) - if shape is not None: - shape = tuple(shape) - if dtype is None: - dtype = torch.get_default_dtype() - if device is None: - device = torch.device('cpu') - if memory_format is None: - memory_format = torch.contiguous_format - - key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) - tensor = _constant_cache.get(key, None) - if tensor is None: - tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) - if shape is not None: - tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) - tensor = tensor.contiguous(memory_format=memory_format) - _constant_cache[key] = tensor - return tensor - -#---------------------------------------------------------------------------- -# Replace NaN/Inf with specified numerical values. - -try: - nan_to_num = torch.nan_to_num # 1.8.0a0 -except AttributeError: - def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin - assert isinstance(input, torch.Tensor) - if posinf is None: - posinf = torch.finfo(input.dtype).max - if neginf is None: - neginf = torch.finfo(input.dtype).min - assert nan == 0 - return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) - -#---------------------------------------------------------------------------- -# Symbolic assert. - -try: - symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access -except AttributeError: - symbolic_assert = torch.Assert # 1.7.0 - -#---------------------------------------------------------------------------- -# Context manager to suppress known warnings in torch.jit.trace(). - -class suppress_tracer_warnings(warnings.catch_warnings): - def __enter__(self): - super().__enter__() - warnings.simplefilter('ignore', category=torch.jit.TracerWarning) - return self - -#---------------------------------------------------------------------------- -# Assert that the shape of a tensor matches the given list of integers. -# None indicates that the size of a dimension is allowed to vary. -# Performs symbolic assertion when used in torch.jit.trace(). - -def assert_shape(tensor, ref_shape): - if tensor.ndim != len(ref_shape): - raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') - for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): - if ref_size is None: - pass - elif isinstance(ref_size, torch.Tensor): - with suppress_tracer_warnings(): # as_tensor results are registered as constants - symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') - elif isinstance(size, torch.Tensor): - with suppress_tracer_warnings(): # as_tensor results are registered as constants - symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') - elif size != ref_size: - raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') - -#---------------------------------------------------------------------------- -# Function decorator that calls torch.autograd.profiler.record_function(). - -def profiled_function(fn): - def decorator(*args, **kwargs): - with torch.autograd.profiler.record_function(fn.__name__): - return fn(*args, **kwargs) - decorator.__name__ = fn.__name__ - return decorator - -#---------------------------------------------------------------------------- -# Sampler for torch.utils.data.DataLoader that loops over the dataset -# indefinitely, shuffling items as it goes. - -class InfiniteSampler(torch.utils.data.Sampler): - def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): - assert len(dataset) > 0 - assert num_replicas > 0 - assert 0 <= rank < num_replicas - assert 0 <= window_size <= 1 - super().__init__(dataset) - self.dataset = dataset - self.rank = rank - self.num_replicas = num_replicas - self.shuffle = shuffle - self.seed = seed - self.window_size = window_size - - def __iter__(self): - order = np.arange(len(self.dataset)) - rnd = None - window = 0 - if self.shuffle: - rnd = np.random.RandomState(self.seed) - rnd.shuffle(order) - window = int(np.rint(order.size * self.window_size)) - - idx = 0 - while True: - i = idx % order.size - if idx % self.num_replicas == self.rank: - yield order[i] - if window >= 2: - j = (i - rnd.randint(window)) % order.size - order[i], order[j] = order[j], order[i] - idx += 1 - -#---------------------------------------------------------------------------- -# Utilities for operating with torch.nn.Module parameters and buffers. - -def params_and_buffers(module): - assert isinstance(module, torch.nn.Module) - return list(module.parameters()) + list(module.buffers()) - -def named_params_and_buffers(module): - assert isinstance(module, torch.nn.Module) - return list(module.named_parameters()) + list(module.named_buffers()) - -def copy_params_and_buffers(src_module, dst_module, require_all=False): - assert isinstance(src_module, torch.nn.Module) - assert isinstance(dst_module, torch.nn.Module) - src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)} - for name, tensor in named_params_and_buffers(dst_module): - assert (name in src_tensors) or (not require_all) - if name in src_tensors: - tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) - -#---------------------------------------------------------------------------- -# Context manager for easily enabling/disabling DistributedDataParallel -# synchronization. - -@contextlib.contextmanager -def ddp_sync(module, sync): - assert isinstance(module, torch.nn.Module) - if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): - yield - else: - with module.no_sync(): - yield - -#---------------------------------------------------------------------------- -# Check DistributedDataParallel consistency across processes. - -def check_ddp_consistency(module, ignore_regex=None): - assert isinstance(module, torch.nn.Module) - for name, tensor in named_params_and_buffers(module): - fullname = type(module).__name__ + '.' + name - if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): - continue - tensor = tensor.detach() - other = tensor.clone() - torch.distributed.broadcast(tensor=other, src=0) - assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname - -#---------------------------------------------------------------------------- -# Print summary table of module hierarchy. - -def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): - assert isinstance(module, torch.nn.Module) - assert not isinstance(module, torch.jit.ScriptModule) - assert isinstance(inputs, (tuple, list)) - - # Register hooks. - entries = [] - nesting = [0] - def pre_hook(_mod, _inputs): - nesting[0] += 1 - def post_hook(mod, _inputs, outputs): - nesting[0] -= 1 - if nesting[0] <= max_nesting: - outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] - outputs = [t for t in outputs if isinstance(t, torch.Tensor)] - entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) - hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] - hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] - - # Run module. - outputs = module(*inputs) - for hook in hooks: - hook.remove() - - # Identify unique outputs, parameters, and buffers. - tensors_seen = set() - for e in entries: - e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] - e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] - e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] - tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} - - # Filter out redundant entries. - if skip_redundant: - entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] - - # Construct table. - rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] - rows += [['---'] * len(rows[0])] - param_total = 0 - buffer_total = 0 - submodule_names = {mod: name for name, mod in module.named_modules()} - for e in entries: - name = '' if e.mod is module else submodule_names[e.mod] - param_size = sum(t.numel() for t in e.unique_params) - buffer_size = sum(t.numel() for t in e.unique_buffers) - output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs] - output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] - rows += [[ - name + (':0' if len(e.outputs) >= 2 else ''), - str(param_size) if param_size else '-', - str(buffer_size) if buffer_size else '-', - (output_shapes + ['-'])[0], - (output_dtypes + ['-'])[0], - ]] - for idx in range(1, len(e.outputs)): - rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] - param_total += param_size - buffer_total += buffer_size - rows += [['---'] * len(rows[0])] - rows += [['Total', str(param_total), str(buffer_total), '-', '-']] - - # Print table. - widths = [max(len(cell) for cell in column) for column in zip(*rows)] - print() - for row in rows: - print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) - print() - return outputs - -#---------------------------------------------------------------------------- diff --git a/torch_utils/models.py b/torch_utils/models.py deleted file mode 100644 index 762550239ba6f1e09f4887bf1b27fd421745a589..0000000000000000000000000000000000000000 --- a/torch_utils/models.py +++ /dev/null @@ -1,756 +0,0 @@ -# Copyright (c) SenseTime Research. All rights reserved. - -# https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py - -import math -import random -import functools -import operator - -import torch -from torch import nn -from torch.nn import functional as F -import torch.nn.init as init -from torch.autograd import Function - -from .op_edit import FusedLeakyReLU, fused_leaky_relu, upfirdn2d - - -class PixelNorm(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, input): - return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) - - -def make_kernel(k): - k = torch.tensor(k, dtype=torch.float32) - if k.ndim == 1: - k = k[None, :] * k[:, None] - k /= k.sum() - return k - - -class Upsample(nn.Module): - def __init__(self, kernel, factor=2): - super().__init__() - - self.factor = factor - kernel = make_kernel(kernel) * (factor ** 2) - self.register_buffer("kernel", kernel) - - p = kernel.shape[0] - factor - - pad0 = (p + 1) // 2 + factor - 1 - pad1 = p // 2 - - self.pad = (pad0, pad1) - - def forward(self, input): - out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) - return out - - -class Downsample(nn.Module): - def __init__(self, kernel, factor=2): - super().__init__() - - self.factor = factor - kernel = make_kernel(kernel) - self.register_buffer("kernel", kernel) - - p = kernel.shape[0] - factor - - pad0 = (p + 1) // 2 - pad1 = p // 2 - - self.pad = (pad0, pad1) - - def forward(self, input): - out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) - return out - - -class Blur(nn.Module): - def __init__(self, kernel, pad, upsample_factor=1): - super().__init__() - - kernel = make_kernel(kernel) - - if upsample_factor > 1: - kernel = kernel * (upsample_factor ** 2) - - self.register_buffer("kernel", kernel) - - self.pad = pad - - def forward(self, input): - out = upfirdn2d(input, self.kernel, pad=self.pad) - return out - - -class EqualConv2d(nn.Module): - def __init__( - self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True - ): - super().__init__() - - self.weight = nn.Parameter( - torch.randn(out_channel, in_channel, kernel_size, kernel_size) - ) - self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) - - self.stride = stride - self.padding = padding - - if bias: - self.bias = nn.Parameter(torch.zeros(out_channel)) - - else: - self.bias = None - - def forward(self, input): - out = F.conv2d( - input, - self.weight * self.scale, - bias=self.bias, - stride=self.stride, - padding=self.padding, - ) - return out - - def __repr__(self): - return ( - f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," - f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" - ) - - -class EqualLinear(nn.Module): - def __init__( - self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None - ): - super().__init__() - - self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) - - if bias: - self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) - else: - self.bias = None - - self.activation = activation - - self.scale = (1 / math.sqrt(in_dim)) * lr_mul - self.lr_mul = lr_mul - - def forward(self, input): - if self.activation: - out = F.linear(input, self.weight * self.scale) - out = fused_leaky_relu(out, self.bias * self.lr_mul) - else: - out = F.linear( - input, self.weight * self.scale, bias=self.bias * self.lr_mul - ) - return out - - def __repr__(self): - return ( - f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})" - ) - - -class ScaledLeakyReLU(nn.Module): - def __init__(self, negative_slope=0.2): - super().__init__() - self.negative_slope = negative_slope - - def forward(self, input): - out = F.leaky_relu(input, negative_slope=self.negative_slope) - return out * math.sqrt(2) - - -class ModulatedConv2d(nn.Module): - def __init__( - self, - in_channel, - out_channel, - kernel_size, - style_dim, - demodulate=True, - upsample=False, - downsample=False, - blur_kernel=[1, 3, 3, 1], - ): - super().__init__() - - self.eps = 1e-8 - self.kernel_size = kernel_size - self.in_channel = in_channel - self.out_channel = out_channel - self.upsample = upsample - self.downsample = downsample - - if upsample: - factor = 2 - p = (len(blur_kernel) - factor) - (kernel_size - 1) - pad0 = (p + 1) // 2 + factor - 1 - pad1 = p // 2 + 1 - self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) - - if downsample: - factor = 2 - p = (len(blur_kernel) - factor) + (kernel_size - 1) - pad0 = (p + 1) // 2 - pad1 = p // 2 - self.blur = Blur(blur_kernel, pad=(pad0, pad1)) - - fan_in = in_channel * kernel_size ** 2 - self.scale = 1 / math.sqrt(fan_in) - self.padding = kernel_size // 2 - self.weight = nn.Parameter( - torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) - ) - self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) - self.demodulate = demodulate - - def __repr__(self): - return ( - f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, " - f"upsample={self.upsample}, downsample={self.downsample})" - ) - - def forward(self, input, style): - batch, in_channel, height, width = input.shape - - style = self.modulation(style).view(batch, 1, in_channel, 1, 1) - weight = self.scale * self.weight * style - - if self.demodulate: - demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) - weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) - - weight = weight.view( - batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size - ) - - if self.upsample: - input = input.view(1, batch * in_channel, height, width) - weight = weight.view( - batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size - ) - weight = weight.transpose(1, 2).reshape( - batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size - ) - out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) - _, _, height, width = out.shape - out = out.view(batch, self.out_channel, height, width) - out = self.blur(out) - - elif self.downsample: - input = self.blur(input) - _, _, height, width = input.shape - input = input.view(1, batch * in_channel, height, width) - out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) - _, _, height, width = out.shape - out = out.view(batch, self.out_channel, height, width) - - else: - input = input.view(1, batch * in_channel, height, width) - out = F.conv2d(input, weight, padding=self.padding, groups=batch) - _, _, height, width = out.shape - out = out.view(batch, self.out_channel, height, width) - - return out - - -class NoiseInjection(nn.Module): - def __init__(self): - super().__init__() - self.weight = nn.Parameter(torch.zeros(1)) - - def forward(self, image, noise=None): - if noise is None: - batch, _, height, width = image.shape - noise = image.new_empty(batch, 1, height, width).normal_() - return image + self.weight * noise - - -class ConstantInput(nn.Module): - def __init__(self, channel, size=4): - super().__init__() - self.input = nn.Parameter(torch.randn(1, channel, size, size // 2)) - - def forward(self, input): - batch = input.shape[0] - out = self.input.repeat(batch, 1, 1, 1) - return out - - -class StyledConv(nn.Module): - def __init__( - self, - in_channel, - out_channel, - kernel_size, - style_dim, - upsample=False, - blur_kernel=[1, 3, 3, 1], - demodulate=True, - ): - super().__init__() - self.conv = ModulatedConv2d( - in_channel, - out_channel, - kernel_size, - style_dim, - upsample=upsample, - blur_kernel=blur_kernel, - demodulate=demodulate, - ) - self.noise = NoiseInjection() - self.activate = FusedLeakyReLU(out_channel) - - def forward(self, input, style, noise=None): - out = self.conv(input, style) - out = self.noise(out, noise=noise) - out = self.activate(out) - return out - - -class ToRGB(nn.Module): - def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): - super().__init__() - if upsample: - self.upsample = Upsample(blur_kernel) - - self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) - self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) - - def forward(self, input, style, skip=None): - out = self.conv(input, style) - out = out + self.bias - - if skip is not None: - skip = self.upsample(skip) - out = out + skip - - return out - - -class Generator(nn.Module): - def __init__( - self, - size, - style_dim, - n_mlp, - channel_multiplier=1, - blur_kernel=[1, 3, 3, 1], - lr_mlp=0.01, - small=False, - small_isaac=False, - ): - super().__init__() - - self.size = size - - if small and size > 64: - raise ValueError("small only works for sizes <= 64") - - self.style_dim = style_dim - layers = [PixelNorm()] - - for i in range(n_mlp): - layers.append( - EqualLinear( - style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu" - ) - ) - - self.style = nn.Sequential(*layers) - - if small: - self.channels = { - 4: 64 * channel_multiplier, - 8: 64 * channel_multiplier, - 16: 64 * channel_multiplier, - 32: 64 * channel_multiplier, - 64: 64 * channel_multiplier, - } - elif small_isaac: - self.channels = {4: 256, 8: 256, 16: 256, 32: 256, 64: 128, 128: 128} - else: - self.channels = { - 4: 512, - 8: 512, - 16: 512, - 32: 512, - 64: 256 * channel_multiplier, - 128: 128 * channel_multiplier, - 256: 64 * channel_multiplier, - 512: 32 * channel_multiplier, - 1024: 16 * channel_multiplier, - } - - self.input = ConstantInput(self.channels[4]) - self.conv1 = StyledConv( - self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel - ) - self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) - - self.log_size = int(math.log(size, 2)) - self.num_layers = (self.log_size - 2) * 2 + 1 - - self.convs = nn.ModuleList() - self.upsamples = nn.ModuleList() - self.to_rgbs = nn.ModuleList() - self.noises = nn.Module() - - in_channel = self.channels[4] - - for layer_idx in range(self.num_layers): - res = (layer_idx + 5) // 2 - shape = [1, 1, 2 ** res, 2 ** res // 2] - self.noises.register_buffer( - "noise_{}".format(layer_idx), torch.randn(*shape) - ) - - for i in range(3, self.log_size + 1): - out_channel = self.channels[2 ** i] - - self.convs.append( - StyledConv( - in_channel, - out_channel, - 3, - style_dim, - upsample=True, - blur_kernel=blur_kernel, - ) - ) - - self.convs.append( - StyledConv( - out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel - ) - ) - - self.to_rgbs.append(ToRGB(out_channel, style_dim)) - in_channel = out_channel - - self.n_latent = self.log_size * 2 - 2 - - def make_noise(self): - device = self.input.input.device - - noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2 // 2, device=device)] - - for i in range(3, self.log_size + 1): - for _ in range(2): - noises.append(torch.randn(1, 1, 2 ** i, 2 ** i // 2, device=device)) - - return noises - - def mean_latent(self, n_latent): - latent_in = torch.randn( - n_latent, self.style_dim, device=self.input.input.device - ) - latent = self.style(latent_in).mean(0, keepdim=True) - - return latent - - def get_latent(self, input): - return self.style(input) - - def forward( - self, - styles, - return_latents=False, - return_features=False, - inject_index=None, - truncation=1, - truncation_latent=None, - input_is_latent=False, - noise=None, - randomize_noise=True, - real=False, - ): - if not input_is_latent: - styles = [self.style(s) for s in styles] - if noise is None: - if randomize_noise: - noise = [None] * self.num_layers - else: - noise = [ - getattr(self.noises, "noise_{}".format(i)) - for i in range(self.num_layers) - ] - - if truncation < 1: - # print('truncation_latent: ', truncation_latent.shape) - if not real: #if type(styles) == list: - style_t = [] - for style in styles: - style_t.append( - truncation_latent + truncation * (style - truncation_latent) - ) # (-1.1162e-03-(-1.0914e-01))*0.8+(-1.0914e-01) - styles = style_t - else: # styles are latent (tensor: 1,18,512), for real PTI output - truncation_latent = truncation_latent.repeat(18,1).unsqueeze(0) # (1,512) --> (1,18,512) - styles = torch.add(truncation_latent,torch.mul(torch.sub(styles,truncation_latent),truncation)) - # print('now styles after truncation : ', styles) - #if type(styles) == list and len(styles) < 2: # this if for input as list of [(1,512)] - if not real: - if len(styles) < 2: - inject_index = self.n_latent - if styles[0].ndim < 3: - latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) - else: - latent = styles[0] - elif type(styles) == list: - if inject_index is None: - inject_index = 4 - - latent = styles[0].unsqueeze(0) - if latent.shape[1] == 1: - latent = latent.repeat(1, inject_index, 1) - else: - latent = latent[:, :inject_index, :] - latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) - latent = torch.cat([latent, latent2], 1) - else: # input is tensor of size with torch.Size([1, 18, 512]), for real PTI output - latent = styles - - # print(f'processed latent: {latent.shape}') - - features = {} - out = self.input(latent) - features["out_0"] = out - out = self.conv1(out, latent[:, 0], noise=noise[0]) - features["conv1_0"] = out - - skip = self.to_rgb1(out, latent[:, 1]) - features["skip_0"] = skip - i = 1 - for conv1, conv2, noise1, noise2, to_rgb in zip( - self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs - ): - out = conv1(out, latent[:, i], noise=noise1) - features["conv1_{}".format(i)] = out - out = conv2(out, latent[:, i + 1], noise=noise2) - features["conv2_{}".format(i)] = out - skip = to_rgb(out, latent[:, i + 2], skip) - features["skip_{}".format(i)] = skip - - i += 2 - - image = skip - - if return_latents: - return image, latent - elif return_features: - return image, features - else: - return image, None - - -class ConvLayer(nn.Sequential): - def __init__( - self, - in_channel, - out_channel, - kernel_size, - downsample=False, - blur_kernel=[1, 3, 3, 1], - bias=True, - activate=True, - ): - layers = [] - - if downsample: - factor = 2 - p = (len(blur_kernel) - factor) + (kernel_size - 1) - pad0 = (p + 1) // 2 - pad1 = p // 2 - - layers.append(Blur(blur_kernel, pad=(pad0, pad1))) - - stride = 2 - self.padding = 0 - - else: - stride = 1 - self.padding = kernel_size // 2 - - layers.append( - EqualConv2d( - in_channel, - out_channel, - kernel_size, - padding=self.padding, - stride=stride, - bias=bias and not activate, - ) - ) - - if activate: - if bias: - layers.append(FusedLeakyReLU(out_channel)) - else: - layers.append(ScaledLeakyReLU(0.2)) - - super().__init__(*layers) - - -class ResBlock(nn.Module): - def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): - super().__init__() - - self.conv1 = ConvLayer(in_channel, in_channel, 3) - self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) - - self.skip = ConvLayer( - in_channel, out_channel, 1, downsample=True, activate=False, bias=False - ) - - def forward(self, input): - out = self.conv1(input) - out = self.conv2(out) - - skip = self.skip(input) - out = (out + skip) / math.sqrt(2) - - return out - - -class StyleDiscriminator(nn.Module): - def __init__( - self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], small=False - ): - super().__init__() - - if small: - channels = {4: 64, 8: 64, 16: 64, 32: 64, 64: 64} - - else: - channels = { - 4: 512, - 8: 512, - 16: 512, - 32: 512, - 64: 256 * channel_multiplier, - 128: 128 * channel_multiplier, - 256: 64 * channel_multiplier, - 512: 32 * channel_multiplier, - 1024: 16 * channel_multiplier, - } - - convs = [ConvLayer(3, channels[size], 1)] - - log_size = int(math.log(size, 2)) - in_channel = channels[size] - - for i in range(log_size, 2, -1): - out_channel = channels[2 ** (i - 1)] - - convs.append(ResBlock(in_channel, out_channel, blur_kernel)) - - in_channel = out_channel - - self.convs = nn.Sequential(*convs) - - self.stddev_group = 4 - self.stddev_feat = 1 - - self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) - self.final_linear = nn.Sequential( - EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"), - EqualLinear(channels[4], 1), - ) - - - def forward(self, input): - h = input - h_list = [] - - for index, blocklist in enumerate(self.convs): - h = blocklist(h) - h_list.append(h) - - out = h - batch, channel, height, width = out.shape - group = min(batch, self.stddev_group) - stddev = out.view( - group, -1, self.stddev_feat, channel // self.stddev_feat, height, width - ) - stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) - stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) - stddev = stddev.repeat(group, 1, height, width) - out = torch.cat([out, stddev], 1) - - out = self.final_conv(out) - h_list.append(out) - - out = out.view(batch, -1) - out = self.final_linear(out) - - return out, h_list - - -class StyleEncoder(nn.Module): - def __init__(self, size, w_dim=512): - super().__init__() - - channels = { - 4: 512, - 8: 512, - 16: 512, - 32: 512, - 64: 256, - 128: 128, - 256: 64, - 512: 32, - 1024: 16 - } - - self.w_dim = w_dim - log_size = int(math.log(size, 2)) - convs = [ConvLayer(3, channels[size], 1)] - - in_channel = channels[size] - for i in range(log_size, 2, -1): - out_channel = channels[2 ** (i - 1)] - convs.append(ResBlock(in_channel, out_channel)) - in_channel = out_channel - - convs.append(EqualConv2d(in_channel,2*self.w_dim, 4, padding=0, bias=False)) - - self.convs = nn.Sequential(*convs) - - def forward(self, input): - out = self.convs(input) - # return out.view(len(input), self.n_latents, self.w_dim) - reshaped = out.view(len(input), 2*self.w_dim) - return reshaped[:,:self.w_dim], reshaped[:,self.w_dim:] - -def kaiming_init(m): - if isinstance(m, (nn.Linear, nn.Conv2d)): - init.kaiming_normal_(m.weight) - if m.bias is not None: - m.bias.data.fill_(0) - elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): - m.weight.data.fill_(1) - if m.bias is not None: - m.bias.data.fill_(0) - - -def normal_init(m): - if isinstance(m, (nn.Linear, nn.Conv2d)): - init.normal_(m.weight, 0, 0.02) - if m.bias is not None: - m.bias.data.fill_(0) - elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): - m.weight.data.fill_(1) - if m.bias is not None: - m.bias.data.fill_(0) \ No newline at end of file diff --git a/torch_utils/models_face.py b/torch_utils/models_face.py deleted file mode 100644 index ce3f5d2f3c41206c18a9dba973c8e5999ddf47fd..0000000000000000000000000000000000000000 --- a/torch_utils/models_face.py +++ /dev/null @@ -1,809 +0,0 @@ -# Copyright (c) SenseTime Research. All rights reserved. - -import math -import random -import functools -import operator - -import torch -from torch import nn -from torch.nn import functional as F -import torch.nn.init as init -from torch.autograd import Function - -from .op_edit import FusedLeakyReLU, fused_leaky_relu, upfirdn2d - - -class PixelNorm(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, input): - return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) - - -def make_kernel(k): - k = torch.tensor(k, dtype=torch.float32) - - if k.ndim == 1: - k = k[None, :] * k[:, None] - - k /= k.sum() - - return k - - -class Upsample(nn.Module): - def __init__(self, kernel, factor=2): - super().__init__() - - self.factor = factor - kernel = make_kernel(kernel) * (factor ** 2) - self.register_buffer("kernel", kernel) - - p = kernel.shape[0] - factor - - pad0 = (p + 1) // 2 + factor - 1 - pad1 = p // 2 - - self.pad = (pad0, pad1) - - def forward(self, input): - out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) - - return out - - -class Downsample(nn.Module): - def __init__(self, kernel, factor=2): - super().__init__() - - self.factor = factor - kernel = make_kernel(kernel) - self.register_buffer("kernel", kernel) - - p = kernel.shape[0] - factor - - pad0 = (p + 1) // 2 - pad1 = p // 2 - - self.pad = (pad0, pad1) - - def forward(self, input): - out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) - - return out - - -class Blur(nn.Module): - def __init__(self, kernel, pad, upsample_factor=1): - super().__init__() - - kernel = make_kernel(kernel) - - if upsample_factor > 1: - kernel = kernel * (upsample_factor ** 2) - - self.register_buffer("kernel", kernel) - - self.pad = pad - - def forward(self, input): - out = upfirdn2d(input, self.kernel, pad=self.pad) - - return out - - -class EqualConv2d(nn.Module): - def __init__( - self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True - ): - super().__init__() - - self.weight = nn.Parameter( - torch.randn(out_channel, in_channel, kernel_size, kernel_size) - ) - self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) - - self.stride = stride - self.padding = padding - - if bias: - self.bias = nn.Parameter(torch.zeros(out_channel)) - - else: - self.bias = None - - def forward(self, input): - out = F.conv2d( - input, - self.weight * self.scale, - bias=self.bias, - stride=self.stride, - padding=self.padding, - ) - - return out - - def __repr__(self): - return ( - f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," - f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" - ) - - -class EqualLinear(nn.Module): - def __init__( - self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None - ): - super().__init__() - - self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) - - if bias: - self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) - - else: - self.bias = None - - self.activation = activation - - self.scale = (1 / math.sqrt(in_dim)) * lr_mul - self.lr_mul = lr_mul - - def forward(self, input): - if self.activation: - out = F.linear(input, self.weight * self.scale) - out = fused_leaky_relu(out, self.bias * self.lr_mul) - - else: - out = F.linear( - input, self.weight * self.scale, bias=self.bias * self.lr_mul - ) - - return out - - def __repr__(self): - return ( - f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})" - ) - - -class ScaledLeakyReLU(nn.Module): - def __init__(self, negative_slope=0.2): - super().__init__() - - self.negative_slope = negative_slope - - def forward(self, input): - out = F.leaky_relu(input, negative_slope=self.negative_slope) - - return out * math.sqrt(2) - - -class ModulatedConv2d(nn.Module): - def __init__( - self, - in_channel, - out_channel, - kernel_size, - style_dim, - demodulate=True, - upsample=False, - downsample=False, - blur_kernel=[1, 3, 3, 1], - ): - super().__init__() - - self.eps = 1e-8 - self.kernel_size = kernel_size - self.in_channel = in_channel - self.out_channel = out_channel - self.upsample = upsample - self.downsample = downsample - - if upsample: - factor = 2 - p = (len(blur_kernel) - factor) - (kernel_size - 1) - pad0 = (p + 1) // 2 + factor - 1 - pad1 = p // 2 + 1 - - self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) - - if downsample: - factor = 2 - p = (len(blur_kernel) - factor) + (kernel_size - 1) - pad0 = (p + 1) // 2 - pad1 = p // 2 - - self.blur = Blur(blur_kernel, pad=(pad0, pad1)) - - fan_in = in_channel * kernel_size ** 2 - self.scale = 1 / math.sqrt(fan_in) - self.padding = kernel_size // 2 - - self.weight = nn.Parameter( - torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) - ) - - self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) - - self.demodulate = demodulate - - def __repr__(self): - return ( - f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, " - f"upsample={self.upsample}, downsample={self.downsample})" - ) - - def forward(self, input, style): - batch, in_channel, height, width = input.shape - - style = self.modulation(style).view(batch, 1, in_channel, 1, 1) - weight = self.scale * self.weight * style - - if self.demodulate: - demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) - weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) - - weight = weight.view( - batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size - ) - - if self.upsample: - input = input.view(1, batch * in_channel, height, width) - weight = weight.view( - batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size - ) - weight = weight.transpose(1, 2).reshape( - batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size - ) - out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) - _, _, height, width = out.shape - out = out.view(batch, self.out_channel, height, width) - out = self.blur(out) - - elif self.downsample: - input = self.blur(input) - _, _, height, width = input.shape - input = input.view(1, batch * in_channel, height, width) - out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) - _, _, height, width = out.shape - out = out.view(batch, self.out_channel, height, width) - - else: - input = input.view(1, batch * in_channel, height, width) - out = F.conv2d(input, weight, padding=self.padding, groups=batch) - _, _, height, width = out.shape - out = out.view(batch, self.out_channel, height, width) - - return out - - -class NoiseInjection(nn.Module): - def __init__(self): - super().__init__() - - self.weight = nn.Parameter(torch.zeros(1)) - - def forward(self, image, noise=None): - if noise is None: - batch, _, height, width = image.shape - noise = image.new_empty(batch, 1, height, width).normal_() - - return image + self.weight * noise - - -class ConstantInput(nn.Module): - def __init__(self, channel, size=4): - super().__init__() - - self.input = nn.Parameter(torch.randn(1, channel, size, size)) - - def forward(self, input): - batch = input.shape[0] - out = self.input.repeat(batch, 1, 1, 1) - - return out - - -class StyledConv(nn.Module): - def __init__( - self, - in_channel, - out_channel, - kernel_size, - style_dim, - upsample=False, - blur_kernel=[1, 3, 3, 1], - demodulate=True, - ): - super().__init__() - - self.conv = ModulatedConv2d( - in_channel, - out_channel, - kernel_size, - style_dim, - upsample=upsample, - blur_kernel=blur_kernel, - demodulate=demodulate, - ) - - self.noise = NoiseInjection() - # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) - # self.activate = ScaledLeakyReLU(0.2) - self.activate = FusedLeakyReLU(out_channel) - - def forward(self, input, style, noise=None): - out = self.conv(input, style) - out = self.noise(out, noise=noise) - # out = out + self.bias - out = self.activate(out) - - return out - - -class ToRGB(nn.Module): - def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): - super().__init__() - - if upsample: - self.upsample = Upsample(blur_kernel) - - self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) - self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) - - def forward(self, input, style, skip=None): - out = self.conv(input, style) - out = out + self.bias - - if skip is not None: - skip = self.upsample(skip) - - out = out + skip - - return out - - -class Generator(nn.Module): - def __init__( - self, - size, - style_dim, - n_mlp, - channel_multiplier=1, - blur_kernel=[1, 3, 3, 1], - lr_mlp=0.01, - small=False, - small_isaac=False, - ): - super().__init__() - - self.size = size - - if small and size > 64: - raise ValueError("small only works for sizes <= 64") - - self.style_dim = style_dim - - layers = [PixelNorm()] - - for i in range(n_mlp): - layers.append( - EqualLinear( - style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu" - ) - ) - - self.style = nn.Sequential(*layers) - - if small: - self.channels = { - 4: 64 * channel_multiplier, - 8: 64 * channel_multiplier, - 16: 64 * channel_multiplier, - 32: 64 * channel_multiplier, - 64: 64 * channel_multiplier, - } - elif small_isaac: - self.channels = {4: 256, 8: 256, 16: 256, 32: 256, 64: 128, 128: 128} - else: - self.channels = { - 4: 512, - 8: 512, - 16: 512, - 32: 512, - 64: 256 * channel_multiplier, - 128: 128 * channel_multiplier, - 256: 64 * channel_multiplier, - 512: 32 * channel_multiplier, - 1024: 16 * channel_multiplier, - } - - self.input = ConstantInput(self.channels[4]) - self.conv1 = StyledConv( - self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel - ) - self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) - - self.log_size = int(math.log(size, 2)) - self.num_layers = (self.log_size - 2) * 2 + 1 - - self.convs = nn.ModuleList() - self.upsamples = nn.ModuleList() - self.to_rgbs = nn.ModuleList() - self.noises = nn.Module() - - in_channel = self.channels[4] - - for layer_idx in range(self.num_layers): - res = (layer_idx + 5) // 2 - shape = [1, 1, 2 ** res, 2 ** res] - self.noises.register_buffer( - "noise_{}".format(layer_idx), torch.randn(*shape) - ) - - for i in range(3, self.log_size + 1): - out_channel = self.channels[2 ** i] - - self.convs.append( - StyledConv( - in_channel, - out_channel, - 3, - style_dim, - upsample=True, - blur_kernel=blur_kernel, - ) - ) - - self.convs.append( - StyledConv( - out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel - ) - ) - - self.to_rgbs.append(ToRGB(out_channel, style_dim)) - - in_channel = out_channel - - self.n_latent = self.log_size * 2 - 2 - - def make_noise(self): - device = self.input.input.device - - noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] - - for i in range(3, self.log_size + 1): - for _ in range(2): - noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) - - return noises - - def mean_latent(self, n_latent): - latent_in = torch.randn( - n_latent, self.style_dim, device=self.input.input.device - ) - latent = self.style(latent_in).mean(0, keepdim=True) - - return latent - - def get_latent(self, input): - return self.style(input) - - def forward( - self, - styles, - return_latents=False, - return_features=False, - inject_index=None, - truncation=1, - truncation_latent=None, - input_is_latent=False, - noise=None, - randomize_noise=True, - ): - if not input_is_latent: - # print("haha") - styles = [self.style(s) for s in styles] - if noise is None: - if randomize_noise: - noise = [None] * self.num_layers - else: - noise = [ - getattr(self.noises, "noise_{}".format(i)) - for i in range(self.num_layers) - ] - - if truncation < 1: - style_t = [] - - for style in styles: - style_t.append( - truncation_latent + truncation * (style - truncation_latent) - ) - - styles = style_t - # print(styles) - if len(styles) < 2: - inject_index = self.n_latent - - if styles[0].ndim < 3: - latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) - # print("a") - else: - # print(len(styles)) - latent = styles[0] - # print("b", latent.shape) - - else: - # print("c") - if inject_index is None: - inject_index = 4 - - latent = styles[0].unsqueeze(0) - if latent.shape[1] == 1: - latent = latent.repeat(1, inject_index, 1) - else: - latent = latent[:, :inject_index, :] - latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) - - latent = torch.cat([latent, latent2], 1) - - features = {} - out = self.input(latent) - features["out_0"] = out - out = self.conv1(out, latent[:, 0], noise=noise[0]) - features["conv1_0"] = out - - skip = self.to_rgb1(out, latent[:, 1]) - features["skip_0"] = skip - i = 1 - for conv1, conv2, noise1, noise2, to_rgb in zip( - self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs - ): - out = conv1(out, latent[:, i], noise=noise1) - features["conv1_{}".format(i)] = out - out = conv2(out, latent[:, i + 1], noise=noise2) - features["conv2_{}".format(i)] = out - skip = to_rgb(out, latent[:, i + 2], skip) - features["skip_{}".format(i)] = skip - - i += 2 - - image = skip - - if return_latents: - return image, latent - elif return_features: - return image, features - else: - return image, None - - -class ConvLayer(nn.Sequential): - def __init__( - self, - in_channel, - out_channel, - kernel_size, - downsample=False, - blur_kernel=[1, 3, 3, 1], - bias=True, - activate=True, - ): - layers = [] - - if downsample: - factor = 2 - p = (len(blur_kernel) - factor) + (kernel_size - 1) - pad0 = (p + 1) // 2 - pad1 = p // 2 - - layers.append(Blur(blur_kernel, pad=(pad0, pad1))) - - stride = 2 - self.padding = 0 - - else: - stride = 1 - self.padding = kernel_size // 2 - - layers.append( - EqualConv2d( - in_channel, - out_channel, - kernel_size, - padding=self.padding, - stride=stride, - bias=bias and not activate, - ) - ) - - if activate: - if bias: - layers.append(FusedLeakyReLU(out_channel)) - - else: - layers.append(ScaledLeakyReLU(0.2)) - - super().__init__(*layers) - - -class ResBlock(nn.Module): - def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): - super().__init__() - - self.conv1 = ConvLayer(in_channel, in_channel, 3) - self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) - - self.skip = ConvLayer( - in_channel, out_channel, 1, downsample=True, activate=False, bias=False - ) - - def forward(self, input): - out = self.conv1(input) - out = self.conv2(out) - - skip = self.skip(input) - out = (out + skip) / math.sqrt(2) - - return out - - -class StyleDiscriminator(nn.Module): - def __init__( - self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], small=False - ): - super().__init__() - - if small: - channels = {4: 64, 8: 64, 16: 64, 32: 64, 64: 64} - - else: - channels = { - 4: 512, - 8: 512, - 16: 512, - 32: 512, - 64: 256 * channel_multiplier, - 128: 128 * channel_multiplier, - 256: 64 * channel_multiplier, - 512: 32 * channel_multiplier, - 1024: 16 * channel_multiplier, - } - - convs = [ConvLayer(3, channels[size], 1)] - - log_size = int(math.log(size, 2)) - - in_channel = channels[size] - - for i in range(log_size, 2, -1): - out_channel = channels[2 ** (i - 1)] - - convs.append(ResBlock(in_channel, out_channel, blur_kernel)) - - in_channel = out_channel - - self.convs = nn.Sequential(*convs) - - self.stddev_group = 4 - self.stddev_feat = 1 - - self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) - self.final_linear = nn.Sequential( - EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"), - EqualLinear(channels[4], 1), - ) - -# def forward(self, input): -# out = self.convs(input) - -# batch, channel, height, width = out.shape -# group = min(batch, self.stddev_group) -# stddev = out.view( -# group, -1, self.stddev_feat, channel // self.stddev_feat, height, width -# ) -# stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) -# stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) -# stddev = stddev.repeat(group, 1, height, width) -# out = torch.cat([out, stddev], 1) - -# out = self.final_conv(out) - -# out = out.view(batch, -1) -# out = self.final_linear(out) - -# return out - - def forward(self, input): - h = input - h_list = [] - - for index, blocklist in enumerate(self.convs): - h = blocklist(h) - h_list.append(h) - - out = h - batch, channel, height, width = out.shape - group = min(batch, self.stddev_group) - stddev = out.view( - group, -1, self.stddev_feat, channel // self.stddev_feat, height, width - ) - stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) - stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) - stddev = stddev.repeat(group, 1, height, width) - out = torch.cat([out, stddev], 1) - - out = self.final_conv(out) - h_list.append(out) - - out = out.view(batch, -1) - out = self.final_linear(out) - - return out, h_list - - -class StyleEncoder(nn.Module): - def __init__(self, size, w_dim=512): - super().__init__() - - channels = { - 4: 512, - 8: 512, - 16: 512, - 32: 512, - 64: 256, - 128: 128, - 256: 64, - 512: 32, - 1024: 16 - } - - self.w_dim = w_dim - log_size = int(math.log(size, 2)) - - # self.n_latents = log_size*2 - 2 - - convs = [ConvLayer(3, channels[size], 1)] - - in_channel = channels[size] - for i in range(log_size, 2, -1): - out_channel = channels[2 ** (i - 1)] - convs.append(ResBlock(in_channel, out_channel)) - in_channel = out_channel - - # convs.append(EqualConv2d(in_channel, self.n_latents*self.w_dim, 4, padding=0, bias=False)) - convs.append(EqualConv2d(in_channel,2*self.w_dim, 4, padding=0, bias=False)) - - - self.convs = nn.Sequential(*convs) - - def forward(self, input): - out = self.convs(input) - # return out.view(len(input), self.n_latents, self.w_dim) - reshaped = out.view(len(input), 2*self.w_dim) - return reshaped[:,:self.w_dim], reshaped[:,self.w_dim:] - -def kaiming_init(m): - if isinstance(m, (nn.Linear, nn.Conv2d)): - init.kaiming_normal_(m.weight) - if m.bias is not None: - m.bias.data.fill_(0) - elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): - m.weight.data.fill_(1) - if m.bias is not None: - m.bias.data.fill_(0) - - -def normal_init(m): - if isinstance(m, (nn.Linear, nn.Conv2d)): - init.normal_(m.weight, 0, 0.02) - if m.bias is not None: - m.bias.data.fill_(0) - elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): - m.weight.data.fill_(1) - if m.bias is not None: - m.bias.data.fill_(0) \ No newline at end of file diff --git a/torch_utils/op_edit/__init__.py b/torch_utils/op_edit/__init__.py deleted file mode 100644 index d2a7efe79d871852affd9de7b46f726a7942f218..0000000000000000000000000000000000000000 --- a/torch_utils/op_edit/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) SenseTime Research. All rights reserved. - -from .fused_act import FusedLeakyReLU, fused_leaky_relu -from .upfirdn2d import upfirdn2d diff --git a/torch_utils/op_edit/fused_act.py b/torch_utils/op_edit/fused_act.py deleted file mode 100644 index 138f090bc67b94b363c346cbf405990f1bbdff68..0000000000000000000000000000000000000000 --- a/torch_utils/op_edit/fused_act.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) SenseTime Research. All rights reserved. - -import os - -import torch -from torch import nn -from torch.nn import functional as F -from torch.autograd import Function -from torch.utils.cpp_extension import load - - -module_path = os.path.dirname(__file__) -fused = load( - "fused", - sources=[ - os.path.join(module_path, "fused_bias_act.cpp"), - os.path.join(module_path, "fused_bias_act_kernel.cu"), - ], -) - - -class FusedLeakyReLUFunctionBackward(Function): - @staticmethod - def forward(ctx, grad_output, out, negative_slope, scale): - ctx.save_for_backward(out) - ctx.negative_slope = negative_slope - ctx.scale = scale - - empty = grad_output.new_empty(0) - - grad_input = fused.fused_bias_act( - grad_output, empty, out, 3, 1, negative_slope, scale - ) - - dim = [0] - - if grad_input.ndim > 2: - dim += list(range(2, grad_input.ndim)) - - grad_bias = grad_input.sum(dim).detach() - - return grad_input, grad_bias - - @staticmethod - def backward(ctx, gradgrad_input, gradgrad_bias): - (out,) = ctx.saved_tensors - gradgrad_out = fused.fused_bias_act( - gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale - ) - - return gradgrad_out, None, None, None - - -class FusedLeakyReLUFunction(Function): - @staticmethod - def forward(ctx, input, bias, negative_slope, scale): - empty = input.new_empty(0) - out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) - ctx.save_for_backward(out) - ctx.negative_slope = negative_slope - ctx.scale = scale - - return out - - @staticmethod - def backward(ctx, grad_output): - (out,) = ctx.saved_tensors - - grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( - grad_output, out, ctx.negative_slope, ctx.scale - ) - - return grad_input, grad_bias, None, None - - -class FusedLeakyReLU(nn.Module): - def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): - super().__init__() - - self.bias = nn.Parameter(torch.zeros(channel)) - self.negative_slope = negative_slope - self.scale = scale - - def forward(self, input): - return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) - - -def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): - if input.device.type == "cpu": - rest_dim = [1] * (input.ndim - bias.ndim - 1) - return ( - F.leaky_relu( - input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 - ) - * scale - ) - - else: - return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) diff --git a/torch_utils/op_edit/fused_bias_act.cpp b/torch_utils/op_edit/fused_bias_act.cpp deleted file mode 100644 index a79a3d65b8fb56393c954630ae8ce5a5c8a8bb7d..0000000000000000000000000000000000000000 --- a/torch_utils/op_edit/fused_bias_act.cpp +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) SenseTime Research. All rights reserved. - -#include - - -torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, - int act, int grad, float alpha, float scale); - -#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) - -torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, - int act, int grad, float alpha, float scale) { - CHECK_CUDA(input); - CHECK_CUDA(bias); - - return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); -} \ No newline at end of file diff --git a/torch_utils/op_edit/fused_bias_act_kernel.cu b/torch_utils/op_edit/fused_bias_act_kernel.cu deleted file mode 100644 index 2d72170bfbd766d7f6ccaf9bdd866833a5dad14f..0000000000000000000000000000000000000000 --- a/torch_utils/op_edit/fused_bias_act_kernel.cu +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright (c) SenseTime Research. All rights reserved. - -// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. -// -// This work is made available under the Nvidia Source Code License-NC. -// To view a copy of this license, visit -// https://nvlabs.github.io/stylegan2/license.html - -#include - -#include -#include -#include -#include - -#include -#include - - -template -static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, - int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { - int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; - - scalar_t zero = 0.0; - - for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { - scalar_t x = p_x[xi]; - - if (use_bias) { - x += p_b[(xi / step_b) % size_b]; - } - - scalar_t ref = use_ref ? p_ref[xi] : zero; - - scalar_t y; - - switch (act * 10 + grad) { - default: - case 10: y = x; break; - case 11: y = x; break; - case 12: y = 0.0; break; - - case 30: y = (x > 0.0) ? x : x * alpha; break; - case 31: y = (ref > 0.0) ? x : x * alpha; break; - case 32: y = 0.0; break; - } - - out[xi] = y * scale; - } -} - - -torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, - int act, int grad, float alpha, float scale) { - int curDevice = -1; - cudaGetDevice(&curDevice); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); - - auto x = input.contiguous(); - auto b = bias.contiguous(); - auto ref = refer.contiguous(); - - int use_bias = b.numel() ? 1 : 0; - int use_ref = ref.numel() ? 1 : 0; - - int size_x = x.numel(); - int size_b = b.numel(); - int step_b = 1; - - for (int i = 1 + 1; i < x.dim(); i++) { - step_b *= x.size(i); - } - - int loop_x = 4; - int block_size = 4 * 32; - int grid_size = (size_x - 1) / (loop_x * block_size) + 1; - - auto y = torch::empty_like(x); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { - fused_bias_act_kernel<<>>( - y.data_ptr(), - x.data_ptr(), - b.data_ptr(), - ref.data_ptr(), - act, - grad, - alpha, - scale, - loop_x, - size_x, - step_b, - size_b, - use_bias, - use_ref - ); - }); - - return y; -} \ No newline at end of file diff --git a/torch_utils/op_edit/upfirdn2d.cpp b/torch_utils/op_edit/upfirdn2d.cpp deleted file mode 100644 index ec39812ba6f50386a0b7f5bf95545265ec419930..0000000000000000000000000000000000000000 --- a/torch_utils/op_edit/upfirdn2d.cpp +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) SenseTime Research. All rights reserved. - -#include - - -torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, - int up_x, int up_y, int down_x, int down_y, - int pad_x0, int pad_x1, int pad_y0, int pad_y1); - -#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) - -torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, - int up_x, int up_y, int down_x, int down_y, - int pad_x0, int pad_x1, int pad_y0, int pad_y1) { - CHECK_CUDA(input); - CHECK_CUDA(kernel); - - return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); -} \ No newline at end of file diff --git a/torch_utils/op_edit/upfirdn2d.py b/torch_utils/op_edit/upfirdn2d.py deleted file mode 100644 index 874c09c5e98bee1ace64408aa31ec547dfe695a4..0000000000000000000000000000000000000000 --- a/torch_utils/op_edit/upfirdn2d.py +++ /dev/null @@ -1,202 +0,0 @@ -# Copyright (c) SenseTime Research. All rights reserved. - -import os - -import torch -from torch.nn import functional as F -from torch.autograd import Function -from torch.utils.cpp_extension import load - - -module_path = os.path.dirname(__file__) -upfirdn2d_op = load( - "upfirdn2d", - sources=[ - os.path.join(module_path, "upfirdn2d.cpp"), - os.path.join(module_path, "upfirdn2d_kernel.cu"), - ], -) - - -class UpFirDn2dBackward(Function): - @staticmethod - def forward( - ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size - ): - - up_x, up_y = up - down_x, down_y = down - g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad - - grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) - - grad_input = upfirdn2d_op.upfirdn2d( - grad_output, - grad_kernel, - down_x, - down_y, - up_x, - up_y, - g_pad_x0, - g_pad_x1, - g_pad_y0, - g_pad_y1, - ) - grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) - - ctx.save_for_backward(kernel) - - pad_x0, pad_x1, pad_y0, pad_y1 = pad - - ctx.up_x = up_x - ctx.up_y = up_y - ctx.down_x = down_x - ctx.down_y = down_y - ctx.pad_x0 = pad_x0 - ctx.pad_x1 = pad_x1 - ctx.pad_y0 = pad_y0 - ctx.pad_y1 = pad_y1 - ctx.in_size = in_size - ctx.out_size = out_size - - return grad_input - - @staticmethod - def backward(ctx, gradgrad_input): - (kernel,) = ctx.saved_tensors - - gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) - - gradgrad_out = upfirdn2d_op.upfirdn2d( - gradgrad_input, - kernel, - ctx.up_x, - ctx.up_y, - ctx.down_x, - ctx.down_y, - ctx.pad_x0, - ctx.pad_x1, - ctx.pad_y0, - ctx.pad_y1, - ) - # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) - gradgrad_out = gradgrad_out.view( - ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] - ) - - return gradgrad_out, None, None, None, None, None, None, None, None - - -class UpFirDn2d(Function): - @staticmethod - def forward(ctx, input, kernel, up, down, pad): - up_x, up_y = up - down_x, down_y = down - pad_x0, pad_x1, pad_y0, pad_y1 = pad - - kernel_h, kernel_w = kernel.shape - batch, channel, in_h, in_w = input.shape - ctx.in_size = input.shape - - input = input.reshape(-1, in_h, in_w, 1) - - ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) - - out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 - out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 - ctx.out_size = (out_h, out_w) - - ctx.up = (up_x, up_y) - ctx.down = (down_x, down_y) - ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) - - g_pad_x0 = kernel_w - pad_x0 - 1 - g_pad_y0 = kernel_h - pad_y0 - 1 - g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 - g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 - - ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) - - out = upfirdn2d_op.upfirdn2d( - input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 - ) - # out = out.view(major, out_h, out_w, minor) - out = out.view(-1, channel, out_h, out_w) - - return out - - @staticmethod - def backward(ctx, grad_output): - kernel, grad_kernel = ctx.saved_tensors - - grad_input = UpFirDn2dBackward.apply( - grad_output, - kernel, - grad_kernel, - ctx.up, - ctx.down, - ctx.pad, - ctx.g_pad, - ctx.in_size, - ctx.out_size, - ) - - return grad_input, None, None, None, None - - -def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): - if input.device.type == "cpu": - out = upfirdn2d_native( - input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] - ) - - else: - out = UpFirDn2d.apply( - input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) - ) - - return out - - -def upfirdn2d_native( - input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 -): - _, channel, in_h, in_w = input.shape - input = input.reshape(-1, in_h, in_w, 1) - - _, in_h, in_w, minor = input.shape - kernel_h, kernel_w = kernel.shape - - out = input.view(-1, in_h, 1, in_w, 1, minor) - out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) - out = out.view(-1, in_h * up_y, in_w * up_x, minor) - - out = F.pad( - out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] - ) - out = out[ - :, - max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), - max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), - :, - ] - - out = out.permute(0, 3, 1, 2) - out = out.reshape( - [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] - ) - w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) - out = F.conv2d(out, w) - out = out.reshape( - -1, - minor, - in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, - in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, - ) - out = out.permute(0, 2, 3, 1) - out = out[:, ::down_y, ::down_x, :] - - out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 - out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 - - return out.view(-1, channel, out_h, out_w) diff --git a/torch_utils/op_edit/upfirdn2d_kernel.cu b/torch_utils/op_edit/upfirdn2d_kernel.cu deleted file mode 100644 index f82f113bd489e86b4e5ee6bc40c9c3d75e30aead..0000000000000000000000000000000000000000 --- a/torch_utils/op_edit/upfirdn2d_kernel.cu +++ /dev/null @@ -1,371 +0,0 @@ -// Copyright (c) SenseTime Research. All rights reserved. - -// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. -// -// This work is made available under the Nvidia Source Code License-NC. -// To view a copy of this license, visit -// https://nvlabs.github.io/stylegan2/license.html - -#include - -#include -#include -#include -#include - -#include -#include - -static __host__ __device__ __forceinline__ int floor_div(int a, int b) { - int c = a / b; - - if (c * b > a) { - c--; - } - - return c; -} - -struct UpFirDn2DKernelParams { - int up_x; - int up_y; - int down_x; - int down_y; - int pad_x0; - int pad_x1; - int pad_y0; - int pad_y1; - - int major_dim; - int in_h; - int in_w; - int minor_dim; - int kernel_h; - int kernel_w; - int out_h; - int out_w; - int loop_major; - int loop_x; -}; - -template -__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, - const scalar_t *kernel, - const UpFirDn2DKernelParams p) { - int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; - int out_y = minor_idx / p.minor_dim; - minor_idx -= out_y * p.minor_dim; - int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; - int major_idx_base = blockIdx.z * p.loop_major; - - if (out_x_base >= p.out_w || out_y >= p.out_h || - major_idx_base >= p.major_dim) { - return; - } - - int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; - int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); - int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; - int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; - - for (int loop_major = 0, major_idx = major_idx_base; - loop_major < p.loop_major && major_idx < p.major_dim; - loop_major++, major_idx++) { - for (int loop_x = 0, out_x = out_x_base; - loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { - int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; - int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); - int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; - int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; - - const scalar_t *x_p = - &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + - minor_idx]; - const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; - int x_px = p.minor_dim; - int k_px = -p.up_x; - int x_py = p.in_w * p.minor_dim; - int k_py = -p.up_y * p.kernel_w; - - scalar_t v = 0.0f; - - for (int y = 0; y < h; y++) { - for (int x = 0; x < w; x++) { - v += static_cast(*x_p) * static_cast(*k_p); - x_p += x_px; - k_p += k_px; - } - - x_p += x_py - w * x_px; - k_p += k_py - w * k_px; - } - - out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + - minor_idx] = v; - } - } -} - -template -__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, - const scalar_t *kernel, - const UpFirDn2DKernelParams p) { - const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; - const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; - - __shared__ volatile float sk[kernel_h][kernel_w]; - __shared__ volatile float sx[tile_in_h][tile_in_w]; - - int minor_idx = blockIdx.x; - int tile_out_y = minor_idx / p.minor_dim; - minor_idx -= tile_out_y * p.minor_dim; - tile_out_y *= tile_out_h; - int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; - int major_idx_base = blockIdx.z * p.loop_major; - - if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | - major_idx_base >= p.major_dim) { - return; - } - - for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; - tap_idx += blockDim.x) { - int ky = tap_idx / kernel_w; - int kx = tap_idx - ky * kernel_w; - scalar_t v = 0.0; - - if (kx < p.kernel_w & ky < p.kernel_h) { - v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; - } - - sk[ky][kx] = v; - } - - for (int loop_major = 0, major_idx = major_idx_base; - loop_major < p.loop_major & major_idx < p.major_dim; - loop_major++, major_idx++) { - for (int loop_x = 0, tile_out_x = tile_out_x_base; - loop_x < p.loop_x & tile_out_x < p.out_w; - loop_x++, tile_out_x += tile_out_w) { - int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; - int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; - int tile_in_x = floor_div(tile_mid_x, up_x); - int tile_in_y = floor_div(tile_mid_y, up_y); - - __syncthreads(); - - for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; - in_idx += blockDim.x) { - int rel_in_y = in_idx / tile_in_w; - int rel_in_x = in_idx - rel_in_y * tile_in_w; - int in_x = rel_in_x + tile_in_x; - int in_y = rel_in_y + tile_in_y; - - scalar_t v = 0.0; - - if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { - v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * - p.minor_dim + - minor_idx]; - } - - sx[rel_in_y][rel_in_x] = v; - } - - __syncthreads(); - for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; - out_idx += blockDim.x) { - int rel_out_y = out_idx / tile_out_w; - int rel_out_x = out_idx - rel_out_y * tile_out_w; - int out_x = rel_out_x + tile_out_x; - int out_y = rel_out_y + tile_out_y; - - int mid_x = tile_mid_x + rel_out_x * down_x; - int mid_y = tile_mid_y + rel_out_y * down_y; - int in_x = floor_div(mid_x, up_x); - int in_y = floor_div(mid_y, up_y); - int rel_in_x = in_x - tile_in_x; - int rel_in_y = in_y - tile_in_y; - int kernel_x = (in_x + 1) * up_x - mid_x - 1; - int kernel_y = (in_y + 1) * up_y - mid_y - 1; - - scalar_t v = 0.0; - -#pragma unroll - for (int y = 0; y < kernel_h / up_y; y++) -#pragma unroll - for (int x = 0; x < kernel_w / up_x; x++) - v += sx[rel_in_y + y][rel_in_x + x] * - sk[kernel_y + y * up_y][kernel_x + x * up_x]; - - if (out_x < p.out_w & out_y < p.out_h) { - out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + - minor_idx] = v; - } - } - } - } -} - -torch::Tensor upfirdn2d_op(const torch::Tensor &input, - const torch::Tensor &kernel, int up_x, int up_y, - int down_x, int down_y, int pad_x0, int pad_x1, - int pad_y0, int pad_y1) { - int curDevice = -1; - cudaGetDevice(&curDevice); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); - - UpFirDn2DKernelParams p; - - auto x = input.contiguous(); - auto k = kernel.contiguous(); - - p.major_dim = x.size(0); - p.in_h = x.size(1); - p.in_w = x.size(2); - p.minor_dim = x.size(3); - p.kernel_h = k.size(0); - p.kernel_w = k.size(1); - p.up_x = up_x; - p.up_y = up_y; - p.down_x = down_x; - p.down_y = down_y; - p.pad_x0 = pad_x0; - p.pad_x1 = pad_x1; - p.pad_y0 = pad_y0; - p.pad_y1 = pad_y1; - - p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / - p.down_y; - p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / - p.down_x; - - auto out = - at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); - - int mode = -1; - - int tile_out_h = -1; - int tile_out_w = -1; - - if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && - p.kernel_h <= 4 && p.kernel_w <= 4) { - mode = 1; - tile_out_h = 16; - tile_out_w = 64; - } - - if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && - p.kernel_h <= 3 && p.kernel_w <= 3) { - mode = 2; - tile_out_h = 16; - tile_out_w = 64; - } - - if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && - p.kernel_h <= 4 && p.kernel_w <= 4) { - mode = 3; - tile_out_h = 16; - tile_out_w = 64; - } - - if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && - p.kernel_h <= 2 && p.kernel_w <= 2) { - mode = 4; - tile_out_h = 16; - tile_out_w = 64; - } - - if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && - p.kernel_h <= 4 && p.kernel_w <= 4) { - mode = 5; - tile_out_h = 8; - tile_out_w = 32; - } - - if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && - p.kernel_h <= 2 && p.kernel_w <= 2) { - mode = 6; - tile_out_h = 8; - tile_out_w = 32; - } - - dim3 block_size; - dim3 grid_size; - - if (tile_out_h > 0 && tile_out_w > 0) { - p.loop_major = (p.major_dim - 1) / 16384 + 1; - p.loop_x = 1; - block_size = dim3(32 * 8, 1, 1); - grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, - (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, - (p.major_dim - 1) / p.loop_major + 1); - } else { - p.loop_major = (p.major_dim - 1) / 16384 + 1; - p.loop_x = 4; - block_size = dim3(4, 32, 1); - grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, - (p.out_w - 1) / (p.loop_x * block_size.y) + 1, - (p.major_dim - 1) / p.loop_major + 1); - } - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { - switch (mode) { - case 1: - upfirdn2d_kernel - <<>>(out.data_ptr(), - x.data_ptr(), - k.data_ptr(), p); - - break; - - case 2: - upfirdn2d_kernel - <<>>(out.data_ptr(), - x.data_ptr(), - k.data_ptr(), p); - - break; - - case 3: - upfirdn2d_kernel - <<>>(out.data_ptr(), - x.data_ptr(), - k.data_ptr(), p); - - break; - - case 4: - upfirdn2d_kernel - <<>>(out.data_ptr(), - x.data_ptr(), - k.data_ptr(), p); - - break; - - case 5: - upfirdn2d_kernel - <<>>(out.data_ptr(), - x.data_ptr(), - k.data_ptr(), p); - - break; - - case 6: - upfirdn2d_kernel - <<>>(out.data_ptr(), - x.data_ptr(), - k.data_ptr(), p); - - break; - - default: - upfirdn2d_kernel_large<<>>( - out.data_ptr(), x.data_ptr(), - k.data_ptr(), p); - } - }); - - return out; -} \ No newline at end of file diff --git a/torch_utils/ops/__init__.py b/torch_utils/ops/__init__.py deleted file mode 100644 index 9c46c314cf2ff24fff74d7308dd8cc50767dd870..0000000000000000000000000000000000000000 --- a/torch_utils/ops/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# Copyright (c) SenseTime Research. All rights reserved. - -#empty \ No newline at end of file diff --git a/torch_utils/ops/bias_act.cpp b/torch_utils/ops/bias_act.cpp deleted file mode 100644 index aef47317a3ae018de6ea620060337bcf44b2d649..0000000000000000000000000000000000000000 --- a/torch_utils/ops/bias_act.cpp +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright (c) SenseTime Research. All rights reserved. - -// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -// -// NVIDIA CORPORATION and its licensors retain all intellectual property -// and proprietary rights in and to this software, related documentation -// and any modifications thereto. Any use, reproduction, disclosure or -// distribution of this software and related documentation without an express -// license agreement from NVIDIA CORPORATION is strictly prohibited. - -#include -#include -#include -#include "bias_act.h" - -//------------------------------------------------------------------------ - -static bool has_same_layout(torch::Tensor x, torch::Tensor y) -{ - if (x.dim() != y.dim()) - return false; - for (int64_t i = 0; i < x.dim(); i++) - { - if (x.size(i) != y.size(i)) - return false; - if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) - return false; - } - return true; -} - -//------------------------------------------------------------------------ - -static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) -{ - // Validate arguments. - TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); - TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); - TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); - TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); - TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); - TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); - TORCH_CHECK(b.dim() == 1, "b must have rank 1"); - TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); - TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); - TORCH_CHECK(grad >= 0, "grad must be non-negative"); - - // Validate layout. - TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); - TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); - TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); - TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); - TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); - - // Create output tensor. - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - torch::Tensor y = torch::empty_like(x); - TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); - - // Initialize CUDA kernel parameters. - bias_act_kernel_params p; - p.x = x.data_ptr(); - p.b = (b.numel()) ? b.data_ptr() : NULL; - p.xref = (xref.numel()) ? xref.data_ptr() : NULL; - p.yref = (yref.numel()) ? yref.data_ptr() : NULL; - p.dy = (dy.numel()) ? dy.data_ptr() : NULL; - p.y = y.data_ptr(); - p.grad = grad; - p.act = act; - p.alpha = alpha; - p.gain = gain; - p.clamp = clamp; - p.sizeX = (int)x.numel(); - p.sizeB = (int)b.numel(); - p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; - - // Choose CUDA kernel. - void* kernel; - AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] - { - kernel = choose_bias_act_kernel(p); - }); - TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); - - // Launch CUDA kernel. - p.loopX = 4; - int blockSize = 4 * 32; - int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; - void* args[] = {&p}; - AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); - return y; -} - -//------------------------------------------------------------------------ - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("bias_act", &bias_act); -} - -//------------------------------------------------------------------------ diff --git a/torch_utils/ops/bias_act.cu b/torch_utils/ops/bias_act.cu deleted file mode 100644 index f0fc48475dbceb3476e5a41954a9711d5ade07e1..0000000000000000000000000000000000000000 --- a/torch_utils/ops/bias_act.cu +++ /dev/null @@ -1,175 +0,0 @@ -// Copyright (c) SenseTime Research. All rights reserved. - -// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -// -// NVIDIA CORPORATION and its licensors retain all intellectual property -// and proprietary rights in and to this software, related documentation -// and any modifications thereto. Any use, reproduction, disclosure or -// distribution of this software and related documentation without an express -// license agreement from NVIDIA CORPORATION is strictly prohibited. - -#include -#include "bias_act.h" - -//------------------------------------------------------------------------ -// Helpers. - -template struct InternalType; -template <> struct InternalType { typedef double scalar_t; }; -template <> struct InternalType { typedef float scalar_t; }; -template <> struct InternalType { typedef float scalar_t; }; - -//------------------------------------------------------------------------ -// CUDA kernel. - -template -__global__ void bias_act_kernel(bias_act_kernel_params p) -{ - typedef typename InternalType::scalar_t scalar_t; - int G = p.grad; - scalar_t alpha = (scalar_t)p.alpha; - scalar_t gain = (scalar_t)p.gain; - scalar_t clamp = (scalar_t)p.clamp; - scalar_t one = (scalar_t)1; - scalar_t two = (scalar_t)2; - scalar_t expRange = (scalar_t)80; - scalar_t halfExpRange = (scalar_t)40; - scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; - scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; - - // Loop over elements. - int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; - for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) - { - // Load. - scalar_t x = (scalar_t)((const T*)p.x)[xi]; - scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; - scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; - scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; - scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; - scalar_t yy = (gain != 0) ? yref / gain : 0; - scalar_t y = 0; - - // Apply bias. - ((G == 0) ? x : xref) += b; - - // linear - if (A == 1) - { - if (G == 0) y = x; - if (G == 1) y = x; - } - - // relu - if (A == 2) - { - if (G == 0) y = (x > 0) ? x : 0; - if (G == 1) y = (yy > 0) ? x : 0; - } - - // lrelu - if (A == 3) - { - if (G == 0) y = (x > 0) ? x : x * alpha; - if (G == 1) y = (yy > 0) ? x : x * alpha; - } - - // tanh - if (A == 4) - { - if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } - if (G == 1) y = x * (one - yy * yy); - if (G == 2) y = x * (one - yy * yy) * (-two * yy); - } - - // sigmoid - if (A == 5) - { - if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); - if (G == 1) y = x * yy * (one - yy); - if (G == 2) y = x * yy * (one - yy) * (one - two * yy); - } - - // elu - if (A == 6) - { - if (G == 0) y = (x >= 0) ? x : exp(x) - one; - if (G == 1) y = (yy >= 0) ? x : x * (yy + one); - if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); - } - - // selu - if (A == 7) - { - if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); - if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); - if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); - } - - // softplus - if (A == 8) - { - if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); - if (G == 1) y = x * (one - exp(-yy)); - if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } - } - - // swish - if (A == 9) - { - if (G == 0) - y = (x < -expRange) ? 0 : x / (exp(-x) + one); - else - { - scalar_t c = exp(xref); - scalar_t d = c + one; - if (G == 1) - y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); - else - y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); - yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; - } - } - - // Apply gain. - y *= gain * dy; - - // Clamp. - if (clamp >= 0) - { - if (G == 0) - y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; - else - y = (yref > -clamp & yref < clamp) ? y : 0; - } - - // Store. - ((T*)p.y)[xi] = (T)y; - } -} - -//------------------------------------------------------------------------ -// CUDA kernel selection. - -template void* choose_bias_act_kernel(const bias_act_kernel_params& p) -{ - if (p.act == 1) return (void*)bias_act_kernel; - if (p.act == 2) return (void*)bias_act_kernel; - if (p.act == 3) return (void*)bias_act_kernel; - if (p.act == 4) return (void*)bias_act_kernel; - if (p.act == 5) return (void*)bias_act_kernel; - if (p.act == 6) return (void*)bias_act_kernel; - if (p.act == 7) return (void*)bias_act_kernel; - if (p.act == 8) return (void*)bias_act_kernel; - if (p.act == 9) return (void*)bias_act_kernel; - return NULL; -} - -//------------------------------------------------------------------------ -// Template specializations. - -template void* choose_bias_act_kernel (const bias_act_kernel_params& p); -template void* choose_bias_act_kernel (const bias_act_kernel_params& p); -template void* choose_bias_act_kernel (const bias_act_kernel_params& p); - -//------------------------------------------------------------------------ diff --git a/torch_utils/ops/bias_act.h b/torch_utils/ops/bias_act.h deleted file mode 100644 index d0246aa06c3dcd5919111fdc914136014b9044b5..0000000000000000000000000000000000000000 --- a/torch_utils/ops/bias_act.h +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) SenseTime Research. All rights reserved. - -// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -// -// NVIDIA CORPORATION and its licensors retain all intellectual property -// and proprietary rights in and to this software, related documentation -// and any modifications thereto. Any use, reproduction, disclosure or -// distribution of this software and related documentation without an express -// license agreement from NVIDIA CORPORATION is strictly prohibited. - -//------------------------------------------------------------------------ -// CUDA kernel parameters. - -struct bias_act_kernel_params -{ - const void* x; // [sizeX] - const void* b; // [sizeB] or NULL - const void* xref; // [sizeX] or NULL - const void* yref; // [sizeX] or NULL - const void* dy; // [sizeX] or NULL - void* y; // [sizeX] - - int grad; - int act; - float alpha; - float gain; - float clamp; - - int sizeX; - int sizeB; - int stepB; - int loopX; -}; - -//------------------------------------------------------------------------ -// CUDA kernel selection. - -template void* choose_bias_act_kernel(const bias_act_kernel_params& p); - -//------------------------------------------------------------------------ diff --git a/torch_utils/ops/bias_act.py b/torch_utils/ops/bias_act.py deleted file mode 100644 index 8041208be7680ddeceb1a87a9db9faae7101e7bf..0000000000000000000000000000000000000000 --- a/torch_utils/ops/bias_act.py +++ /dev/null @@ -1,214 +0,0 @@ -# Copyright (c) SenseTime Research. All rights reserved. - -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -"""Custom PyTorch ops for efficient bias and activation.""" - -import os -import warnings -import numpy as np -import torch -import dnnlib -import traceback - -from .. import custom_ops -from .. import misc - -#---------------------------------------------------------------------------- - -activation_funcs = { - 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), - 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), - 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), - 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), - 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), - 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), - 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), - 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), - 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), -} - -#---------------------------------------------------------------------------- - -_inited = False -_plugin = None -_null_tensor = torch.empty([0]) - -def _init(): - global _inited, _plugin - if not _inited: - _inited = True - sources = ['bias_act.cpp', 'bias_act.cu'] - sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] - try: - _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) - except: - warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) - return _plugin is not None - -#---------------------------------------------------------------------------- - -def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): - r"""Fused bias and activation function. - - Adds bias `b` to activation tensor `x`, evaluates activation function `act`, - and scales the result by `gain`. Each of the steps is optional. In most cases, - the fused op is considerably more efficient than performing the same calculation - using standard PyTorch ops. It supports first and second order gradients, - but not third order gradients. - - Args: - x: Input activation tensor. Can be of any shape. - b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type - as `x`. The shape must be known, and it must match the dimension of `x` - corresponding to `dim`. - dim: The dimension in `x` corresponding to the elements of `b`. - The value of `dim` is ignored if `b` is not specified. - act: Name of the activation function to evaluate, or `"linear"` to disable. - Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. - See `activation_funcs` for a full list. `None` is not allowed. - alpha: Shape parameter for the activation function, or `None` to use the default. - gain: Scaling factor for the output tensor, or `None` to use default. - See `activation_funcs` for the default scaling of each activation function. - If unsure, consider specifying 1. - clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable - the clamping (default). - impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). - - Returns: - Tensor of the same shape and datatype as `x`. - """ - assert isinstance(x, torch.Tensor) - assert impl in ['ref', 'cuda'] - if impl == 'cuda' and x.device.type == 'cuda' and _init(): - return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) - return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) - -#---------------------------------------------------------------------------- - -@misc.profiled_function -def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): - """Slow reference implementation of `bias_act()` using standard TensorFlow ops. - """ - assert isinstance(x, torch.Tensor) - assert clamp is None or clamp >= 0 - spec = activation_funcs[act] - alpha = float(alpha if alpha is not None else spec.def_alpha) - gain = float(gain if gain is not None else spec.def_gain) - clamp = float(clamp if clamp is not None else -1) - - # Add bias. - if b is not None: - assert isinstance(b, torch.Tensor) and b.ndim == 1 - assert 0 <= dim < x.ndim - assert b.shape[0] == x.shape[dim] - x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) - - # Evaluate activation function. - alpha = float(alpha) - x = spec.func(x, alpha=alpha) - - # Scale by gain. - gain = float(gain) - if gain != 1: - x = x * gain - - # Clamp. - if clamp >= 0: - x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type - return x - -#---------------------------------------------------------------------------- - -_bias_act_cuda_cache = dict() - -def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): - """Fast CUDA implementation of `bias_act()` using custom ops. - """ - # Parse arguments. - assert clamp is None or clamp >= 0 - spec = activation_funcs[act] - alpha = float(alpha if alpha is not None else spec.def_alpha) - gain = float(gain if gain is not None else spec.def_gain) - clamp = float(clamp if clamp is not None else -1) - - # Lookup from cache. - key = (dim, act, alpha, gain, clamp) - if key in _bias_act_cuda_cache: - return _bias_act_cuda_cache[key] - - # Forward op. - class BiasActCuda(torch.autograd.Function): - @staticmethod - def forward(ctx, x, b): # pylint: disable=arguments-differ - ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format - x = x.contiguous(memory_format=ctx.memory_format) - b = b.contiguous() if b is not None else _null_tensor - y = x - if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: - y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) - ctx.save_for_backward( - x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, - b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, - y if 'y' in spec.ref else _null_tensor) - return y - - @staticmethod - def backward(ctx, dy): # pylint: disable=arguments-differ - dy = dy.contiguous(memory_format=ctx.memory_format) - x, b, y = ctx.saved_tensors - dx = None - db = None - - if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: - dx = dy - if act != 'linear' or gain != 1 or clamp >= 0: - dx = BiasActCudaGrad.apply(dy, x, b, y) - - if ctx.needs_input_grad[1]: - db = dx.sum([i for i in range(dx.ndim) if i != dim]) - - return dx, db - - # Backward op. - class BiasActCudaGrad(torch.autograd.Function): - @staticmethod - def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ - ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format - dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) - ctx.save_for_backward( - dy if spec.has_2nd_grad else _null_tensor, - x, b, y) - return dx - - @staticmethod - def backward(ctx, d_dx): # pylint: disable=arguments-differ - d_dx = d_dx.contiguous(memory_format=ctx.memory_format) - dy, x, b, y = ctx.saved_tensors - d_dy = None - d_x = None - d_b = None - d_y = None - - if ctx.needs_input_grad[0]: - d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) - - if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): - d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) - - if spec.has_2nd_grad and ctx.needs_input_grad[2]: - d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) - - return d_dy, d_x, d_b, d_y - - # Add to cache. - _bias_act_cuda_cache[key] = BiasActCuda - return BiasActCuda - -#---------------------------------------------------------------------------- diff --git a/torch_utils/ops/conv2d_gradfix.py b/torch_utils/ops/conv2d_gradfix.py deleted file mode 100644 index 093036b728336d6f2f593aaea187054a8af8d523..0000000000000000000000000000000000000000 --- a/torch_utils/ops/conv2d_gradfix.py +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright (c) SenseTime Research. All rights reserved. - -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -"""Custom replacement for `torch.nn.functional.conv2d` that supports -arbitrarily high order gradients with zero performance penalty.""" - -import warnings -import contextlib -import torch - -# pylint: disable=redefined-builtin -# pylint: disable=arguments-differ -# pylint: disable=protected-access - -#---------------------------------------------------------------------------- - -enabled = False # Enable the custom op by setting this to true. -weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. - -@contextlib.contextmanager -def no_weight_gradients(): - global weight_gradients_disabled - old = weight_gradients_disabled - weight_gradients_disabled = True - yield - weight_gradients_disabled = old - -#---------------------------------------------------------------------------- - -def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): - if _should_use_custom_op(input): - return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) - return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) - -def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): - if _should_use_custom_op(input): - return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) - return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) - -#---------------------------------------------------------------------------- - -def _should_use_custom_op(input): - assert isinstance(input, torch.Tensor) - if (not enabled) or (not torch.backends.cudnn.enabled): - return False - if input.device.type != 'cuda': - return False - if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): - return True - warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().') - return False - -def _tuple_of_ints(xs, ndim): - xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim - assert len(xs) == ndim - assert all(isinstance(x, int) for x in xs) - return xs - -#---------------------------------------------------------------------------- - -_conv2d_gradfix_cache = dict() - -def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): - # Parse arguments. - ndim = 2 - weight_shape = tuple(weight_shape) - stride = _tuple_of_ints(stride, ndim) - padding = _tuple_of_ints(padding, ndim) - output_padding = _tuple_of_ints(output_padding, ndim) - dilation = _tuple_of_ints(dilation, ndim) - - # Lookup from cache. - key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) - if key in _conv2d_gradfix_cache: - return _conv2d_gradfix_cache[key] - - # Validate arguments. - assert groups >= 1 - assert len(weight_shape) == ndim + 2 - assert all(stride[i] >= 1 for i in range(ndim)) - assert all(padding[i] >= 0 for i in range(ndim)) - assert all(dilation[i] >= 0 for i in range(ndim)) - if not transpose: - assert all(output_padding[i] == 0 for i in range(ndim)) - else: # transpose - assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) - - # Helpers. - common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) - def calc_output_padding(input_shape, output_shape): - if transpose: - return [0, 0] - return [ - input_shape[i + 2] - - (output_shape[i + 2] - 1) * stride[i] - - (1 - 2 * padding[i]) - - dilation[i] * (weight_shape[i + 2] - 1) - for i in range(ndim) - ] - - # Forward & backward. - class Conv2d(torch.autograd.Function): - @staticmethod - def forward(ctx, input, weight, bias): - assert weight.shape == weight_shape - if not transpose: - output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) - else: # transpose - output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) - ctx.save_for_backward(input, weight) - return output - - @staticmethod - def backward(ctx, grad_output): - input, weight = ctx.saved_tensors - grad_input = None - grad_weight = None - grad_bias = None - - if ctx.needs_input_grad[0]: - p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) - grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None) - assert grad_input.shape == input.shape - - if ctx.needs_input_grad[1] and not weight_gradients_disabled: - grad_weight = Conv2dGradWeight.apply(grad_output, input) - assert grad_weight.shape == weight_shape - - if ctx.needs_input_grad[2]: - grad_bias = grad_output.sum([0, 2, 3]) - - return grad_input, grad_weight, grad_bias - - # Gradient with respect to the weights. - class Conv2dGradWeight(torch.autograd.Function): - @staticmethod - def forward(ctx, grad_output, input): - op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight') - flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] - grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) - assert grad_weight.shape == weight_shape - ctx.save_for_backward(grad_output, input) - return grad_weight - - @staticmethod - def backward(ctx, grad2_grad_weight): - grad_output, input = ctx.saved_tensors - grad2_grad_output = None - grad2_input = None - - if ctx.needs_input_grad[0]: - grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) - assert grad2_grad_output.shape == grad_output.shape - - if ctx.needs_input_grad[1]: - p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) - grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None) - assert grad2_input.shape == input.shape - - return grad2_grad_output, grad2_input - - _conv2d_gradfix_cache[key] = Conv2d - return Conv2d - -#---------------------------------------------------------------------------- diff --git a/torch_utils/ops/conv2d_resample.py b/torch_utils/ops/conv2d_resample.py deleted file mode 100644 index 44a0883f731156af72ec19829ef0bfb8026682be..0000000000000000000000000000000000000000 --- a/torch_utils/ops/conv2d_resample.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright (c) SenseTime Research. All rights reserved. - -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -"""2D convolution with optional up/downsampling.""" - -import torch - -from .. import misc -from . import conv2d_gradfix -from . import upfirdn2d -from .upfirdn2d import _parse_padding -from .upfirdn2d import _get_filter_size - -#---------------------------------------------------------------------------- - -def _get_weight_shape(w): - with misc.suppress_tracer_warnings(): # this value will be treated as a constant - shape = [int(sz) for sz in w.shape] - misc.assert_shape(w, shape) - return shape - -#---------------------------------------------------------------------------- - -def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): - """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. - """ - out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) - - # Flip weight if requested. - if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). - w = w.flip([2, 3]) - - # Workaround performance pitfall in cuDNN 8.0.5, triggered when using - # 1x1 kernel + memory_format=channels_last + less than 64 channels. - if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose: - if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: - if out_channels <= 4 and groups == 1: - in_shape = x.shape - x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1]) - x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) - else: - x = x.to(memory_format=torch.contiguous_format) - w = w.to(memory_format=torch.contiguous_format) - x = conv2d_gradfix.conv2d(x, w, groups=groups) - return x.to(memory_format=torch.channels_last) - - # Otherwise => execute using conv2d_gradfix. - op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d - return op(x, w, stride=stride, padding=padding, groups=groups) - -#---------------------------------------------------------------------------- - -@misc.profiled_function -def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): - r"""2D convolution with optional up/downsampling. - - Padding is performed only once at the beginning, not between the operations. - - Args: - x: Input tensor of shape - `[batch_size, in_channels, in_height, in_width]`. - w: Weight tensor of shape - `[out_channels, in_channels//groups, kernel_height, kernel_width]`. - f: Low-pass filter for up/downsampling. Must be prepared beforehand by - calling upfirdn2d.setup_filter(). None = identity (default). - up: Integer upsampling factor (default: 1). - down: Integer downsampling factor (default: 1). - padding: Padding with respect to the upsampled image. Can be a single number - or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` - (default: 0). - groups: Split input channels into N groups (default: 1). - flip_weight: False = convolution, True = correlation (default: True). - flip_filter: False = convolution, True = correlation (default: False). - - Returns: - Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. - """ - # Validate arguments. - assert isinstance(x, torch.Tensor) and (x.ndim == 4) - assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) - assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) - assert isinstance(up, int) and (up >= 1) - assert isinstance(down, int) and (down >= 1) - assert isinstance(groups, int) and (groups >= 1) - out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) - fw, fh = _get_filter_size(f) - px0, px1, py0, py1 = _parse_padding(padding) - - # Adjust padding to account for up/downsampling. - if up > 1: - px0 += (fw + up - 1) // 2 - px1 += (fw - up) // 2 - py0 += (fh + up - 1) // 2 - py1 += (fh - up) // 2 - if down > 1: - px0 += (fw - down + 1) // 2 - px1 += (fw - down) // 2 - py0 += (fh - down + 1) // 2 - py1 += (fh - down) // 2 - - # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. - if kw == 1 and kh == 1 and (down > 1 and up == 1): - x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) - x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) - return x - - # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. - if kw == 1 and kh == 1 and (up > 1 and down == 1): - x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) - x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) - return x - - # Fast path: downsampling only => use strided convolution. - if down > 1 and up == 1: - x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) - x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) - return x - - # Fast path: upsampling with optional downsampling => use transpose strided convolution. - if up > 1: - if groups == 1: - w = w.transpose(0, 1) - else: - w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) - w = w.transpose(1, 2) - w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) - px0 -= kw - 1 - px1 -= kw - up - py0 -= kh - 1 - py1 -= kh - up - pxt = max(min(-px0, -px1), 0) - pyt = max(min(-py0, -py1), 0) - x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) - x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) - if down > 1: - x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) - return x - - # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. - if up == 1 and down == 1: - if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: - return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) - - # Fallback: Generic reference implementation. - x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) - x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) - if down > 1: - x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) - return x - -#---------------------------------------------------------------------------- diff --git a/torch_utils/ops/filtered_lrelu.cpp b/torch_utils/ops/filtered_lrelu.cpp deleted file mode 100644 index 4e253d1f3ffe84e54e667bf61a45dfe66264a73c..0000000000000000000000000000000000000000 --- a/torch_utils/ops/filtered_lrelu.cpp +++ /dev/null @@ -1,300 +0,0 @@ -// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// -// NVIDIA CORPORATION and its licensors retain all intellectual property -// and proprietary rights in and to this software, related documentation -// and any modifications thereto. Any use, reproduction, disclosure or -// distribution of this software and related documentation without an express -// license agreement from NVIDIA CORPORATION is strictly prohibited. - -#include -#include -#include -#include "filtered_lrelu.h" - -//------------------------------------------------------------------------ - -static std::tuple filtered_lrelu( - torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si, - int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns) -{ - // Set CUDA device. - TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - - // Validate arguments. - TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device"); - TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32"); - TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype"); - TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32"); - TORCH_CHECK(x.dim() == 4, "x must be rank 4"); - TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large"); - TORCH_CHECK(x.numel() > 0, "x is empty"); - TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2"); - TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large"); - TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large"); - TORCH_CHECK(fu.numel() > 0, "fu is empty"); - TORCH_CHECK(fd.numel() > 0, "fd is empty"); - TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x"); - TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1"); - - // Figure out how much shared memory is available on the device. - int maxSharedBytes = 0; - AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index())); - int sharedKB = maxSharedBytes >> 10; - - // Populate enough launch parameters to check if a CUDA kernel exists. - filtered_lrelu_kernel_params p; - p.up = up; - p.down = down; - p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter. - p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0); - filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel(p, sharedKB); - if (!test_spec.exec) - { - // No kernel found - return empty tensors and indicate missing kernel with return code of -1. - return std::make_tuple(torch::Tensor(), torch::Tensor(), -1); - } - - // Input/output element size. - int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4; - - // Input sizes. - int64_t xw = (int)x.size(3); - int64_t xh = (int)x.size(2); - int64_t fut_w = (int)fu.size(-1) - 1; - int64_t fut_h = (int)fu.size(0) - 1; - int64_t fdt_w = (int)fd.size(-1) - 1; - int64_t fdt_h = (int)fd.size(0) - 1; - - // Logical size of upsampled buffer. - int64_t cw = xw * up + (px0 + px1) - fut_w; - int64_t ch = xh * up + (py0 + py1) - fut_h; - TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter"); - TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large"); - - // Compute output size and allocate. - int64_t yw = (cw - fdt_w + (down - 1)) / down; - int64_t yh = (ch - fdt_h + (down - 1)) / down; - TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1"); - TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large"); - torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format()); - - // Allocate sign tensor. - torch::Tensor so; - torch::Tensor s = si; - bool readSigns = !!s.numel(); - int64_t sw_active = 0; // Active width of sign tensor. - if (writeSigns) - { - sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements. - int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height. - int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16. - TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large"); - s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous); - } - else if (readSigns) - sw_active = s.size(3) << 2; - - // Validate sign tensor if in use. - if (readSigns || writeSigns) - { - TORCH_CHECK(s.is_contiguous(), "signs must be contiguous"); - TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8"); - TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x"); - TORCH_CHECK(s.dim() == 4, "signs must be rank 4"); - TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x"); - TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large"); - } - - // Populate rest of CUDA kernel parameters. - p.x = x.data_ptr(); - p.y = y.data_ptr(); - p.b = b.data_ptr(); - p.s = (readSigns || writeSigns) ? s.data_ptr() : 0; - p.fu = fu.data_ptr(); - p.fd = fd.data_ptr(); - p.pad0 = make_int2(px0, py0); - p.gain = gain; - p.slope = slope; - p.clamp = clamp; - p.flip = (flip_filters) ? 1 : 0; - p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); - p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); - p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous. - p.sOfs = make_int2(sx, sy); - p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes. - - // x, y, b strides are in bytes. - p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0)); - p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0)); - p.bStride = sz * b.stride(0); - - // fu, fd strides are in elements. - p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0); - p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0); - - // Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those. - bool index64b = false; - if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true; - if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true; - if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true; - if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true; - if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true; - if (s.numel() > INT_MAX) index64b = true; - - // Choose CUDA kernel. - filtered_lrelu_kernel_spec spec = { 0 }; - AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&] - { - if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation. - { - // Choose kernel based on index type, datatype and sign read/write modes. - if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); - else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); - else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); - else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); - else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); - else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); - } - }); - TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists. - - // Launch CUDA kernel. - void* args[] = {&p}; - int bx = spec.numWarps * 32; - int gx = (p.yShape.x - 1) / spec.tileOut.x + 1; - int gy = (p.yShape.y - 1) / spec.tileOut.y + 1; - int gz = p.yShape.z * p.yShape.w; - - // Repeat multiple horizontal tiles in a CTA? - if (spec.xrep) - { - p.tilesXrep = spec.xrep; - p.tilesXdim = gx; - - gx = (gx + p.tilesXrep - 1) / p.tilesXrep; - std::swap(gx, gy); - } - else - { - p.tilesXrep = 0; - p.tilesXdim = 0; - } - - // Launch filter setup kernel. - AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream())); - - // Copy kernels to constant memory. - if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); - else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); - else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); - - // Set cache and shared memory configurations for main kernel. - AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared)); - if (spec.dynamicSharedKB) // Need dynamically allocated shared memory? - AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10)); - AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte)); - - // Launch main kernel. - const int maxSubGz = 65535; // CUDA maximum for block z dimension. - for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big. - { - p.blockZofs = zofs; - int subGz = std::min(maxSubGz, gz - zofs); - AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream())); - } - - // Done. - return std::make_tuple(y, so, 0); -} - -//------------------------------------------------------------------------ - -static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns) -{ - // Set CUDA device. - TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - - // Validate arguments. - TORCH_CHECK(x.dim() == 4, "x must be rank 4"); - TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large"); - TORCH_CHECK(x.numel() > 0, "x is empty"); - TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64"); - - // Output signs if we don't have sign input. - torch::Tensor so; - torch::Tensor s = si; - bool readSigns = !!s.numel(); - if (writeSigns) - { - int64_t sw = x.size(3); - sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing. - s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous); - } - - // Validate sign tensor if in use. - if (readSigns || writeSigns) - { - TORCH_CHECK(s.is_contiguous(), "signs must be contiguous"); - TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8"); - TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x"); - TORCH_CHECK(s.dim() == 4, "signs must be rank 4"); - TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x"); - TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large"); - } - - // Initialize CUDA kernel parameters. - filtered_lrelu_act_kernel_params p; - p.x = x.data_ptr(); - p.s = (readSigns || writeSigns) ? s.data_ptr() : 0; - p.gain = gain; - p.slope = slope; - p.clamp = clamp; - p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); - p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0)); - p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous. - p.sOfs = make_int2(sx, sy); - - // Choose CUDA kernel. - void* func = 0; - AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&] - { - if (writeSigns) - func = choose_filtered_lrelu_act_kernel(); - else if (readSigns) - func = choose_filtered_lrelu_act_kernel(); - else - func = choose_filtered_lrelu_act_kernel(); - }); - TORCH_CHECK(func, "internal error - CUDA kernel not found"); - - // Launch CUDA kernel. - void* args[] = {&p}; - int bx = 128; // 4 warps per block. - - // Logical size of launch = writeSigns ? p.s : p.x - uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x; - uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y; - uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use. - gx = (gx - 1) / bx + 1; - - // Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest. - const uint32_t gmax = 65535; - gy = std::min(gy, gmax); - gz = std::min(gz, gmax); - - // Launch. - AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream())); - return so; -} - -//------------------------------------------------------------------------ - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("filtered_lrelu", &filtered_lrelu); // The whole thing. - m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place. -} - -//------------------------------------------------------------------------ \ No newline at end of file diff --git a/torch_utils/ops/filtered_lrelu.cu b/torch_utils/ops/filtered_lrelu.cu deleted file mode 100644 index 50bad61678bec055dade2b337f12ea395aeb382e..0000000000000000000000000000000000000000 --- a/torch_utils/ops/filtered_lrelu.cu +++ /dev/null @@ -1,1284 +0,0 @@ -// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// -// NVIDIA CORPORATION and its licensors retain all intellectual property -// and proprietary rights in and to this software, related documentation -// and any modifications thereto. Any use, reproduction, disclosure or -// distribution of this software and related documentation without an express -// license agreement from NVIDIA CORPORATION is strictly prohibited. - -#include -#include "filtered_lrelu.h" -#include - -//------------------------------------------------------------------------ -// Helpers. - -enum // Filter modes. -{ - MODE_SUSD = 0, // Separable upsampling, separable downsampling. - MODE_FUSD = 1, // Full upsampling, separable downsampling. - MODE_SUFD = 2, // Separable upsampling, full downsampling. - MODE_FUFD = 3, // Full upsampling, full downsampling. -}; - -template struct InternalType; -template <> struct InternalType -{ - typedef double scalar_t; typedef double2 vec2_t; typedef double4 vec4_t; - __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_double2(0, 0); } - __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_double4(0, 0, 0, 0); } - __device__ __forceinline__ static double clamp(double x, double c) { return fmin(fmax(x, -c), c); } -}; -template <> struct InternalType -{ - typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t; - __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); } - __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); } - __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); } -}; -template <> struct InternalType -{ - typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t; - __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); } - __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); } - __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); } -}; - -#define MIN(A, B) ((A) < (B) ? (A) : (B)) -#define MAX(A, B) ((A) > (B) ? (A) : (B)) -#define CEIL_DIV(A, B) (((B)==1) ? (A) : \ - ((B)==2) ? ((int)((A)+1) >> 1) : \ - ((B)==4) ? ((int)((A)+3) >> 2) : \ - (((A) + ((A) > 0 ? (B) - 1 : 0)) / (B))) - -// This works only up to blocks of size 256 x 256 and for all N that are powers of two. -template __device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i) -{ - if ((N & (N-1)) && N <= 256) - y = (i * ((1<<24)/N + 1)) >> 24; // Assumes N <= 256, i < N*256. - else - y = i/N; - - x = i - y*N; -} - -// Type cast stride before reading it. -template __device__ __forceinline__ T get_stride(const int64_t& x) -{ - return *reinterpret_cast(&x); -} - -//------------------------------------------------------------------------ -// Filters, setup kernel, copying function. - -#define MAX_FILTER_SIZE 32 - -// Combined up/down filter buffers so that transfer can be done with one copy. -__device__ float g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in global memory, written by setup kernel. -__device__ __constant__ float c_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in constant memory, read by main kernel. - -// Accessors to combined buffers to index up/down filters individually. -#define c_fu (c_fbuf) -#define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE) -#define g_fu (g_fbuf) -#define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE) - -// Set up filters into global memory buffer. -static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p) -{ - for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE; idx += blockDim.x) - { - int x, y; - fast_div_mod(x, y, idx); - - int fu_x = p.flip ? x : (p.fuShape.x - 1 - x); - int fu_y = p.flip ? y : (p.fuShape.y - 1 - y); - if (p.fuShape.y > 0) - g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y) ? 0.0f : p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y]; - else - g_fu[idx] = (x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x]; - - int fd_x = p.flip ? x : (p.fdShape.x - 1 - x); - int fd_y = p.flip ? y : (p.fdShape.y - 1 - y); - if (p.fdShape.y > 0) - g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y) ? 0.0f : p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y]; - else - g_fd[idx] = (x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x]; - } -} - -// Host function to copy filters written by setup kernel into constant buffer for main kernel. -template static cudaError_t copy_filters(cudaStream_t stream) -{ - void* src = 0; - cudaError_t err = cudaGetSymbolAddress(&src, g_fbuf); - if (err) return err; - return cudaMemcpyToSymbolAsync(c_fbuf, src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0, cudaMemcpyDeviceToDevice, stream); -} - -//------------------------------------------------------------------------ -// Coordinate spaces: -// - Relative to input tensor: inX, inY, tileInX, tileInY -// - Relative to input tile: relInX, relInY, tileInW, tileInH -// - Relative to upsampled tile: relUpX, relUpY, tileUpW, tileUpH -// - Relative to output tile: relOutX, relOutY, tileOutW, tileOutH -// - Relative to output tensor: outX, outY, tileOutX, tileOutY -// -// Relationships between coordinate spaces: -// - inX = tileInX + relInX -// - inY = tileInY + relInY -// - relUpX = relInX * up + phaseInX -// - relUpY = relInY * up + phaseInY -// - relUpX = relOutX * down -// - relUpY = relOutY * down -// - outX = tileOutX + relOutX -// - outY = tileOutY + relOutY - -extern __shared__ char s_buf_raw[]; // When sharedKB <= 48, allocate shared memory statically inside the kernel, otherwise use the externally allocated shared memory buffer. - -template -static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) -{ - // Check that we don't try to support non-existing filter modes. - static_assert(up == 1 || up == 2 || up == 4, "only up=1, up=2, up=4 scales supported"); - static_assert(down == 1 || down == 2 || down == 4, "only down=1, down=2, down=4 scales supported"); - static_assert(fuSize >= up, "upsampling filter size must be at least upsampling factor"); - static_assert(fdSize >= down, "downsampling filter size must be at least downsampling factor"); - static_assert(fuSize % up == 0, "upsampling filter size must be divisible with upsampling factor"); - static_assert(fdSize % down == 0, "downsampling filter size must be divisible with downsampling factor"); - static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE, "filter size greater than MAX_FILTER_SIZE"); - static_assert(up != 1 || (fuSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "up=1 supported only for 1x1 full filters"); - static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "down=1 supported only for 1x1 full filters"); - static_assert(!(up == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "full filters not supported for up=4"); - static_assert(!(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "full filters not supported for down=4"); - - // Static definitions. - typedef typename InternalType::scalar_t scalar_t; - typedef typename InternalType::vec2_t vec2_t; - typedef typename InternalType::vec4_t vec4_t; - const int tileUpW = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) & ~3; // Upsampled tile width, rounded up to multiple of 4. - const int tileUpH = tileOutH * down + (fdSize - 1) - (down - 1); // Upsampled tile height. - const int tileInW = CEIL_DIV(tileUpW + (fuSize - 1), up); // Input tile width. - const int tileInH = CEIL_DIV(tileUpH + (fuSize - 1), up); // Input tile height. - const int tileUpH_up = CEIL_DIV(tileUpH, up) * up; // Upsampled tile height rounded up to a multiple of up. - const int tileInH_up = CEIL_DIV(tileUpH_up + (fuSize - 1), up); // For allocations only, to avoid shared memory read overruns with up=2 and up=4. - - // Merge 1x1 downsampling into last upsampling step for upf1 and ups2. - const bool downInline = (down == 1) && ((up == 1 && filterMode == MODE_FUFD) || (up == 2 && filterMode == MODE_SUFD)); - - // Sizes of logical buffers. - const int szIn = tileInH_up * tileInW; - const int szUpX = tileInH_up * tileUpW; - const int szUpXY = downInline ? 0 : (tileUpH * tileUpW); - const int szDownX = tileUpH * tileOutW; - - // Sizes for shared memory arrays. - const int s_buf0_size_base = - (filterMode == MODE_SUSD) ? MAX(szIn, szUpXY) : - (filterMode == MODE_FUSD) ? MAX(szIn, szDownX) : - (filterMode == MODE_SUFD) ? MAX(szIn, szUpXY) : - (filterMode == MODE_FUFD) ? szIn : - -1; - const int s_buf1_size_base = - (filterMode == MODE_SUSD) ? MAX(szUpX, szDownX) : - (filterMode == MODE_FUSD) ? szUpXY : - (filterMode == MODE_SUFD) ? szUpX : - (filterMode == MODE_FUFD) ? szUpXY : - -1; - - // Ensure U128 alignment. - const int s_buf0_size = (s_buf0_size_base + 3) & ~3; - const int s_buf1_size = (s_buf1_size_base + 3) & ~3; - - // Check at compile time that we don't use too much shared memory. - static_assert((s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10), "shared memory overflow"); - - // Declare shared memory arrays. - scalar_t* s_buf0; - scalar_t* s_buf1; - if (sharedKB <= 48) - { - // Allocate shared memory arrays here. - __shared__ scalar_t s_buf0_st[(sharedKB > 48) ? (1<<24) : (s_buf0_size + s_buf1_size)]; // Prevent launching if this isn't optimized away when unused. - s_buf0 = s_buf0_st; - s_buf1 = s_buf0 + s_buf0_size; - } - else - { - // Use the dynamically allocated shared memory array. - s_buf0 = (scalar_t*)s_buf_raw; - s_buf1 = s_buf0 + s_buf0_size; - } - - // Pointers to the buffers. - scalar_t* s_tileIn; // Input tile: [relInX * tileInH + relInY] - scalar_t* s_tileUpX; // After horizontal upsampling: [relInY * tileUpW + relUpX] - scalar_t* s_tileUpXY; // After upsampling: [relUpY * tileUpW + relUpX] - scalar_t* s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW + relOutX] - if (filterMode == MODE_SUSD) - { - s_tileIn = s_buf0; - s_tileUpX = s_buf1; - s_tileUpXY = s_buf0; - s_tileDownX = s_buf1; - } - else if (filterMode == MODE_FUSD) - { - s_tileIn = s_buf0; - s_tileUpXY = s_buf1; - s_tileDownX = s_buf0; - } - else if (filterMode == MODE_SUFD) - { - s_tileIn = s_buf0; - s_tileUpX = s_buf1; - s_tileUpXY = s_buf0; - } - else if (filterMode == MODE_FUFD) - { - s_tileIn = s_buf0; - s_tileUpXY = s_buf1; - } - - // Allow large grids in z direction via per-launch offset. - int channelIdx = blockIdx.z + p.blockZofs; - int batchIdx = channelIdx / p.yShape.z; - channelIdx -= batchIdx * p.yShape.z; - - // Offset to output feature map. In bytes. - index_t mapOfsOut = channelIdx * get_stride(p.yStride.z) + batchIdx * get_stride(p.yStride.w); - - // Sign shift amount. - uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6; - - // Inner tile loop. - #pragma unroll 1 - for (int tileIdx = 0; !enableXrep || (tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y)); tileIdx++) - { - // Locate output tile. - int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x; - int tileOutX = tileX * tileOutW; - int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH; - - // Locate input tile. - int tmpX = tileOutX * down - p.pad0.x; - int tmpY = tileOutY * down - p.pad0.y; - int tileInX = CEIL_DIV(tmpX, up); - int tileInY = CEIL_DIV(tmpY, up); - const int phaseInX = tileInX * up - tmpX; - const int phaseInY = tileInY * up - tmpY; - - // Extra sync if input and output buffers are the same and we are not on first tile. - if (enableXrep && tileIdx > 0 && (filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) || (filterMode == MODE_FUFD && downInline))) - __syncthreads(); - - // Load input tile & apply bias. Unrolled. - scalar_t b = (scalar_t)*(const T*)((const char*)p.b + (channelIdx * get_stride(p.bStride))); - index_t mapOfsIn = channelIdx * get_stride(p.xStride.z) + batchIdx * get_stride(p.xStride.w); - int idx = threadIdx.x; - const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock); - #pragma unroll - for (int loop = 0; loop < loopCountIN; loop++) - { - int relInX, relInY; - fast_div_mod(relInX, relInY, idx); - int inX = tileInX + relInX; - int inY = tileInY + relInY; - scalar_t v = 0; - - if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y) - v = (scalar_t)*((const T*)((const char*)p.x + (inX * get_stride(p.xStride.x) + inY * get_stride(p.xStride.y) + mapOfsIn))) + b; - - bool skip = (loop == loopCountIN-1) && (idx >= tileInW * tileInH); - if (!skip) - s_tileIn[idx] = v; - - idx += threadsPerBlock; - } - - if (filterMode == MODE_SUSD || filterMode == MODE_SUFD) // Separable upsampling filter. - { - // Horizontal upsampling. - __syncthreads(); - if (up == 4) - { - for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up) - { - int relUpX0, relInY; - fast_div_mod(relUpX0, relInY, idx); - int relInX0 = relUpX0 / up; - int src0 = relInX0 + tileInW * relInY; - int dst = relInY * tileUpW + relUpX0; - vec4_t v = InternalType::zero_vec4(); - scalar_t a = s_tileIn[src0]; - if (phaseInX == 0) - { - #pragma unroll - for (int step = 0; step < fuSize / up; step++) - { - v.x += a * (scalar_t)c_fu[step * up + 0]; - a = s_tileIn[src0 + step + 1]; - v.y += a * (scalar_t)c_fu[step * up + 3]; - v.z += a * (scalar_t)c_fu[step * up + 2]; - v.w += a * (scalar_t)c_fu[step * up + 1]; - } - } - else if (phaseInX == 1) - { - #pragma unroll - for (int step = 0; step < fuSize / up; step++) - { - v.x += a * (scalar_t)c_fu[step * up + 1]; - v.y += a * (scalar_t)c_fu[step * up + 0]; - a = s_tileIn[src0 + step + 1]; - v.z += a * (scalar_t)c_fu[step * up + 3]; - v.w += a * (scalar_t)c_fu[step * up + 2]; - } - } - else if (phaseInX == 2) - { - #pragma unroll - for (int step = 0; step < fuSize / up; step++) - { - v.x += a * (scalar_t)c_fu[step * up + 2]; - v.y += a * (scalar_t)c_fu[step * up + 1]; - v.z += a * (scalar_t)c_fu[step * up + 0]; - a = s_tileIn[src0 + step + 1]; - v.w += a * (scalar_t)c_fu[step * up + 3]; - } - } - else // (phaseInX == 3) - { - #pragma unroll - for (int step = 0; step < fuSize / up; step++) - { - v.x += a * (scalar_t)c_fu[step * up + 3]; - v.y += a * (scalar_t)c_fu[step * up + 2]; - v.z += a * (scalar_t)c_fu[step * up + 1]; - v.w += a * (scalar_t)c_fu[step * up + 0]; - a = s_tileIn[src0 + step + 1]; - } - } - s_tileUpX[dst+0] = v.x; - s_tileUpX[dst+1] = v.y; - s_tileUpX[dst+2] = v.z; - s_tileUpX[dst+3] = v.w; - } - } - else if (up == 2) - { - bool p0 = (phaseInX == 0); - for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up) - { - int relUpX0, relInY; - fast_div_mod(relUpX0, relInY, idx); - int relInX0 = relUpX0 / up; - int src0 = relInX0 + tileInW * relInY; - int dst = relInY * tileUpW + relUpX0; - vec2_t v = InternalType::zero_vec2(); - scalar_t a = s_tileIn[src0]; - if (p0) // (phaseInX == 0) - { - #pragma unroll - for (int step = 0; step < fuSize / up; step++) - { - v.x += a * (scalar_t)c_fu[step * up + 0]; - a = s_tileIn[src0 + step + 1]; - v.y += a * (scalar_t)c_fu[step * up + 1]; - } - } - else // (phaseInX == 1) - { - #pragma unroll - for (int step = 0; step < fuSize / up; step++) - { - v.x += a * (scalar_t)c_fu[step * up + 1]; - v.y += a * (scalar_t)c_fu[step * up + 0]; - a = s_tileIn[src0 + step + 1]; - } - } - s_tileUpX[dst+0] = v.x; - s_tileUpX[dst+1] = v.y; - } - } - - // Vertical upsampling & nonlinearity. - - __syncthreads(); - int groupMask = 15 << ((threadIdx.x & 31) & ~3); - int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs. - int sShapeMaxY = MIN(p.sShape.y, tileOutY * down + tileUpH); // Avoid out-of-tile sign writes. - if (up == 4) - { - minY -= 3; // Adjust according to block height. - for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x) - { - int relUpX, relInY0; - fast_div_mod(relUpX, relInY0, idx); - int relUpY0 = relInY0 * up; - int src0 = relInY0 * tileUpW + relUpX; - int dst = relUpY0 * tileUpW + relUpX; - vec4_t v = InternalType::zero_vec4(); - - scalar_t a = s_tileUpX[src0]; - if (phaseInY == 0) - { - #pragma unroll - for (int step = 0; step < fuSize / up; step++) - { - v.x += a * (scalar_t)c_fu[step * up + 0]; - a = s_tileUpX[src0 + (step + 1) * tileUpW]; - v.y += a * (scalar_t)c_fu[step * up + 3]; - v.z += a * (scalar_t)c_fu[step * up + 2]; - v.w += a * (scalar_t)c_fu[step * up + 1]; - } - } - else if (phaseInY == 1) - { - #pragma unroll - for (int step = 0; step < fuSize / up; step++) - { - v.x += a * (scalar_t)c_fu[step * up + 1]; - v.y += a * (scalar_t)c_fu[step * up + 0]; - a = s_tileUpX[src0 + (step + 1) * tileUpW]; - v.z += a * (scalar_t)c_fu[step * up + 3]; - v.w += a * (scalar_t)c_fu[step * up + 2]; - } - } - else if (phaseInY == 2) - { - #pragma unroll - for (int step = 0; step < fuSize / up; step++) - { - v.x += a * (scalar_t)c_fu[step * up + 2]; - v.y += a * (scalar_t)c_fu[step * up + 1]; - v.z += a * (scalar_t)c_fu[step * up + 0]; - a = s_tileUpX[src0 + (step + 1) * tileUpW]; - v.w += a * (scalar_t)c_fu[step * up + 3]; - } - } - else // (phaseInY == 3) - { - #pragma unroll - for (int step = 0; step < fuSize / up; step++) - { - v.x += a * (scalar_t)c_fu[step * up + 3]; - v.y += a * (scalar_t)c_fu[step * up + 2]; - v.z += a * (scalar_t)c_fu[step * up + 1]; - v.w += a * (scalar_t)c_fu[step * up + 0]; - a = s_tileUpX[src0 + (step + 1) * tileUpW]; - } - } - - int x = tileOutX * down + relUpX; - int y = tileOutY * down + relUpY0; - int signX = x + p.sOfs.x; - int signY = y + p.sOfs.y; - int signZ = blockIdx.z + p.blockZofs; - int signXb = signX >> 2; - index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); - index_t si1 = si0 + p.sShape.x; - index_t si2 = si0 + p.sShape.x * 2; - index_t si3 = si0 + p.sShape.x * 3; - - v.x *= (scalar_t)((float)up * (float)up * p.gain); - v.y *= (scalar_t)((float)up * (float)up * p.gain); - v.z *= (scalar_t)((float)up * (float)up * p.gain); - v.w *= (scalar_t)((float)up * (float)up * p.gain); - - if (signWrite) - { - if (!enableWriteSkip) - { - // Determine and write signs. - int sx = __float_as_uint(v.x) >> 31 << 0; - int sy = __float_as_uint(v.y) >> 31 << 8; - int sz = __float_as_uint(v.z) >> 31 << 16; - int sw = __float_as_uint(v.w) >> 31 << 24; - if (sx) v.x *= p.slope; - if (sy) v.y *= p.slope; - if (sz) v.z *= p.slope; - if (sw) v.w *= p.slope; - if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } - if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } - if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); } - if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); } - - if ((uint32_t)signXb < p.swLimit && signY >= minY) - { - // Combine signs. - uint32_t s = sx + sy + sw + sz; - s <<= (signX & 3) << 1; - s |= __shfl_xor_sync(groupMask, s, 1); - s |= __shfl_xor_sync(groupMask, s, 2); - - // Write signs. - if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } - if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } - if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); } - if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); } - } - } - else - { - // Determine and write signs. - if ((uint32_t)signXb < p.swLimit && signY >= minY) - { - int sx = __float_as_uint(v.x) >> 31 << 0; - int sy = __float_as_uint(v.y) >> 31 << 8; - int sz = __float_as_uint(v.z) >> 31 << 16; - int sw = __float_as_uint(v.w) >> 31 << 24; - if (sx) v.x *= p.slope; - if (sy) v.y *= p.slope; - if (sz) v.z *= p.slope; - if (sw) v.w *= p.slope; - if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } - if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } - if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); } - if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); } - - // Combine signs. - uint32_t s = sx + sy + sw + sz; - s <<= (signX & 3) << 1; - s |= __shfl_xor_sync(groupMask, s, 1); - s |= __shfl_xor_sync(groupMask, s, 2); - - // Write signs. - if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } - if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } - if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); } - if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); } - } - else - { - // Just compute the values. - if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); - if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); - if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); - if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); - } - } - } - else if (signRead) // Read signs and apply. - { - if ((uint32_t)signXb < p.swLimit) - { - int ss = (signX & 3) << 1; - if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> ss; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; } - if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> ss; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; } - if ((uint32_t)(signY + 2) < p.sShape.y) { int s = p.s[si2] >> ss; if (s & 1) v.z *= p.slope; if (s & 2) v.z = 0.f; } - if ((uint32_t)(signY + 3) < p.sShape.y) { int s = p.s[si3] >> ss; if (s & 1) v.w *= p.slope; if (s & 2) v.w = 0.f; } - } - } - else // Forward pass with no sign write. - { - if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); - if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); - if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); - if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); - } - - s_tileUpXY[dst + 0 * tileUpW] = v.x; - if (relUpY0 + 1 < tileUpH) s_tileUpXY[dst + 1 * tileUpW] = v.y; - if (relUpY0 + 2 < tileUpH) s_tileUpXY[dst + 2 * tileUpW] = v.z; - if (relUpY0 + 3 < tileUpH) s_tileUpXY[dst + 3 * tileUpW] = v.w; - } - } - else if (up == 2) - { - minY -= 1; // Adjust according to block height. - for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x) - { - int relUpX, relInY0; - fast_div_mod(relUpX, relInY0, idx); - int relUpY0 = relInY0 * up; - int src0 = relInY0 * tileUpW + relUpX; - int dst = relUpY0 * tileUpW + relUpX; - vec2_t v = InternalType::zero_vec2(); - - scalar_t a = s_tileUpX[src0]; - if (phaseInY == 0) - { - #pragma unroll - for (int step = 0; step < fuSize / up; step++) - { - v.x += a * (scalar_t)c_fu[step * up + 0]; - a = s_tileUpX[src0 + (step + 1) * tileUpW]; - v.y += a * (scalar_t)c_fu[step * up + 1]; - } - } - else // (phaseInY == 1) - { - #pragma unroll - for (int step = 0; step < fuSize / up; step++) - { - v.x += a * (scalar_t)c_fu[step * up + 1]; - v.y += a * (scalar_t)c_fu[step * up + 0]; - a = s_tileUpX[src0 + (step + 1) * tileUpW]; - } - } - - int x = tileOutX * down + relUpX; - int y = tileOutY * down + relUpY0; - int signX = x + p.sOfs.x; - int signY = y + p.sOfs.y; - int signZ = blockIdx.z + p.blockZofs; - int signXb = signX >> 2; - index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); - index_t si1 = si0 + p.sShape.x; - - v.x *= (scalar_t)((float)up * (float)up * p.gain); - v.y *= (scalar_t)((float)up * (float)up * p.gain); - - if (signWrite) - { - if (!enableWriteSkip) - { - // Determine and write signs. - int sx = __float_as_uint(v.x) >> 31 << 0; - int sy = __float_as_uint(v.y) >> 31 << 8; - if (sx) v.x *= p.slope; - if (sy) v.y *= p.slope; - if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } - if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } - - if ((uint32_t)signXb < p.swLimit && signY >= minY) - { - // Combine signs. - int s = sx + sy; - s <<= signXo; - s |= __shfl_xor_sync(groupMask, s, 1); - s |= __shfl_xor_sync(groupMask, s, 2); - - // Write signs. - if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } - if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } - } - } - else - { - // Determine and write signs. - if ((uint32_t)signXb < p.swLimit && signY >= minY) - { - int sx = __float_as_uint(v.x) >> 31 << 0; - int sy = __float_as_uint(v.y) >> 31 << 8; - if (sx) v.x *= p.slope; - if (sy) v.y *= p.slope; - if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } - if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } - - // Combine signs. - int s = sx + sy; - s <<= signXo; - s |= __shfl_xor_sync(groupMask, s, 1); - s |= __shfl_xor_sync(groupMask, s, 2); - - // Write signs. - if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } - if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } - } - else - { - // Just compute the values. - if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); - if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); - } - } - } - else if (signRead) // Read signs and apply. - { - if ((uint32_t)signXb < p.swLimit) - { - if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> signXo; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; } - if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> signXo; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; } - } - } - else // Forward pass with no sign write. - { - if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); - if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); - } - - if (!downInline) - { - // Write into temporary buffer. - s_tileUpXY[dst] = v.x; - if (relUpY0 < tileUpH - 1) - s_tileUpXY[dst + tileUpW] = v.y; - } - else - { - // Write directly into output buffer. - if ((uint32_t)x < p.yShape.x) - { - int ymax = MIN(p.yShape.y, tileUpH + tileOutY * down); - index_t ofs = x * get_stride(p.yStride.x) + y * get_stride(p.yStride.y) + mapOfsOut; - if ((uint32_t)y + 0 < p.yShape.y) *((T*)((char*)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]); - if ((uint32_t)y + 1 < ymax) *((T*)((char*)p.y + ofs + get_stride(p.yStride.y))) = (T)(v.y * (scalar_t)c_fd[0]); - } - } - } - } - } - else if (filterMode == MODE_FUSD || filterMode == MODE_FUFD) - { - // Full upsampling filter. - - if (up == 2) - { - // 2 x 2-wide. - __syncthreads(); - int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + p.sOfs.y : 0; // Skip already written signs. - for (int idx = threadIdx.x * 4; idx < tileUpW * tileUpH; idx += blockDim.x * 4) - { - int relUpX0, relUpY0; - fast_div_mod(relUpX0, relUpY0, idx); - int relInX0 = CEIL_DIV(relUpX0 - phaseInX, up); - int relInY0 = CEIL_DIV(relUpY0 - phaseInY, up); - int src0 = relInX0 + tileInW * relInY0; - int tap0y = (relInY0 * up + phaseInY - relUpY0); - - #define X_LOOP(TAPY, PX) \ - for (int sx = 0; sx < fuSize / up; sx++) \ - { \ - v.x += a * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \ - v.z += b * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 0) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \ - v.y += a * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \ - v.w += b * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 1) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \ - } - - vec4_t v = InternalType::zero_vec4(); - if (tap0y == 0 && phaseInX == 0) - #pragma unroll - for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; - #pragma unroll - X_LOOP(0, 0) } - if (tap0y == 0 && phaseInX == 1) - #pragma unroll - for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; - #pragma unroll - X_LOOP(0, 1) } - if (tap0y == 1 && phaseInX == 0) - #pragma unroll - for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; - #pragma unroll - X_LOOP(1, 0) } - if (tap0y == 1 && phaseInX == 1) - #pragma unroll - for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; - #pragma unroll - X_LOOP(1, 1) } - - #undef X_LOOP - - int x = tileOutX * down + relUpX0; - int y = tileOutY * down + relUpY0; - int signX = x + p.sOfs.x; - int signY = y + p.sOfs.y; - int signZ = blockIdx.z + p.blockZofs; - int signXb = signX >> 2; - index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); - - v.x *= (scalar_t)((float)up * (float)up * p.gain); - v.y *= (scalar_t)((float)up * (float)up * p.gain); - v.z *= (scalar_t)((float)up * (float)up * p.gain); - v.w *= (scalar_t)((float)up * (float)up * p.gain); - - if (signWrite) - { - if (!enableWriteSkip) - { - // Determine and write signs. - int sx = __float_as_uint(v.x) >> 31; - int sy = __float_as_uint(v.y) >> 31; - int sz = __float_as_uint(v.z) >> 31; - int sw = __float_as_uint(v.w) >> 31; - if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType::clamp(v.x, p.clamp); } - if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType::clamp(v.y, p.clamp); } - if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType::clamp(v.z, p.clamp); } - if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType::clamp(v.w, p.clamp); } - - if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) - { - p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6); - } - } - else - { - // Determine and write signs. - if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) - { - int sx = __float_as_uint(v.x) >> 31; - int sy = __float_as_uint(v.y) >> 31; - int sz = __float_as_uint(v.z) >> 31; - int sw = __float_as_uint(v.w) >> 31; - if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType::clamp(v.x, p.clamp); } - if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType::clamp(v.y, p.clamp); } - if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType::clamp(v.z, p.clamp); } - if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType::clamp(v.w, p.clamp); } - - p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6); - } - else - { - // Just compute the values. - if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); - if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); - if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); - if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); - } - } - } - else if (signRead) // Read sign and apply. - { - if ((uint32_t)signY < p.sShape.y) - { - int s = 0; - if ((uint32_t)signXb < p.swLimit) s = p.s[si]; - if ((uint32_t)signXb + 1 < p.swLimit) s |= p.s[si + 1] << 8; - s >>= (signX & 3) << 1; - if (s & 0x01) v.x *= p.slope; if (s & 0x02) v.x = 0.f; - if (s & 0x04) v.y *= p.slope; if (s & 0x08) v.y = 0.f; - if (s & 0x10) v.z *= p.slope; if (s & 0x20) v.z = 0.f; - if (s & 0x40) v.w *= p.slope; if (s & 0x80) v.w = 0.f; - } - } - else // Forward pass with no sign write. - { - if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); - if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); - if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); - if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); - } - - s_tileUpXY[idx + 0] = v.x; - s_tileUpXY[idx + 1] = v.y; - s_tileUpXY[idx + 2] = v.z; - s_tileUpXY[idx + 3] = v.w; - } - } - else if (up == 1) - { - __syncthreads(); - uint32_t groupMask = 15 << ((threadIdx.x & 31) & ~3); - int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs. - for (int idx = threadIdx.x; idx < tileUpW * tileUpH; idx += blockDim.x) - { - int relUpX0, relUpY0; - fast_div_mod(relUpX0, relUpY0, idx); - scalar_t v = s_tileIn[idx] * (scalar_t)c_fu[0]; // 1x1 filter. - - int x = tileOutX * down + relUpX0; - int y = tileOutY * down + relUpY0; - int signX = x + p.sOfs.x; - int signY = y + p.sOfs.y; - int signZ = blockIdx.z + p.blockZofs; - int signXb = signX >> 2; - index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); - v *= (scalar_t)((float)up * (float)up * p.gain); - - if (signWrite) - { - if (!enableWriteSkip) - { - // Determine and write sign. - uint32_t s = 0; - uint32_t signXbit = (1u << signXo); - if (v < 0.f) - { - s = signXbit; - v *= p.slope; - } - if (fabsf(v) > p.clamp) - { - s = signXbit * 2; - v = InternalType::clamp(v, p.clamp); - } - if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) - { - s += __shfl_xor_sync(groupMask, s, 1); // Coalesce. - s += __shfl_xor_sync(groupMask, s, 2); // Coalesce. - p.s[si] = s; // Write. - } - } - else - { - // Determine and write sign. - if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) - { - uint32_t s = 0; - uint32_t signXbit = (1u << signXo); - if (v < 0.f) - { - s = signXbit; - v *= p.slope; - } - if (fabsf(v) > p.clamp) - { - s = signXbit * 2; - v = InternalType::clamp(v, p.clamp); - } - s += __shfl_xor_sync(groupMask, s, 1); // Coalesce. - s += __shfl_xor_sync(groupMask, s, 2); // Coalesce. - p.s[si] = s; // Write. - } - else - { - // Just compute the value. - if (v < 0.f) v *= p.slope; - v = InternalType::clamp(v, p.clamp); - } - } - } - else if (signRead) - { - // Read sign and apply if within sign tensor bounds. - if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y) - { - int s = p.s[si]; - s >>= signXo; - if (s & 1) v *= p.slope; - if (s & 2) v = 0.f; - } - } - else // Forward pass with no sign write. - { - if (v < 0.f) v *= p.slope; - v = InternalType::clamp(v, p.clamp); - } - - if (!downInline) // Write into temporary buffer. - s_tileUpXY[idx] = v; - else if ((uint32_t)x < p.yShape.x && (uint32_t)y < p.yShape.y) // Write directly into output buffer - *((T*)((char*)p.y + (x * get_stride(p.yStride.x) + y * get_stride(p.yStride.y) + mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]); - } - } - } - - // Downsampling. - if (filterMode == MODE_SUSD || filterMode == MODE_FUSD) - { - // Horizontal downsampling. - __syncthreads(); - if (down == 4 && tileOutW % 4 == 0) - { - // Calculate 4 pixels at a time. - for (int idx = threadIdx.x * 4; idx < tileOutW * tileUpH; idx += blockDim.x * 4) - { - int relOutX0, relUpY; - fast_div_mod(relOutX0, relUpY, idx); - int relUpX0 = relOutX0 * down; - int src0 = relUpY * tileUpW + relUpX0; - vec4_t v = InternalType::zero_vec4(); - #pragma unroll - for (int step = 0; step < fdSize; step++) - { - v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step]; - v.y += s_tileUpXY[src0 + 4 + step] * (scalar_t)c_fd[step]; - v.z += s_tileUpXY[src0 + 8 + step] * (scalar_t)c_fd[step]; - v.w += s_tileUpXY[src0 + 12 + step] * (scalar_t)c_fd[step]; - } - s_tileDownX[idx+0] = v.x; - s_tileDownX[idx+1] = v.y; - s_tileDownX[idx+2] = v.z; - s_tileDownX[idx+3] = v.w; - } - } - else if ((down == 2 || down == 4) && (tileOutW % 2 == 0)) - { - // Calculate 2 pixels at a time. - for (int idx = threadIdx.x * 2; idx < tileOutW * tileUpH; idx += blockDim.x * 2) - { - int relOutX0, relUpY; - fast_div_mod(relOutX0, relUpY, idx); - int relUpX0 = relOutX0 * down; - int src0 = relUpY * tileUpW + relUpX0; - vec2_t v = InternalType::zero_vec2(); - #pragma unroll - for (int step = 0; step < fdSize; step++) - { - v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step]; - v.y += s_tileUpXY[src0 + down + step] * (scalar_t)c_fd[step]; - } - s_tileDownX[idx+0] = v.x; - s_tileDownX[idx+1] = v.y; - } - } - else - { - // Calculate 1 pixel at a time. - for (int idx = threadIdx.x; idx < tileOutW * tileUpH; idx += blockDim.x) - { - int relOutX0, relUpY; - fast_div_mod(relOutX0, relUpY, idx); - int relUpX0 = relOutX0 * down; - int src = relUpY * tileUpW + relUpX0; - scalar_t v = 0.f; - #pragma unroll - for (int step = 0; step < fdSize; step++) - v += s_tileUpXY[src + step] * (scalar_t)c_fd[step]; - s_tileDownX[idx] = v; - } - } - - // Vertical downsampling & store output tile. - __syncthreads(); - for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x) - { - int relOutX, relOutY0; - fast_div_mod(relOutX, relOutY0, idx); - int relUpY0 = relOutY0 * down; - int src0 = relUpY0 * tileOutW + relOutX; - scalar_t v = 0; - #pragma unroll - for (int step = 0; step < fdSize; step++) - v += s_tileDownX[src0 + step * tileOutW] * (scalar_t)c_fd[step]; - - int outX = tileOutX + relOutX; - int outY = tileOutY + relOutY0; - - if (outX < p.yShape.x & outY < p.yShape.y) - *((T*)((char*)p.y + (outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut))) = (T)v; - } - } - else if (filterMode == MODE_SUFD || filterMode == MODE_FUFD) - { - // Full downsampling filter. - if (down == 2) - { - // 2-wide. - __syncthreads(); - for (int idx = threadIdx.x * 2; idx < tileOutW * tileOutH; idx += blockDim.x * 2) - { - int relOutX0, relOutY0; - fast_div_mod(relOutX0, relOutY0, idx); - int relUpX0 = relOutX0 * down; - int relUpY0 = relOutY0 * down; - int src0 = relUpY0 * tileUpW + relUpX0; - vec2_t v = InternalType::zero_vec2(); - #pragma unroll - for (int sy = 0; sy < fdSize; sy++) - #pragma unroll - for (int sx = 0; sx < fdSize; sx++) - { - v.x += s_tileUpXY[src0 + 0 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE]; - v.y += s_tileUpXY[src0 + 2 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE]; - } - - int outX = tileOutX + relOutX0; - int outY = tileOutY + relOutY0; - if ((uint32_t)outY < p.yShape.y) - { - index_t ofs = outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut; - if (outX + 0 < p.yShape.x) *((T*)((char*)p.y + ofs)) = (T)v.x; - if (outX + 1 < p.yShape.x) *((T*)((char*)p.y + ofs + get_stride(p.yStride.x))) = (T)v.y; - } - } - } - else if (down == 1 && !downInline) - { - // Thread per pixel. - __syncthreads(); - for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x) - { - int relOutX0, relOutY0; - fast_div_mod(relOutX0, relOutY0, idx); - scalar_t v = s_tileUpXY[idx] * (scalar_t)c_fd[0]; // 1x1 filter. - - int outX = tileOutX + relOutX0; - int outY = tileOutY + relOutY0; - if ((uint32_t)outX < p.yShape.x && (uint32_t)outY < p.yShape.y) - *((T*)((char*)p.y + (outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut))) = (T)v; - } - } - } - - if (!enableXrep) - break; - } -} - -//------------------------------------------------------------------------ -// Compute activation function and signs for upsampled data tensor, modifying data tensor in-place. Used for accelerating the generic variant. -// Sign tensor is known to be contiguous, and p.x and p.s have the same z, w dimensions. 64-bit indexing is always used. - -template -static __global__ void filtered_lrelu_act_kernel(filtered_lrelu_act_kernel_params p) -{ - typedef typename InternalType::scalar_t scalar_t; - - // Indexing. - int32_t x = threadIdx.x + blockIdx.x * blockDim.x; - int32_t ymax = signWrite ? p.sShape.y : p.xShape.y; - int32_t qmax = p.xShape.z * p.xShape.w; // Combined minibatch*channel maximum index. - - // Loop to accommodate oversized tensors. - for (int32_t q = blockIdx.z; q < qmax; q += gridDim.z) - for (int32_t y = blockIdx.y; y < ymax; y += gridDim.y) - { - // Extract z and w (channel, minibatch index). - int32_t w = q / p.xShape.z; - int32_t z = q - w * p.xShape.z; - - // Choose behavior based on sign read/write mode. - if (signWrite) - { - // Process value if in p.x. - uint32_t s = 0; - if (x < p.xShape.x && y < p.xShape.y) - { - int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w; - T* pv = ((T*)p.x) + ix; - scalar_t v = (scalar_t)(*pv); - - // Gain, LReLU, clamp. - v *= p.gain; - if (v < 0.f) - { - v *= p.slope; - s = 1; // Sign. - } - if (fabsf(v) > p.clamp) - { - v = InternalType::clamp(v, p.clamp); - s = 2; // Clamp. - } - - *pv = (T)v; // Write value. - } - - // Coalesce into threads 0 and 16 of warp. - uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu; - s <<= ((threadIdx.x & 15) << 1); // Shift into place. - s |= __shfl_xor_sync(m, s, 1); // Distribute. - s |= __shfl_xor_sync(m, s, 2); - s |= __shfl_xor_sync(m, s, 4); - s |= __shfl_xor_sync(m, s, 8); - - // Write signs if leader and in p.s. - if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in. - { - uint64_t is = x + p.sShape.x * (y + (int64_t)p.sShape.y * q); // Contiguous. - ((uint32_t*)p.s)[is >> 4] = s; - } - } - else if (signRead) - { - // Process value if in p.x. - if (x < p.xShape.x) // y is always in. - { - int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w; - T* pv = ((T*)p.x) + ix; - scalar_t v = (scalar_t)(*pv); - v *= p.gain; - - // Apply sign buffer offset. - uint32_t sx = x + p.sOfs.x; - uint32_t sy = y + p.sOfs.y; - - // Read and apply signs if we land inside valid region of sign buffer. - if (sx < p.sShape.x && sy < p.sShape.y) - { - uint64_t is = (sx >> 2) + (p.sShape.x >> 2) * (sy + (uint64_t)p.sShape.y * q); // Contiguous. - unsigned char s = p.s[is]; - s >>= (sx & 3) << 1; // Shift into place. - if (s & 1) // Sign? - v *= p.slope; - if (s & 2) // Clamp? - v = 0.f; - } - - *pv = (T)v; // Write value. - } - } - else - { - // Forward pass with no sign write. Process value if in p.x. - if (x < p.xShape.x) // y is always in. - { - int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w; - T* pv = ((T*)p.x) + ix; - scalar_t v = (scalar_t)(*pv); - v *= p.gain; - if (v < 0.f) - v *= p.slope; - if (fabsf(v) > p.clamp) - v = InternalType::clamp(v, p.clamp); - *pv = (T)v; // Write value. - } - } - } -} - -template void* choose_filtered_lrelu_act_kernel(void) -{ - return (void*)filtered_lrelu_act_kernel; -} - -//------------------------------------------------------------------------ -// CUDA kernel selection. - -template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB) -{ - filtered_lrelu_kernel_spec s = { 0 }; - - // Return the first matching kernel. -#define CASE(SH, U, FU, D, FD, MODE, TW, TH, W, XR, WS) \ - if (sharedKB >= SH) \ - if ((p.fuShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_SUFD)) || (p.fuShape.y > 0 && (MODE == MODE_FUSD || MODE == MODE_FUFD))) \ - if ((p.fdShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_FUSD)) || (p.fdShape.y > 0 && (MODE == MODE_SUFD || MODE == MODE_FUFD))) \ - if (p.up == U && p.fuShape.x <= FU && p.fuShape.y <= FU && p.down == D && p.fdShape.x <= FD && p.fdShape.y <= FD) \ - { \ - static_assert((D*TW % 4) == 0, "down * tileWidth must be divisible by 4"); \ - static_assert(FU % U == 0, "upscaling filter size must be multiple of upscaling factor"); \ - static_assert(FD % D == 0, "downscaling filter size must be multiple of downscaling factor"); \ - s.setup = (void*)setup_filters_kernel; \ - s.exec = (void*)filtered_lrelu_kernel; \ - s.tileOut = make_int2(TW, TH); \ - s.numWarps = W; \ - s.xrep = XR; \ - s.dynamicSharedKB = (SH == 48) ? 0 : SH; \ - return s; \ - } - - // Launch parameters for various kernel specializations. - // Small filters must be listed before large filters, otherwise the kernel for larger filter will always match first. - // Kernels that use more shared memory must be listed before those that use less, for the same reason. - - CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/1,1, /*mode*/MODE_FUFD, /*tw,th,warps,xrep,wskip*/64, 178, 32, 0, 0) // 1t-upf1-downf1 - CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/152, 95, 16, 0, 0) // 4t-ups2-downf1 - CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 22, 16, 0, 0) // 4t-upf1-downs2 - CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 29, 16, 11, 0) // 4t-ups2-downs2 - CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/60, 28, 16, 0, 0) // 4t-upf2-downs2 - CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 28, 16, 0, 0) // 4t-ups2-downf2 - CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 31, 16, 11, 0) // 4t-ups4-downs2 - CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 36, 16, 0, 0) // 4t-ups4-downf2 - CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 22, 16, 12, 0) // 4t-ups2-downs4 - CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/29, 15, 16, 0, 0) // 4t-upf2-downs4 - CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/96, 150, 28, 0, 0) // 6t-ups2-downf1 - CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 35, 24, 0, 0) // 6t-upf1-downs2 - CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 16, 10, 0) // 6t-ups2-downs2 - CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/58, 28, 24, 8, 0) // 6t-upf2-downs2 - CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/52, 28, 16, 0, 0) // 6t-ups2-downf2 - CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 51, 16, 5, 0) // 6t-ups4-downs2 - CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 56, 16, 6, 0) // 6t-ups4-downf2 - CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 18, 16, 12, 0) // 6t-ups2-downs4 - CASE(/*sharedKB*/96, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 31, 32, 6, 0) // 6t-upf2-downs4 96kB - CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 13, 24, 0, 0) // 6t-upf2-downs4 - CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/148, 89, 24, 0, 0) // 8t-ups2-downf1 - CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 31, 16, 5, 0) // 8t-upf1-downs2 - CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 41, 16, 9, 0) // 8t-ups2-downs2 - CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 26, 24, 0, 0) // 8t-upf2-downs2 - CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 40, 16, 0, 0) // 8t-ups2-downf2 - CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 24, 5, 0) // 8t-ups4-downs2 - CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 50, 16, 0, 0) // 8t-ups4-downf2 - CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/24, 24, 32, 12, 1) // 8t-ups2-downs4 96kB - CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 13, 16, 10, 1) // 8t-ups2-downs4 - CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 28, 28, 4, 0) // 8t-upf2-downs4 96kB - CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 10, 24, 0, 0) // 8t-upf2-downs4 - - #undef CASE - return s; // No kernel found. -} - -//------------------------------------------------------------------------ \ No newline at end of file diff --git a/torch_utils/ops/filtered_lrelu.h b/torch_utils/ops/filtered_lrelu.h deleted file mode 100644 index 524c804122a2582e20e2e4e9c49267e1a1b6db60..0000000000000000000000000000000000000000 --- a/torch_utils/ops/filtered_lrelu.h +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// -// NVIDIA CORPORATION and its licensors retain all intellectual property -// and proprietary rights in and to this software, related documentation -// and any modifications thereto. Any use, reproduction, disclosure or -// distribution of this software and related documentation without an express -// license agreement from NVIDIA CORPORATION is strictly prohibited. - -#include - -//------------------------------------------------------------------------ -// CUDA kernel parameters. - -struct filtered_lrelu_kernel_params -{ - // These parameters decide which kernel to use. - int up; // upsampling ratio (1, 2, 4) - int down; // downsampling ratio (1, 2, 4) - int2 fuShape; // [size, 1] | [size, size] - int2 fdShape; // [size, 1] | [size, size] - - int _dummy; // Alignment. - - // Rest of the parameters. - const void* x; // Input tensor. - void* y; // Output tensor. - const void* b; // Bias tensor. - unsigned char* s; // Sign tensor in/out. NULL if unused. - const float* fu; // Upsampling filter. - const float* fd; // Downsampling filter. - - int2 pad0; // Left/top padding. - float gain; // Additional gain factor. - float slope; // Leaky ReLU slope on negative side. - float clamp; // Clamp after nonlinearity. - int flip; // Filter kernel flip for gradient computation. - - int tilesXdim; // Original number of horizontal output tiles. - int tilesXrep; // Number of horizontal tiles per CTA. - int blockZofs; // Block z offset to support large minibatch, channel dimensions. - - int4 xShape; // [width, height, channel, batch] - int4 yShape; // [width, height, channel, batch] - int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused. - int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. - int swLimit; // Active width of sign tensor in bytes. - - longlong4 xStride; // Strides of all tensors except signs, same component order as shapes. - longlong4 yStride; // - int64_t bStride; // - longlong3 fuStride; // - longlong3 fdStride; // -}; - -struct filtered_lrelu_act_kernel_params -{ - void* x; // Input/output, modified in-place. - unsigned char* s; // Sign tensor in/out. NULL if unused. - - float gain; // Additional gain factor. - float slope; // Leaky ReLU slope on negative side. - float clamp; // Clamp after nonlinearity. - - int4 xShape; // [width, height, channel, batch] - longlong4 xStride; // Input/output tensor strides, same order as in shape. - int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused. - int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. -}; - -//------------------------------------------------------------------------ -// CUDA kernel specialization. - -struct filtered_lrelu_kernel_spec -{ - void* setup; // Function for filter kernel setup. - void* exec; // Function for main operation. - int2 tileOut; // Width/height of launch tile. - int numWarps; // Number of warps per thread block, determines launch block size. - int xrep; // For processing multiple horizontal tiles per thread block. - int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants. -}; - -//------------------------------------------------------------------------ -// CUDA kernel selection. - -template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); -template void* choose_filtered_lrelu_act_kernel(void); -template cudaError_t copy_filters(cudaStream_t stream); - -//------------------------------------------------------------------------ \ No newline at end of file diff --git a/torch_utils/ops/filtered_lrelu.py b/torch_utils/ops/filtered_lrelu.py deleted file mode 100644 index f5e3748fb725884b18b7e8119f569722b5bbe67f..0000000000000000000000000000000000000000 --- a/torch_utils/ops/filtered_lrelu.py +++ /dev/null @@ -1,282 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -import os -import numpy as np -import torch -import warnings - -from .. import custom_ops -from .. import misc -from . import upfirdn2d -from . import bias_act - -#---------------------------------------------------------------------------- - -_plugin = None - -def _init(): - global _plugin - if _plugin is None: - - # sources=['filtered_lrelu.h', 'filtered_lrelu.cu', 'filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'] - # sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] - # try: - # _plugin = custom_ops.get_plugin('filtered_lrelu_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler']) - # except: - # warnings.warn('Failed to build CUDA kernels for filtered_lrelu_plugin. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) - - _plugin = custom_ops.get_plugin_v3( - module_name='filtered_lrelu_plugin', - sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'], - headers=['filtered_lrelu.h', 'filtered_lrelu.cu'], - source_dir=os.path.dirname(__file__), - extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'], - ) - return True - -def _get_filter_size(f): - if f is None: - return 1, 1 - assert isinstance(f, torch.Tensor) - assert 1 <= f.ndim <= 2 - return f.shape[-1], f.shape[0] # width, height - -def _parse_padding(padding): - if isinstance(padding, int): - padding = [padding, padding] - assert isinstance(padding, (list, tuple)) - assert all(isinstance(x, (int, np.integer)) for x in padding) - padding = [int(x) for x in padding] - if len(padding) == 2: - px, py = padding - padding = [px, px, py, py] - px0, px1, py0, py1 = padding - return px0, px1, py0, py1 - -#---------------------------------------------------------------------------- - -def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'): - r"""Filtered leaky ReLU for a batch of 2D images. - - Performs the following sequence of operations for each channel: - - 1. Add channel-specific bias if provided (`b`). - - 2. Upsample the image by inserting N-1 zeros after each pixel (`up`). - - 3. Pad the image with the specified number of zeros on each side (`padding`). - Negative padding corresponds to cropping the image. - - 4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it - so that the footprint of all output pixels lies within the input image. - - 5. Multiply each value by the provided gain factor (`gain`). - - 6. Apply leaky ReLU activation function to each value. - - 7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided. - - 8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking - it so that the footprint of all output pixels lies within the input image. - - 9. Downsample the image by keeping every Nth pixel (`down`). - - The fused op is considerably more efficient than performing the same calculation - using standard PyTorch ops. It supports gradients of arbitrary order. - - Args: - x: Float32/float16/float64 input tensor of the shape - `[batch_size, num_channels, in_height, in_width]`. - fu: Float32 upsampling FIR filter of the shape - `[filter_height, filter_width]` (non-separable), - `[filter_taps]` (separable), or - `None` (identity). - fd: Float32 downsampling FIR filter of the shape - `[filter_height, filter_width]` (non-separable), - `[filter_taps]` (separable), or - `None` (identity). - b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type - as `x`. The length of vector must must match the channel dimension of `x`. - up: Integer upsampling factor (default: 1). - down: Integer downsampling factor. (default: 1). - padding: Padding with respect to the upsampled image. Can be a single number - or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` - (default: 0). - gain: Overall scaling factor for signal magnitude (default: sqrt(2)). - slope: Slope on the negative side of leaky ReLU (default: 0.2). - clamp: Maximum magnitude for leaky ReLU output (default: None). - flip_filter: False = convolution, True = correlation (default: False). - impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). - - Returns: - Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. - """ - assert isinstance(x, torch.Tensor) - assert impl in ['ref', 'cuda'] - if impl == 'cuda' and x.device.type == 'cuda' and _init(): - return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0) - return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter) - -#---------------------------------------------------------------------------- - -@misc.profiled_function -def _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): - """Slow and memory-inefficient reference implementation of `filtered_lrelu()` using - existing `upfirdn2n()` and `bias_act()` ops. - """ - assert isinstance(x, torch.Tensor) and x.ndim == 4 - fu_w, fu_h = _get_filter_size(fu) - fd_w, fd_h = _get_filter_size(fd) - if b is not None: - assert isinstance(b, torch.Tensor) and b.dtype == x.dtype - misc.assert_shape(b, [x.shape[1]]) - assert isinstance(up, int) and up >= 1 - assert isinstance(down, int) and down >= 1 - px0, px1, py0, py1 = _parse_padding(padding) - assert gain == float(gain) and gain > 0 - assert slope == float(slope) and slope >= 0 - assert clamp is None or (clamp == float(clamp) and clamp >= 0) - - # Calculate output size. - batch_size, channels, in_h, in_w = x.shape - in_dtype = x.dtype - out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down - out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down - - # Compute using existing ops. - x = bias_act.bias_act(x=x, b=b) # Apply bias. - x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample. - x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp. - x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) # Downsample. - - # Check output shape & dtype. - misc.assert_shape(x, [batch_size, channels, out_h, out_w]) - assert x.dtype == in_dtype - return x - -#---------------------------------------------------------------------------- - -_filtered_lrelu_cuda_cache = dict() - -def _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): - """Fast CUDA implementation of `filtered_lrelu()` using custom ops. - """ - assert isinstance(up, int) and up >= 1 - assert isinstance(down, int) and down >= 1 - px0, px1, py0, py1 = _parse_padding(padding) - assert gain == float(gain) and gain > 0 - gain = float(gain) - assert slope == float(slope) and slope >= 0 - slope = float(slope) - assert clamp is None or (clamp == float(clamp) and clamp >= 0) - clamp = float(clamp if clamp is not None else 'inf') - - # Lookup from cache. - key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter) - if key in _filtered_lrelu_cuda_cache: - return _filtered_lrelu_cuda_cache[key] - - # Forward op. - class FilteredLReluCuda(torch.autograd.Function): - @staticmethod - def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ - assert isinstance(x, torch.Tensor) and x.ndim == 4 - - # Replace empty up/downsample kernels with full 1x1 kernels (faster than separable). - if fu is None: - fu = torch.ones([1, 1], dtype=torch.float32, device=x.device) - if fd is None: - fd = torch.ones([1, 1], dtype=torch.float32, device=x.device) - assert 1 <= fu.ndim <= 2 - assert 1 <= fd.ndim <= 2 - - # Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1. - if up == 1 and fu.ndim == 1 and fu.shape[0] == 1: - fu = fu.square()[None] - if down == 1 and fd.ndim == 1 and fd.shape[0] == 1: - fd = fd.square()[None] - - # Missing sign input tensor. - if si is None: - si = torch.empty([0]) - - # Missing bias tensor. - if b is None: - b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device) - - # Construct internal sign tensor only if gradients are needed. - write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad) - - # Warn if input storage strides are not in decreasing order due to e.g. channels-last layout. - strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1] - if any(a < b for a, b in zip(strides[:-1], strides[1:])): - warnings.warn("low-performance memory layout detected in filtered_lrelu input", RuntimeWarning) - - # Call C++/Cuda plugin if datatype is supported. - if x.dtype in [torch.float16, torch.float32]: - if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device): - warnings.warn("filtered_lrelu called with non-default cuda stream but concurrent execution is not supported", RuntimeWarning) - y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs) - else: - return_code = -1 - - # No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because - # only the bit-packed sign tensor is retained for gradient computation. - if return_code < 0: - warnings.warn("filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback", RuntimeWarning) - - y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias. - y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample. - so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place. - y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample. - - # Prepare for gradient computation. - ctx.save_for_backward(fu, fd, (si if si.numel() else so)) - ctx.x_shape = x.shape - ctx.y_shape = y.shape - ctx.s_ofs = sx, sy - return y - - @staticmethod - def backward(ctx, dy): # pylint: disable=arguments-differ - fu, fd, si = ctx.saved_tensors - _, _, xh, xw = ctx.x_shape - _, _, yh, yw = ctx.y_shape - sx, sy = ctx.s_ofs - dx = None # 0 - dfu = None; assert not ctx.needs_input_grad[1] - dfd = None; assert not ctx.needs_input_grad[2] - db = None # 3 - dsi = None; assert not ctx.needs_input_grad[4] - dsx = None; assert not ctx.needs_input_grad[5] - dsy = None; assert not ctx.needs_input_grad[6] - - if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]: - pp = [ - (fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0, - xw * up - yw * down + px0 - (up - 1), - (fu.shape[0] - 1) + (fd.shape[0] - 1) - py0, - xh * up - yh * down + py0 - (up - 1), - ] - gg = gain * (up ** 2) / (down ** 2) - ff = (not flip_filter) - sx = sx - (fu.shape[-1] - 1) + px0 - sy = sy - (fu.shape[0] - 1) + py0 - dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy) - - if ctx.needs_input_grad[3]: - db = dx.sum([0, 2, 3]) - - return dx, dfu, dfd, db, dsi, dsx, dsy - - # Add to cache. - _filtered_lrelu_cuda_cache[key] = FilteredLReluCuda - return FilteredLReluCuda - -#---------------------------------------------------------------------------- \ No newline at end of file diff --git a/torch_utils/ops/filtered_lrelu_ns.cu b/torch_utils/ops/filtered_lrelu_ns.cu deleted file mode 100644 index a65e743bc4a4c760a6f0605249041fdd52ec264d..0000000000000000000000000000000000000000 --- a/torch_utils/ops/filtered_lrelu_ns.cu +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// -// NVIDIA CORPORATION and its licensors retain all intellectual property -// and proprietary rights in and to this software, related documentation -// and any modifications thereto. Any use, reproduction, disclosure or -// distribution of this software and related documentation without an express -// license agreement from NVIDIA CORPORATION is strictly prohibited. - -#include "filtered_lrelu.cu" - -// Template/kernel specializations for no signs mode (no gradients required). - -// Full op, 32-bit indexing. -template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); -template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); - -// Full op, 64-bit indexing. -template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); -template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); - -// Activation/signs only for generic variant. 64-bit indexing. -template void* choose_filtered_lrelu_act_kernel(void); -template void* choose_filtered_lrelu_act_kernel(void); -template void* choose_filtered_lrelu_act_kernel(void); - -// Copy filters to constant memory. -template cudaError_t copy_filters(cudaStream_t stream); \ No newline at end of file diff --git a/torch_utils/ops/filtered_lrelu_rd.cu b/torch_utils/ops/filtered_lrelu_rd.cu deleted file mode 100644 index 79c75b0c22b6b09476782166e30e74d00f2c7d61..0000000000000000000000000000000000000000 --- a/torch_utils/ops/filtered_lrelu_rd.cu +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// -// NVIDIA CORPORATION and its licensors retain all intellectual property -// and proprietary rights in and to this software, related documentation -// and any modifications thereto. Any use, reproduction, disclosure or -// distribution of this software and related documentation without an express -// license agreement from NVIDIA CORPORATION is strictly prohibited. - -#include "filtered_lrelu.cu" - -// Template/kernel specializations for sign read mode. - -// Full op, 32-bit indexing. -template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); -template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); - -// Full op, 64-bit indexing. -template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); -template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); - -// Activation/signs only for generic variant. 64-bit indexing. -template void* choose_filtered_lrelu_act_kernel(void); -template void* choose_filtered_lrelu_act_kernel(void); -template void* choose_filtered_lrelu_act_kernel(void); - -// Copy filters to constant memory. -template cudaError_t copy_filters(cudaStream_t stream); \ No newline at end of file diff --git a/torch_utils/ops/filtered_lrelu_wr.cu b/torch_utils/ops/filtered_lrelu_wr.cu deleted file mode 100644 index 7c3e82eaf5ba7df13adade0e77a7796f57bfdf10..0000000000000000000000000000000000000000 --- a/torch_utils/ops/filtered_lrelu_wr.cu +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// -// NVIDIA CORPORATION and its licensors retain all intellectual property -// and proprietary rights in and to this software, related documentation -// and any modifications thereto. Any use, reproduction, disclosure or -// distribution of this software and related documentation without an express -// license agreement from NVIDIA CORPORATION is strictly prohibited. - -#include "filtered_lrelu.cu" - -// Template/kernel specializations for sign write mode. - -// Full op, 32-bit indexing. -template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); -template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); - -// Full op, 64-bit indexing. -template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); -template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); - -// Activation/signs only for generic variant. 64-bit indexing. -template void* choose_filtered_lrelu_act_kernel(void); -template void* choose_filtered_lrelu_act_kernel(void); -template void* choose_filtered_lrelu_act_kernel(void); - -// Copy filters to constant memory. -template cudaError_t copy_filters(cudaStream_t stream); \ No newline at end of file diff --git a/torch_utils/ops/fma.py b/torch_utils/ops/fma.py deleted file mode 100644 index 06530ed5e0731b1355b18c7fe1526786dc683d26..0000000000000000000000000000000000000000 --- a/torch_utils/ops/fma.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright (c) SenseTime Research. All rights reserved. - -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" - -import torch - -#---------------------------------------------------------------------------- - -def fma(a, b, c): # => a * b + c - return _FusedMultiplyAdd.apply(a, b, c) - -#---------------------------------------------------------------------------- - -class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c - @staticmethod - def forward(ctx, a, b, c): # pylint: disable=arguments-differ - out = torch.addcmul(c, a, b) - ctx.save_for_backward(a, b) - ctx.c_shape = c.shape - return out - - @staticmethod - def backward(ctx, dout): # pylint: disable=arguments-differ - a, b = ctx.saved_tensors - c_shape = ctx.c_shape - da = None - db = None - dc = None - - if ctx.needs_input_grad[0]: - da = _unbroadcast(dout * b, a.shape) - - if ctx.needs_input_grad[1]: - db = _unbroadcast(dout * a, b.shape) - - if ctx.needs_input_grad[2]: - dc = _unbroadcast(dout, c_shape) - - return da, db, dc - -#---------------------------------------------------------------------------- - -def _unbroadcast(x, shape): - extra_dims = x.ndim - len(shape) - assert extra_dims >= 0 - dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] - if len(dim): - x = x.sum(dim=dim, keepdim=True) - if extra_dims: - x = x.reshape(-1, *x.shape[extra_dims+1:]) - assert x.shape == shape - return x - -#---------------------------------------------------------------------------- diff --git a/torch_utils/ops/grid_sample_gradfix.py b/torch_utils/ops/grid_sample_gradfix.py deleted file mode 100644 index 4f69aad7510d49d55cd865b5e2554703f979b185..0000000000000000000000000000000000000000 --- a/torch_utils/ops/grid_sample_gradfix.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright (c) SenseTime Research. All rights reserved. - -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -"""Custom replacement for `torch.nn.functional.grid_sample` that -supports arbitrarily high order gradients between the input and output. -Only works on 2D images and assumes -`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" - -import warnings -import torch - -# pylint: disable=redefined-builtin -# pylint: disable=arguments-differ -# pylint: disable=protected-access - -#---------------------------------------------------------------------------- - -enabled = False # Enable the custom op by setting this to true. - -#---------------------------------------------------------------------------- - -def grid_sample(input, grid): - if _should_use_custom_op(): - return _GridSample2dForward.apply(input, grid) - return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) - -#---------------------------------------------------------------------------- - -def _should_use_custom_op(): - if not enabled: - return False - if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): - return True - warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().') - return False - -#---------------------------------------------------------------------------- - -class _GridSample2dForward(torch.autograd.Function): - @staticmethod - def forward(ctx, input, grid): - assert input.ndim == 4 - assert grid.ndim == 4 - output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) - ctx.save_for_backward(input, grid) - return output - - @staticmethod - def backward(ctx, grad_output): - input, grid = ctx.saved_tensors - grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) - return grad_input, grad_grid - -#---------------------------------------------------------------------------- - -class _GridSample2dBackward(torch.autograd.Function): - @staticmethod - def forward(ctx, grad_output, input, grid): - op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') - grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) - ctx.save_for_backward(grid) - return grad_input, grad_grid - - @staticmethod - def backward(ctx, grad2_grad_input, grad2_grad_grid): - _ = grad2_grad_grid # unused - grid, = ctx.saved_tensors - grad2_grad_output = None - grad2_input = None - grad2_grid = None - - if ctx.needs_input_grad[0]: - grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) - - assert not ctx.needs_input_grad[2] - return grad2_grad_output, grad2_input, grad2_grid - -#---------------------------------------------------------------------------- diff --git a/torch_utils/ops/upfirdn2d.cpp b/torch_utils/ops/upfirdn2d.cpp deleted file mode 100644 index 42bdd483490a555266c8f9b9dd6684464b2088bc..0000000000000000000000000000000000000000 --- a/torch_utils/ops/upfirdn2d.cpp +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) SenseTime Research. All rights reserved. - -// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -// -// NVIDIA CORPORATION and its licensors retain all intellectual property -// and proprietary rights in and to this software, related documentation -// and any modifications thereto. Any use, reproduction, disclosure or -// distribution of this software and related documentation without an express -// license agreement from NVIDIA CORPORATION is strictly prohibited. - -#include -#include -#include -#include "upfirdn2d.h" - -//------------------------------------------------------------------------ - -static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) -{ - // Validate arguments. - TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); - TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); - TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); - TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); - TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); - TORCH_CHECK(x.dim() == 4, "x must be rank 4"); - TORCH_CHECK(f.dim() == 2, "f must be rank 2"); - TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); - TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); - TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); - - // Create output tensor. - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; - int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; - TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); - torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); - TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); - - // Initialize CUDA kernel parameters. - upfirdn2d_kernel_params p; - p.x = x.data_ptr(); - p.f = f.data_ptr(); - p.y = y.data_ptr(); - p.up = make_int2(upx, upy); - p.down = make_int2(downx, downy); - p.pad0 = make_int2(padx0, pady0); - p.flip = (flip) ? 1 : 0; - p.gain = gain; - p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); - p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); - p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); - p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); - p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); - p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); - p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; - p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; - - // Choose CUDA kernel. - upfirdn2d_kernel_spec spec; - AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] - { - spec = choose_upfirdn2d_kernel(p); - }); - - // Set looping options. - p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; - p.loopMinor = spec.loopMinor; - p.loopX = spec.loopX; - p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; - p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; - - // Compute grid size. - dim3 blockSize, gridSize; - if (spec.tileOutW < 0) // large - { - blockSize = dim3(4, 32, 1); - gridSize = dim3( - ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, - (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, - p.launchMajor); - } - else // small - { - blockSize = dim3(256, 1, 1); - gridSize = dim3( - ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, - (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, - p.launchMajor); - } - - // Launch CUDA kernel. - void* args[] = {&p}; - AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); - return y; -} - -//------------------------------------------------------------------------ - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("upfirdn2d", &upfirdn2d); -} - -//------------------------------------------------------------------------ diff --git a/torch_utils/ops/upfirdn2d.cu b/torch_utils/ops/upfirdn2d.cu deleted file mode 100644 index 2126f450047bb4a7f2e77b27d105207d02acffcd..0000000000000000000000000000000000000000 --- a/torch_utils/ops/upfirdn2d.cu +++ /dev/null @@ -1,352 +0,0 @@ -// Copyright (c) SenseTime Research. All rights reserved. - -// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -// -// NVIDIA CORPORATION and its licensors retain all intellectual property -// and proprietary rights in and to this software, related documentation -// and any modifications thereto. Any use, reproduction, disclosure or -// distribution of this software and related documentation without an express -// license agreement from NVIDIA CORPORATION is strictly prohibited. - -#include -#include "upfirdn2d.h" - -//------------------------------------------------------------------------ -// Helpers. - -template struct InternalType; -template <> struct InternalType { typedef double scalar_t; }; -template <> struct InternalType { typedef float scalar_t; }; -template <> struct InternalType { typedef float scalar_t; }; - -static __device__ __forceinline__ int floor_div(int a, int b) -{ - int t = 1 - a / b; - return (a + t * b) / b - t; -} - -//------------------------------------------------------------------------ -// Generic CUDA implementation for large filters. - -template static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p) -{ - typedef typename InternalType::scalar_t scalar_t; - - // Calculate thread index. - int minorBase = blockIdx.x * blockDim.x + threadIdx.x; - int outY = minorBase / p.launchMinor; - minorBase -= outY * p.launchMinor; - int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; - int majorBase = blockIdx.z * p.loopMajor; - if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor) - return; - - // Setup Y receptive field. - int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y; - int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y); - int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY; - int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y; - if (p.flip) - filterY = p.filterSize.y - 1 - filterY; - - // Loop over major, minor, and X. - for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) - for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor) - { - int nc = major * p.sizeMinor + minor; - int n = nc / p.inSize.z; - int c = nc - n * p.inSize.z; - for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y) - { - // Setup X receptive field. - int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x; - int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x); - int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX; - int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x; - if (p.flip) - filterX = p.filterSize.x - 1 - filterX; - - // Initialize pointers. - const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; - const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y]; - int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x; - int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y; - - // Inner loop. - scalar_t v = 0; - for (int y = 0; y < h; y++) - { - for (int x = 0; x < w; x++) - { - v += (scalar_t)(*xp) * (scalar_t)(*fp); - xp += p.inStride.x; - fp += filterStepX; - } - xp += p.inStride.y - w * p.inStride.x; - fp += filterStepY - w * filterStepX; - } - - // Store result. - v *= p.gain; - ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; - } - } -} - -//------------------------------------------------------------------------ -// Specialized CUDA implementation for small filters. - -template -static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p) -{ - typedef typename InternalType::scalar_t scalar_t; - const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1; - const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1; - __shared__ volatile scalar_t sf[filterH][filterW]; - __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor]; - - // Calculate tile index. - int minorBase = blockIdx.x; - int tileOutY = minorBase / p.launchMinor; - minorBase -= tileOutY * p.launchMinor; - minorBase *= loopMinor; - tileOutY *= tileOutH; - int tileOutXBase = blockIdx.y * p.loopX * tileOutW; - int majorBase = blockIdx.z * p.loopMajor; - if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor) - return; - - // Load filter (flipped). - for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x) - { - int fy = tapIdx / filterW; - int fx = tapIdx - fy * filterW; - scalar_t v = 0; - if (fx < p.filterSize.x & fy < p.filterSize.y) - { - int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx; - int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy; - v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y]; - } - sf[fy][fx] = v; - } - - // Loop over major and X. - for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) - { - int baseNC = major * p.sizeMinor + minorBase; - int n = baseNC / p.inSize.z; - int baseC = baseNC - n * p.inSize.z; - for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW) - { - // Load input pixels. - int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x; - int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y; - int tileInX = floor_div(tileMidX, upx); - int tileInY = floor_div(tileMidY, upy); - __syncthreads(); - for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x) - { - int relC = inIdx; - int relInX = relC / loopMinor; - int relInY = relInX / tileInW; - relC -= relInX * loopMinor; - relInX -= relInY * tileInW; - int c = baseC + relC; - int inX = tileInX + relInX; - int inY = tileInY + relInY; - scalar_t v = 0; - if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z) - v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; - sx[relInY][relInX][relC] = v; - } - - // Loop over output pixels. - __syncthreads(); - for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x) - { - int relC = outIdx; - int relOutX = relC / loopMinor; - int relOutY = relOutX / tileOutW; - relC -= relOutX * loopMinor; - relOutX -= relOutY * tileOutW; - int c = baseC + relC; - int outX = tileOutX + relOutX; - int outY = tileOutY + relOutY; - - // Setup receptive field. - int midX = tileMidX + relOutX * downx; - int midY = tileMidY + relOutY * downy; - int inX = floor_div(midX, upx); - int inY = floor_div(midY, upy); - int relInX = inX - tileInX; - int relInY = inY - tileInY; - int filterX = (inX + 1) * upx - midX - 1; // flipped - int filterY = (inY + 1) * upy - midY - 1; // flipped - - // Inner loop. - if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z) - { - scalar_t v = 0; - #pragma unroll - for (int y = 0; y < filterH / upy; y++) - #pragma unroll - for (int x = 0; x < filterW / upx; x++) - v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx]; - v *= p.gain; - ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; - } - } - } - } -} - -//------------------------------------------------------------------------ -// CUDA kernel selection. - -template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p) -{ - int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y; - - upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large, -1,-1,1, 4}; // contiguous - if (s == 1) spec = {(void*)upfirdn2d_kernel_large, -1,-1,4, 1}; // channels_last - - if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous - { - if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; - if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; - if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; - if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; - if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; - if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; - if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; - if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; - if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; - if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; - if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; - if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; - if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; - if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; - if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; - } - if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last - { - if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; - if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; - if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; - if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; - if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; - if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; - if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; - if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; - if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; - if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; - if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; - if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; - if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; - if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; - if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; - } - if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous - { - if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; - if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; - if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; - if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; - } - if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last - { - if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; - if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; - if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; - if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; - } - if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous - { - if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; - if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; - if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; - if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; - if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; - } - if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last - { - if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; - if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; - if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; - if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; - if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; - } - if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous - { - if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; - if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; - if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; - if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; - if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; - } - if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last - { - if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; - if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; - if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; - if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; - if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; - } - if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous - { - if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; - if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; - if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; - if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; - } - if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last - { - if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; - if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; - if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; - if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; - } - if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous - { - if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; - if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; - if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; - if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; - if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; - } - if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last - { - if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; - if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; - if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; - if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; - if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; - } - if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous - { - if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; - if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; - if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; - if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; - if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; - } - if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last - { - if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; - if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; - if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; - if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; - if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; - } - return spec; -} - -//------------------------------------------------------------------------ -// Template specializations. - -template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); -template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); -template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); - -//------------------------------------------------------------------------ diff --git a/torch_utils/ops/upfirdn2d.h b/torch_utils/ops/upfirdn2d.h deleted file mode 100644 index dc6e713694d3fcca0e06cecfb9437ffb4932ffe6..0000000000000000000000000000000000000000 --- a/torch_utils/ops/upfirdn2d.h +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) SenseTime Research. All rights reserved. - -// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -// -// NVIDIA CORPORATION and its licensors retain all intellectual property -// and proprietary rights in and to this software, related documentation -// and any modifications thereto. Any use, reproduction, disclosure or -// distribution of this software and related documentation without an express -// license agreement from NVIDIA CORPORATION is strictly prohibited. - -#include - -//------------------------------------------------------------------------ -// CUDA kernel parameters. - -struct upfirdn2d_kernel_params -{ - const void* x; - const float* f; - void* y; - - int2 up; - int2 down; - int2 pad0; - int flip; - float gain; - - int4 inSize; // [width, height, channel, batch] - int4 inStride; - int2 filterSize; // [width, height] - int2 filterStride; - int4 outSize; // [width, height, channel, batch] - int4 outStride; - int sizeMinor; - int sizeMajor; - - int loopMinor; - int loopMajor; - int loopX; - int launchMinor; - int launchMajor; -}; - -//------------------------------------------------------------------------ -// CUDA kernel specialization. - -struct upfirdn2d_kernel_spec -{ - void* kernel; - int tileOutW; - int tileOutH; - int loopMinor; - int loopX; -}; - -//------------------------------------------------------------------------ -// CUDA kernel selection. - -template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); - -//------------------------------------------------------------------------ diff --git a/torch_utils/ops/upfirdn2d.py b/torch_utils/ops/upfirdn2d.py deleted file mode 100644 index a14c15fe8737cf047338d2de795d7e40a1f4e9cc..0000000000000000000000000000000000000000 --- a/torch_utils/ops/upfirdn2d.py +++ /dev/null @@ -1,386 +0,0 @@ -# Copyright (c) SenseTime Research. All rights reserved. - -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -"""Custom PyTorch ops for efficient resampling of 2D images.""" - -import os -import warnings -import numpy as np -import torch -import traceback - -from .. import custom_ops -from .. import misc -from . import conv2d_gradfix - -#---------------------------------------------------------------------------- - -_inited = False -_plugin = None - -def _init(): - global _inited, _plugin - if not _inited: - sources = ['upfirdn2d.cpp', 'upfirdn2d.cu'] - sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] - try: - _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) - except: - warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) - return _plugin is not None - -def _parse_scaling(scaling): - if isinstance(scaling, int): - scaling = [scaling, scaling] - assert isinstance(scaling, (list, tuple)) - assert all(isinstance(x, int) for x in scaling) - sx, sy = scaling - assert sx >= 1 and sy >= 1 - return sx, sy - -def _parse_padding(padding): - if isinstance(padding, int): - padding = [padding, padding] - assert isinstance(padding, (list, tuple)) - assert all(isinstance(x, int) for x in padding) - if len(padding) == 2: - padx, pady = padding - padding = [padx, padx, pady, pady] - padx0, padx1, pady0, pady1 = padding - return padx0, padx1, pady0, pady1 - -def _get_filter_size(f): - if f is None: - return 1, 1 - assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] - fw = f.shape[-1] - fh = f.shape[0] - with misc.suppress_tracer_warnings(): - fw = int(fw) - fh = int(fh) - misc.assert_shape(f, [fh, fw][:f.ndim]) - assert fw >= 1 and fh >= 1 - return fw, fh - -#---------------------------------------------------------------------------- - -def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): - r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. - - Args: - f: Torch tensor, numpy array, or python list of the shape - `[filter_height, filter_width]` (non-separable), - `[filter_taps]` (separable), - `[]` (impulse), or - `None` (identity). - device: Result device (default: cpu). - normalize: Normalize the filter so that it retains the magnitude - for constant input signal (DC)? (default: True). - flip_filter: Flip the filter? (default: False). - gain: Overall scaling factor for signal magnitude (default: 1). - separable: Return a separable filter? (default: select automatically). - - Returns: - Float32 tensor of the shape - `[filter_height, filter_width]` (non-separable) or - `[filter_taps]` (separable). - """ - # Validate. - if f is None: - f = 1 - f = torch.as_tensor(f, dtype=torch.float32) - assert f.ndim in [0, 1, 2] - assert f.numel() > 0 - if f.ndim == 0: - f = f[np.newaxis] - - # Separable? - if separable is None: - separable = (f.ndim == 1 and f.numel() >= 8) - if f.ndim == 1 and not separable: - f = f.ger(f) - assert f.ndim == (1 if separable else 2) - - # Apply normalize, flip, gain, and device. - if normalize: - f /= f.sum() - if flip_filter: - f = f.flip(list(range(f.ndim))) - f = f * (gain ** (f.ndim / 2)) - f = f.to(device=device) - return f - -#---------------------------------------------------------------------------- - -def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): - r"""Pad, upsample, filter, and downsample a batch of 2D images. - - Performs the following sequence of operations for each channel: - - 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). - - 2. Pad the image with the specified number of zeros on each side (`padding`). - Negative padding corresponds to cropping the image. - - 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it - so that the footprint of all output pixels lies within the input image. - - 4. Downsample the image by keeping every Nth pixel (`down`). - - This sequence of operations bears close resemblance to scipy.signal.upfirdn(). - The fused op is considerably more efficient than performing the same calculation - using standard PyTorch ops. It supports gradients of arbitrary order. - - Args: - x: Float32/float64/float16 input tensor of the shape - `[batch_size, num_channels, in_height, in_width]`. - f: Float32 FIR filter of the shape - `[filter_height, filter_width]` (non-separable), - `[filter_taps]` (separable), or - `None` (identity). - up: Integer upsampling factor. Can be a single int or a list/tuple - `[x, y]` (default: 1). - down: Integer downsampling factor. Can be a single int or a list/tuple - `[x, y]` (default: 1). - padding: Padding with respect to the upsampled image. Can be a single number - or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` - (default: 0). - flip_filter: False = convolution, True = correlation (default: False). - gain: Overall scaling factor for signal magnitude (default: 1). - impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). - - Returns: - Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. - """ - assert isinstance(x, torch.Tensor) - assert impl in ['ref', 'cuda'] - if impl == 'cuda' and x.device.type == 'cuda' and _init(): - return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f) - return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) - -#---------------------------------------------------------------------------- - -@misc.profiled_function -def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): - """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. - """ - # Validate arguments. - assert isinstance(x, torch.Tensor) and x.ndim == 4 - if f is None: - f = torch.ones([1, 1], dtype=torch.float32, device=x.device) - assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] - assert f.dtype == torch.float32 and not f.requires_grad - batch_size, num_channels, in_height, in_width = x.shape - upx, upy = _parse_scaling(up) - downx, downy = _parse_scaling(down) - padx0, padx1, pady0, pady1 = _parse_padding(padding) - - # Upsample by inserting zeros. - x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) - x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) - x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) - - # Pad or crop. - x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) - x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)] - - # Setup filter. - f = f * (gain ** (f.ndim / 2)) - f = f.to(x.dtype) - if not flip_filter: - f = f.flip(list(range(f.ndim))) - - # Convolve with the filter. - f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) - if f.ndim == 4: - x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels) - else: - x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) - x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) - - # Downsample by throwing away pixels. - x = x[:, :, ::downy, ::downx] - return x - -#---------------------------------------------------------------------------- - -_upfirdn2d_cuda_cache = dict() - -def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): - """Fast CUDA implementation of `upfirdn2d()` using custom ops. - """ - # Parse arguments. - upx, upy = _parse_scaling(up) - downx, downy = _parse_scaling(down) - padx0, padx1, pady0, pady1 = _parse_padding(padding) - - # Lookup from cache. - key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) - if key in _upfirdn2d_cuda_cache: - return _upfirdn2d_cuda_cache[key] - - # Forward op. - class Upfirdn2dCuda(torch.autograd.Function): - @staticmethod - def forward(ctx, x, f): # pylint: disable=arguments-differ - assert isinstance(x, torch.Tensor) and x.ndim == 4 - if f is None: - f = torch.ones([1, 1], dtype=torch.float32, device=x.device) - assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] - y = x - if f.ndim == 2: - y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) - else: - y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain)) - y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain)) - ctx.save_for_backward(f) - ctx.x_shape = x.shape - return y - - @staticmethod - def backward(ctx, dy): # pylint: disable=arguments-differ - f, = ctx.saved_tensors - _, _, ih, iw = ctx.x_shape - _, _, oh, ow = dy.shape - fw, fh = _get_filter_size(f) - p = [ - fw - padx0 - 1, - iw * upx - ow * downx + padx0 - upx + 1, - fh - pady0 - 1, - ih * upy - oh * downy + pady0 - upy + 1, - ] - dx = None - df = None - - if ctx.needs_input_grad[0]: - dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f) - - assert not ctx.needs_input_grad[1] - return dx, df - - # Add to cache. - _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda - return Upfirdn2dCuda - -#---------------------------------------------------------------------------- - -def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'): - r"""Filter a batch of 2D images using the given 2D FIR filter. - - By default, the result is padded so that its shape matches the input. - User-specified padding is applied on top of that, with negative values - indicating cropping. Pixels outside the image are assumed to be zero. - - Args: - x: Float32/float64/float16 input tensor of the shape - `[batch_size, num_channels, in_height, in_width]`. - f: Float32 FIR filter of the shape - `[filter_height, filter_width]` (non-separable), - `[filter_taps]` (separable), or - `None` (identity). - padding: Padding with respect to the output. Can be a single number or a - list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` - (default: 0). - flip_filter: False = convolution, True = correlation (default: False). - gain: Overall scaling factor for signal magnitude (default: 1). - impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). - - Returns: - Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. - """ - padx0, padx1, pady0, pady1 = _parse_padding(padding) - fw, fh = _get_filter_size(f) - p = [ - padx0 + fw // 2, - padx1 + (fw - 1) // 2, - pady0 + fh // 2, - pady1 + (fh - 1) // 2, - ] - return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) - -#---------------------------------------------------------------------------- - -def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): - r"""Upsample a batch of 2D images using the given 2D FIR filter. - - By default, the result is padded so that its shape is a multiple of the input. - User-specified padding is applied on top of that, with negative values - indicating cropping. Pixels outside the image are assumed to be zero. - - Args: - x: Float32/float64/float16 input tensor of the shape - `[batch_size, num_channels, in_height, in_width]`. - f: Float32 FIR filter of the shape - `[filter_height, filter_width]` (non-separable), - `[filter_taps]` (separable), or - `None` (identity). - up: Integer upsampling factor. Can be a single int or a list/tuple - `[x, y]` (default: 1). - padding: Padding with respect to the output. Can be a single number or a - list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` - (default: 0). - flip_filter: False = convolution, True = correlation (default: False). - gain: Overall scaling factor for signal magnitude (default: 1). - impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). - - Returns: - Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. - """ - upx, upy = _parse_scaling(up) - padx0, padx1, pady0, pady1 = _parse_padding(padding) - fw, fh = _get_filter_size(f) - p = [ - padx0 + (fw + upx - 1) // 2, - padx1 + (fw - upx) // 2, - pady0 + (fh + upy - 1) // 2, - pady1 + (fh - upy) // 2, - ] - return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl) - -#---------------------------------------------------------------------------- - -def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): - r"""Downsample a batch of 2D images using the given 2D FIR filter. - - By default, the result is padded so that its shape is a fraction of the input. - User-specified padding is applied on top of that, with negative values - indicating cropping. Pixels outside the image are assumed to be zero. - - Args: - x: Float32/float64/float16 input tensor of the shape - `[batch_size, num_channels, in_height, in_width]`. - f: Float32 FIR filter of the shape - `[filter_height, filter_width]` (non-separable), - `[filter_taps]` (separable), or - `None` (identity). - down: Integer downsampling factor. Can be a single int or a list/tuple - `[x, y]` (default: 1). - padding: Padding with respect to the input. Can be a single number or a - list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` - (default: 0). - flip_filter: False = convolution, True = correlation (default: False). - gain: Overall scaling factor for signal magnitude (default: 1). - impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). - - Returns: - Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. - """ - downx, downy = _parse_scaling(down) - padx0, padx1, pady0, pady1 = _parse_padding(padding) - fw, fh = _get_filter_size(f) - p = [ - padx0 + (fw - downx + 1) // 2, - padx1 + (fw - downx) // 2, - pady0 + (fh - downy + 1) // 2, - pady1 + (fh - downy) // 2, - ] - return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) - -#---------------------------------------------------------------------------- diff --git a/torch_utils/persistence.py b/torch_utils/persistence.py deleted file mode 100644 index 50269409c8d9f7c38d7870ee7c8e4660bfb4115c..0000000000000000000000000000000000000000 --- a/torch_utils/persistence.py +++ /dev/null @@ -1,253 +0,0 @@ -# Copyright (c) SenseTime Research. All rights reserved. - -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -"""Facilities for pickling Python code alongside other data. - -The pickled code is automatically imported into a separate Python module -during unpickling. This way, any previously exported pickles will remain -usable even if the original code is no longer available, or if the current -version of the code is not consistent with what was originally pickled.""" - -import sys -import pickle -import io -import inspect -import copy -import uuid -import types -import dnnlib - -#---------------------------------------------------------------------------- - -_version = 6 # internal version number -_decorators = set() # {decorator_class, ...} -_import_hooks = [] # [hook_function, ...] -_module_to_src_dict = dict() # {module: src, ...} -_src_to_module_dict = dict() # {src: module, ...} - -#---------------------------------------------------------------------------- - -def persistent_class(orig_class): - r"""Class decorator that extends a given class to save its source code - when pickled. - - Example: - - from torch_utils import persistence - - @persistence.persistent_class - class MyNetwork(torch.nn.Module): - def __init__(self, num_inputs, num_outputs): - super().__init__() - self.fc = MyLayer(num_inputs, num_outputs) - ... - - @persistence.persistent_class - class MyLayer(torch.nn.Module): - ... - - When pickled, any instance of `MyNetwork` and `MyLayer` will save its - source code alongside other internal state (e.g., parameters, buffers, - and submodules). This way, any previously exported pickle will remain - usable even if the class definitions have been modified or are no - longer available. - - The decorator saves the source code of the entire Python module - containing the decorated class. It does *not* save the source code of - any imported modules. Thus, the imported modules must be available - during unpickling, also including `torch_utils.persistence` itself. - - It is ok to call functions defined in the same module from the - decorated class. However, if the decorated class depends on other - classes defined in the same module, they must be decorated as well. - This is illustrated in the above example in the case of `MyLayer`. - - It is also possible to employ the decorator just-in-time before - calling the constructor. For example: - - cls = MyLayer - if want_to_make_it_persistent: - cls = persistence.persistent_class(cls) - layer = cls(num_inputs, num_outputs) - - As an additional feature, the decorator also keeps track of the - arguments that were used to construct each instance of the decorated - class. The arguments can be queried via `obj.init_args` and - `obj.init_kwargs`, and they are automatically pickled alongside other - object state. A typical use case is to first unpickle a previous - instance of a persistent class, and then upgrade it to use the latest - version of the source code: - - with open('old_pickle.pkl', 'rb') as f: - old_net = pickle.load(f) - new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) - misc.copy_params_and_buffers(old_net, new_net, require_all=True) - """ - assert isinstance(orig_class, type) - if is_persistent(orig_class): - return orig_class - - assert orig_class.__module__ in sys.modules - orig_module = sys.modules[orig_class.__module__] - orig_module_src = _module_to_src(orig_module) - - class Decorator(orig_class): - _orig_module_src = orig_module_src - _orig_class_name = orig_class.__name__ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._init_args = copy.deepcopy(args) - self._init_kwargs = copy.deepcopy(kwargs) - assert orig_class.__name__ in orig_module.__dict__ - _check_pickleable(self.__reduce__()) - - @property - def init_args(self): - return copy.deepcopy(self._init_args) - - @property - def init_kwargs(self): - return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) - - def __reduce__(self): - fields = list(super().__reduce__()) - fields += [None] * max(3 - len(fields), 0) - if fields[0] is not _reconstruct_persistent_obj: - meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) - fields[0] = _reconstruct_persistent_obj # reconstruct func - fields[1] = (meta,) # reconstruct args - fields[2] = None # state dict - return tuple(fields) - - Decorator.__name__ = orig_class.__name__ - _decorators.add(Decorator) - return Decorator - -#---------------------------------------------------------------------------- - -def is_persistent(obj): - r"""Test whether the given object or class is persistent, i.e., - whether it will save its source code when pickled. - """ - try: - if obj in _decorators: - return True - except TypeError: - pass - return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck - -#---------------------------------------------------------------------------- - -def import_hook(hook): - r"""Register an import hook that is called whenever a persistent object - is being unpickled. A typical use case is to patch the pickled source - code to avoid errors and inconsistencies when the API of some imported - module has changed. - - The hook should have the following signature: - - hook(meta) -> modified meta - - `meta` is an instance of `dnnlib.EasyDict` with the following fields: - - type: Type of the persistent object, e.g. `'class'`. - version: Internal version number of `torch_utils.persistence`. - module_src Original source code of the Python module. - class_name: Class name in the original Python module. - state: Internal state of the object. - - Example: - - @persistence.import_hook - def wreck_my_network(meta): - if meta.class_name == 'MyNetwork': - print('MyNetwork is being imported. I will wreck it!') - meta.module_src = meta.module_src.replace("True", "False") - return meta - """ - assert callable(hook) - _import_hooks.append(hook) - -#---------------------------------------------------------------------------- - -def _reconstruct_persistent_obj(meta): - r"""Hook that is called internally by the `pickle` module to unpickle - a persistent object. - """ - meta = dnnlib.EasyDict(meta) - meta.state = dnnlib.EasyDict(meta.state) - for hook in _import_hooks: - meta = hook(meta) - assert meta is not None - - assert meta.version == _version - module = _src_to_module(meta.module_src) - - assert meta.type == 'class' - orig_class = module.__dict__[meta.class_name] - decorator_class = persistent_class(orig_class) - obj = decorator_class.__new__(decorator_class) - - setstate = getattr(obj, '__setstate__', None) - if callable(setstate): - setstate(meta.state) # pylint: disable=not-callable - else: - obj.__dict__.update(meta.state) - return obj - -#---------------------------------------------------------------------------- - -def _module_to_src(module): - r"""Query the source code of a given Python module. - """ - src = _module_to_src_dict.get(module, None) - if src is None: - src = inspect.getsource(module) - _module_to_src_dict[module] = src - _src_to_module_dict[src] = module - return src - -def _src_to_module(src): - r"""Get or create a Python module for the given source code. - """ - module = _src_to_module_dict.get(src, None) - if module is None: - module_name = "_imported_module_" + uuid.uuid4().hex - module = types.ModuleType(module_name) - sys.modules[module_name] = module - _module_to_src_dict[module] = src - _src_to_module_dict[src] = module - exec(src, module.__dict__) # pylint: disable=exec-used - return module - -#---------------------------------------------------------------------------- - -def _check_pickleable(obj): - r"""Check that the given object is pickleable, raising an exception if - it is not. This function is expected to be considerably more efficient - than actually pickling the object. - """ - def recurse(obj): - if isinstance(obj, (list, tuple, set)): - return [recurse(x) for x in obj] - if isinstance(obj, dict): - return [[recurse(x), recurse(y)] for x, y in obj.items()] - if isinstance(obj, (str, int, float, bool, bytes, bytearray)): - return None # Python primitive types are pickleable. - if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']: - return None # NumPy arrays and PyTorch tensors are pickleable. - if is_persistent(obj): - return None # Persistent objects are pickleable, by virtue of the constructor check. - return obj - with io.BytesIO() as f: - pickle.dump(recurse(obj), f) - -#---------------------------------------------------------------------------- diff --git a/torch_utils/training_stats.py b/torch_utils/training_stats.py deleted file mode 100644 index 3eb94d95286d8aeffe40ad32ca667e53b4622c4f..0000000000000000000000000000000000000000 --- a/torch_utils/training_stats.py +++ /dev/null @@ -1,270 +0,0 @@ -# Copyright (c) SenseTime Research. All rights reserved. - -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -"""Facilities for reporting and collecting training statistics across -multiple processes and devices. The interface is designed to minimize -synchronization overhead as well as the amount of boilerplate in user -code.""" - -import re -import numpy as np -import torch -import dnnlib - -from . import misc - -#---------------------------------------------------------------------------- - -_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] -_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. -_counter_dtype = torch.float64 # Data type to use for the internal counters. -_rank = 0 # Rank of the current process. -_sync_device = None # Device to use for multiprocess communication. None = single-process. -_sync_called = False # Has _sync() been called yet? -_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor -_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor - -#---------------------------------------------------------------------------- - -def init_multiprocessing(rank, sync_device): - r"""Initializes `torch_utils.training_stats` for collecting statistics - across multiple processes. - - This function must be called after - `torch.distributed.init_process_group()` and before `Collector.update()`. - The call is not necessary if multi-process collection is not needed. - - Args: - rank: Rank of the current process. - sync_device: PyTorch device to use for inter-process - communication, or None to disable multi-process - collection. Typically `torch.device('cuda', rank)`. - """ - global _rank, _sync_device - assert not _sync_called - _rank = rank - _sync_device = sync_device - -#---------------------------------------------------------------------------- - -@misc.profiled_function -def report(name, value): - r"""Broadcasts the given set of scalars to all interested instances of - `Collector`, across device and process boundaries. - - This function is expected to be extremely cheap and can be safely - called from anywhere in the training loop, loss function, or inside a - `torch.nn.Module`. - - Warning: The current implementation expects the set of unique names to - be consistent across processes. Please make sure that `report()` is - called at least once for each unique name by each process, and in the - same order. If a given process has no scalars to broadcast, it can do - `report(name, [])` (empty list). - - Args: - name: Arbitrary string specifying the name of the statistic. - Averages are accumulated separately for each unique name. - value: Arbitrary set of scalars. Can be a list, tuple, - NumPy array, PyTorch tensor, or Python scalar. - - Returns: - The same `value` that was passed in. - """ - if name not in _counters: - _counters[name] = dict() - - elems = torch.as_tensor(value) - if elems.numel() == 0: - return value - - elems = elems.detach().flatten().to(_reduce_dtype) - moments = torch.stack([ - torch.ones_like(elems).sum(), - elems.sum(), - elems.square().sum(), - ]) - assert moments.ndim == 1 and moments.shape[0] == _num_moments - moments = moments.to(_counter_dtype) - - device = moments.device - if device not in _counters[name]: - _counters[name][device] = torch.zeros_like(moments) - _counters[name][device].add_(moments) - return value - -#---------------------------------------------------------------------------- - -def report0(name, value): - r"""Broadcasts the given set of scalars by the first process (`rank = 0`), - but ignores any scalars provided by the other processes. - See `report()` for further details. - """ - report(name, value if _rank == 0 else []) - return value - -#---------------------------------------------------------------------------- - -class Collector: - r"""Collects the scalars broadcasted by `report()` and `report0()` and - computes their long-term averages (mean and standard deviation) over - user-defined periods of time. - - The averages are first collected into internal counters that are not - directly visible to the user. They are then copied to the user-visible - state as a result of calling `update()` and can then be queried using - `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the - internal counters for the next round, so that the user-visible state - effectively reflects averages collected between the last two calls to - `update()`. - - Args: - regex: Regular expression defining which statistics to - collect. The default is to collect everything. - keep_previous: Whether to retain the previous averages if no - scalars were collected on a given round - (default: True). - """ - def __init__(self, regex='.*', keep_previous=True): - self._regex = re.compile(regex) - self._keep_previous = keep_previous - self._cumulative = dict() - self._moments = dict() - self.update() - self._moments.clear() - - def names(self): - r"""Returns the names of all statistics broadcasted so far that - match the regular expression specified at construction time. - """ - return [name for name in _counters if self._regex.fullmatch(name)] - - def update(self): - r"""Copies current values of the internal counters to the - user-visible state and resets them for the next round. - - If `keep_previous=True` was specified at construction time, the - operation is skipped for statistics that have received no scalars - since the last update, retaining their previous averages. - - This method performs a number of GPU-to-CPU transfers and one - `torch.distributed.all_reduce()`. It is intended to be called - periodically in the main training loop, typically once every - N training steps. - """ - if not self._keep_previous: - self._moments.clear() - for name, cumulative in _sync(self.names()): - if name not in self._cumulative: - self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) - delta = cumulative - self._cumulative[name] - self._cumulative[name].copy_(cumulative) - if float(delta[0]) != 0: - self._moments[name] = delta - - def _get_delta(self, name): - r"""Returns the raw moments that were accumulated for the given - statistic between the last two calls to `update()`, or zero if - no scalars were collected. - """ - assert self._regex.fullmatch(name) - if name not in self._moments: - self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) - return self._moments[name] - - def num(self, name): - r"""Returns the number of scalars that were accumulated for the given - statistic between the last two calls to `update()`, or zero if - no scalars were collected. - """ - delta = self._get_delta(name) - return int(delta[0]) - - def mean(self, name): - r"""Returns the mean of the scalars that were accumulated for the - given statistic between the last two calls to `update()`, or NaN if - no scalars were collected. - """ - delta = self._get_delta(name) - if int(delta[0]) == 0: - return float('nan') - return float(delta[1] / delta[0]) - - def std(self, name): - r"""Returns the standard deviation of the scalars that were - accumulated for the given statistic between the last two calls to - `update()`, or NaN if no scalars were collected. - """ - delta = self._get_delta(name) - if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): - return float('nan') - if int(delta[0]) == 1: - return float(0) - mean = float(delta[1] / delta[0]) - raw_var = float(delta[2] / delta[0]) - return np.sqrt(max(raw_var - np.square(mean), 0)) - - def as_dict(self): - r"""Returns the averages accumulated between the last two calls to - `update()` as an `dnnlib.EasyDict`. The contents are as follows: - - dnnlib.EasyDict( - NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), - ... - ) - """ - stats = dnnlib.EasyDict() - for name in self.names(): - stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) - return stats - - def __getitem__(self, name): - r"""Convenience getter. - `collector[name]` is a synonym for `collector.mean(name)`. - """ - return self.mean(name) - -#---------------------------------------------------------------------------- - -def _sync(names): - r"""Synchronize the global cumulative counters across devices and - processes. Called internally by `Collector.update()`. - """ - if len(names) == 0: - return [] - global _sync_called - _sync_called = True - - # Collect deltas within current rank. - deltas = [] - device = _sync_device if _sync_device is not None else torch.device('cpu') - for name in names: - delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) - for counter in _counters[name].values(): - delta.add_(counter.to(device)) - counter.copy_(torch.zeros_like(counter)) - deltas.append(delta) - deltas = torch.stack(deltas) - - # Sum deltas across ranks. - if _sync_device is not None: - torch.distributed.all_reduce(deltas) - - # Update cumulative values. - deltas = deltas.cpu() - for idx, name in enumerate(names): - if name not in _cumulative: - _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) - _cumulative[name].add_(deltas[idx]) - - # Return name-value pairs. - return [(name, _cumulative[name]) for name in names] - -#----------------------------------------------------------------------------