Spaces:
Running
on
Zero
Running
on
Zero
| # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: LicenseRef-NvidiaProprietary | |
| # | |
| # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual | |
| # property and proprietary rights in and to this material, related | |
| # documentation and any modifications thereto. Any use, reproduction, | |
| # disclosure or distribution of this material and related documentation | |
| # without an express license agreement from NVIDIA CORPORATION or | |
| # its affiliates is strictly prohibited. | |
| """Miscellaneous utility classes and functions.""" | |
| import ctypes | |
| import fnmatch | |
| import importlib | |
| import inspect | |
| import numpy as np | |
| import os | |
| import shutil | |
| import sys | |
| import types | |
| import io | |
| import pickle | |
| import re | |
| import requests | |
| import html | |
| import hashlib | |
| import glob | |
| import tempfile | |
| import urllib | |
| import urllib.request | |
| import uuid | |
| from distutils.util import strtobool | |
| from typing import Any, List, Tuple, Union | |
| import torch | |
| # Util classes | |
| # ------------------------------------------------------------------------------------------ | |
| def calculate_adaptive_weight(recon_loss, g_loss, last_layer, disc_weight_max=1.0): | |
| recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0] | |
| g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] | |
| d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4) | |
| d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach() | |
| return d_weight | |
| class EasyDict(dict): | |
| """Convenience class that behaves like a dict but allows access with the attribute syntax.""" | |
| def __getattr__(self, name: str) -> Any: | |
| try: | |
| return self[name] | |
| except KeyError: | |
| raise AttributeError(name) | |
| def __setattr__(self, name: str, value: Any) -> None: | |
| self[name] = value | |
| def __delattr__(self, name: str) -> None: | |
| del self[name] | |
| class Logger(object): | |
| """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" | |
| def __init__(self, | |
| file_name: str = None, | |
| file_mode: str = "w", | |
| should_flush: bool = True): | |
| self.file = None | |
| if file_name is not None: | |
| self.file = open(file_name, file_mode) | |
| self.should_flush = should_flush | |
| self.stdout = sys.stdout | |
| self.stderr = sys.stderr | |
| sys.stdout = self | |
| sys.stderr = self | |
| def __enter__(self) -> "Logger": | |
| return self | |
| def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: | |
| self.close() | |
| def write(self, text: Union[str, bytes]) -> None: | |
| """Write text to stdout (and a file) and optionally flush.""" | |
| if isinstance(text, bytes): | |
| text = text.decode() | |
| if len( | |
| text | |
| ) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash | |
| return | |
| if self.file is not None: | |
| self.file.write(text) | |
| self.stdout.write(text) | |
| if self.should_flush: | |
| self.flush() | |
| def flush(self) -> None: | |
| """Flush written text to both stdout and a file, if open.""" | |
| if self.file is not None: | |
| self.file.flush() | |
| self.stdout.flush() | |
| def close(self) -> None: | |
| """Flush, close possible files, and remove stdout/stderr mirroring.""" | |
| self.flush() | |
| # if using multiple loggers, prevent closing in wrong order | |
| if sys.stdout is self: | |
| sys.stdout = self.stdout | |
| if sys.stderr is self: | |
| sys.stderr = self.stderr | |
| if self.file is not None: | |
| self.file.close() | |
| self.file = None | |
| # Cache directories | |
| # ------------------------------------------------------------------------------------------ | |
| _dnnlib_cache_dir = None | |
| def set_cache_dir(path: str) -> None: | |
| global _dnnlib_cache_dir | |
| _dnnlib_cache_dir = path | |
| def make_cache_dir_path(*paths: str) -> str: | |
| if _dnnlib_cache_dir is not None: | |
| return os.path.join(_dnnlib_cache_dir, *paths) | |
| if 'DNNLIB_CACHE_DIR' in os.environ: | |
| return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths) | |
| if 'HOME' in os.environ: | |
| return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths) | |
| if 'USERPROFILE' in os.environ: | |
| return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', | |
| *paths) | |
| return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths) | |
| # Small util functions | |
| # ------------------------------------------------------------------------------------------ | |
| def format_time(seconds: Union[int, float]) -> str: | |
| """Convert the seconds to human readable string with days, hours, minutes and seconds.""" | |
| s = int(np.rint(seconds)) | |
| if s < 60: | |
| return "{0}s".format(s) | |
| elif s < 60 * 60: | |
| return "{0}m {1:02}s".format(s // 60, s % 60) | |
| elif s < 24 * 60 * 60: | |
| return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, | |
| s % 60) | |
| else: | |
| return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), | |
| (s // (60 * 60)) % 24, | |
| (s // 60) % 60) | |
| def format_time_brief(seconds: Union[int, float]) -> str: | |
| """Convert the seconds to human readable string with days, hours, minutes and seconds.""" | |
| s = int(np.rint(seconds)) | |
| if s < 60: | |
| return "{0}s".format(s) | |
| elif s < 60 * 60: | |
| return "{0}m {1:02}s".format(s // 60, s % 60) | |
| elif s < 24 * 60 * 60: | |
| return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60) | |
| else: | |
| return "{0}d {1:02}h".format(s // (24 * 60 * 60), | |
| (s // (60 * 60)) % 24) | |
| def ask_yes_no(question: str) -> bool: | |
| """Ask the user the question until the user inputs a valid answer.""" | |
| while True: | |
| try: | |
| print("{0} [y/n]".format(question)) | |
| return strtobool(input().lower()) | |
| except ValueError: | |
| pass | |
| def tuple_product(t: Tuple) -> Any: | |
| """Calculate the product of the tuple elements.""" | |
| result = 1 | |
| for v in t: | |
| result *= v | |
| return result | |
| _str_to_ctype = { | |
| "uint8": ctypes.c_ubyte, | |
| "uint16": ctypes.c_uint16, | |
| "uint32": ctypes.c_uint32, | |
| "uint64": ctypes.c_uint64, | |
| "int8": ctypes.c_byte, | |
| "int16": ctypes.c_int16, | |
| "int32": ctypes.c_int32, | |
| "int64": ctypes.c_int64, | |
| "float32": ctypes.c_float, | |
| "float64": ctypes.c_double | |
| } | |
| def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: | |
| """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" | |
| type_str = None | |
| if isinstance(type_obj, str): | |
| type_str = type_obj | |
| elif hasattr(type_obj, "__name__"): | |
| type_str = type_obj.__name__ | |
| elif hasattr(type_obj, "name"): | |
| type_str = type_obj.name | |
| else: | |
| raise RuntimeError("Cannot infer type name from input") | |
| assert type_str in _str_to_ctype.keys() | |
| my_dtype = np.dtype(type_str) | |
| my_ctype = _str_to_ctype[type_str] | |
| assert my_dtype.itemsize == ctypes.sizeof(my_ctype) | |
| return my_dtype, my_ctype | |
| def is_pickleable(obj: Any) -> bool: | |
| try: | |
| with io.BytesIO() as stream: | |
| pickle.dump(obj, stream) | |
| return True | |
| except: | |
| return False | |
| # Functionality to import modules/objects by name, and call functions by name | |
| # ------------------------------------------------------------------------------------------ | |
| def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: | |
| """Searches for the underlying module behind the name to some python object. | |
| Returns the module and the object name (original name with module part removed).""" | |
| # allow convenience shorthands, substitute them by full names | |
| obj_name = re.sub("^np.", "numpy.", obj_name) | |
| obj_name = re.sub("^tf.", "tensorflow.", obj_name) | |
| # list alternatives for (module_name, local_obj_name) | |
| parts = obj_name.split(".") | |
| name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) | |
| for i in range(len(parts), 0, -1)] | |
| # try each alternative in turn | |
| for module_name, local_obj_name in name_pairs: | |
| try: | |
| module = importlib.import_module( | |
| module_name) # may raise ImportError | |
| get_obj_from_module(module, | |
| local_obj_name) # may raise AttributeError | |
| return module, local_obj_name | |
| except: | |
| pass | |
| # maybe some of the modules themselves contain errors? | |
| for module_name, _local_obj_name in name_pairs: | |
| try: | |
| importlib.import_module(module_name) # may raise ImportError | |
| except ImportError: | |
| if not str(sys.exc_info()[1]).startswith("No module named '" + | |
| module_name + "'"): | |
| raise | |
| # maybe the requested attribute is missing? | |
| for module_name, local_obj_name in name_pairs: | |
| try: | |
| module = importlib.import_module( | |
| module_name) # may raise ImportError | |
| get_obj_from_module(module, | |
| local_obj_name) # may raise AttributeError | |
| except ImportError: | |
| pass | |
| # we are out of luck, but we have no idea why | |
| raise ImportError(obj_name) | |
| def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: | |
| """Traverses the object name and returns the last (rightmost) python object.""" | |
| if obj_name == '': | |
| return module | |
| obj = module | |
| for part in obj_name.split("."): | |
| obj = getattr(obj, part) | |
| return obj | |
| def get_obj_by_name(name: str) -> Any: | |
| """Finds the python object with the given name.""" | |
| module, obj_name = get_module_from_obj_name(name) | |
| return get_obj_from_module(module, obj_name) | |
| def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: | |
| """Finds the python object with the given name and calls it as a function.""" | |
| assert func_name is not None | |
| func_obj = get_obj_by_name(func_name) | |
| assert callable(func_obj) | |
| return func_obj(*args, **kwargs) | |
| def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: | |
| """Finds the python class with the given name and constructs it with the given arguments.""" | |
| return call_func_by_name(*args, func_name=class_name, **kwargs) | |
| def get_module_dir_by_obj_name(obj_name: str) -> str: | |
| """Get the directory path of the module containing the given object name.""" | |
| module, _ = get_module_from_obj_name(obj_name) | |
| return os.path.dirname(inspect.getfile(module)) | |
| def is_top_level_function(obj: Any) -> bool: | |
| """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" | |
| return callable(obj) and obj.__name__ in sys.modules[ | |
| obj.__module__].__dict__ | |
| def get_top_level_function_name(obj: Any) -> str: | |
| """Return the fully-qualified name of a top-level function.""" | |
| assert is_top_level_function(obj) | |
| module = obj.__module__ | |
| if module == '__main__': | |
| module = os.path.splitext( | |
| os.path.basename(sys.modules[module].__file__))[0] | |
| return module + "." + obj.__name__ | |
| # File system helpers | |
| # ------------------------------------------------------------------------------------------ | |
| def list_dir_recursively_with_ignore( | |
| dir_path: str, | |
| ignores: List[str] = None, | |
| add_base_to_relative: bool = False) -> List[Tuple[str, str]]: | |
| """List all files recursively in a given directory while ignoring given file and directory names. | |
| Returns list of tuples containing both absolute and relative paths.""" | |
| assert os.path.isdir(dir_path) | |
| base_name = os.path.basename(os.path.normpath(dir_path)) | |
| if ignores is None: | |
| ignores = [] | |
| result = [] | |
| for root, dirs, files in os.walk(dir_path, topdown=True): | |
| for ignore_ in ignores: | |
| dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] | |
| # dirs need to be edited in-place | |
| for d in dirs_to_remove: | |
| dirs.remove(d) | |
| files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] | |
| absolute_paths = [os.path.join(root, f) for f in files] | |
| relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] | |
| if add_base_to_relative: | |
| relative_paths = [ | |
| os.path.join(base_name, p) for p in relative_paths | |
| ] | |
| assert len(absolute_paths) == len(relative_paths) | |
| result += zip(absolute_paths, relative_paths) | |
| return result | |
| def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: | |
| """Takes in a list of tuples of (src, dst) paths and copies files. | |
| Will create all necessary directories.""" | |
| for file in files: | |
| target_dir_name = os.path.dirname(file[1]) | |
| # will create all intermediate-level directories | |
| if not os.path.exists(target_dir_name): | |
| os.makedirs(target_dir_name) | |
| shutil.copyfile(file[0], file[1]) | |
| # URL helpers | |
| # ------------------------------------------------------------------------------------------ | |
| def is_url(obj: Any, allow_file_urls: bool = False) -> bool: | |
| """Determine whether the given object is a valid URL string.""" | |
| if not isinstance(obj, str) or not "://" in obj: | |
| return False | |
| if allow_file_urls and obj.startswith('file://'): | |
| return True | |
| try: | |
| res = requests.compat.urlparse(obj) | |
| if not res.scheme or not res.netloc or not "." in res.netloc: | |
| return False | |
| res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) | |
| if not res.scheme or not res.netloc or not "." in res.netloc: | |
| return False | |
| except: | |
| return False | |
| return True | |
| def open_url(url: str, | |
| cache_dir: str = None, | |
| num_attempts: int = 10, | |
| verbose: bool = True, | |
| return_filename: bool = False, | |
| cache: bool = True) -> Any: | |
| """Download the given URL and return a binary-mode file object to access the data.""" | |
| assert num_attempts >= 1 | |
| assert not (return_filename and (not cache)) | |
| # Doesn't look like an URL scheme so interpret it as a local filename. | |
| if not re.match('^[a-z]+://', url): | |
| return url if return_filename else open(url, "rb") | |
| # Handle file URLs. This code handles unusual file:// patterns that | |
| # arise on Windows: | |
| # | |
| # file:///c:/foo.txt | |
| # | |
| # which would translate to a local '/c:/foo.txt' filename that's | |
| # invalid. Drop the forward slash for such pathnames. | |
| # | |
| # If you touch this code path, you should test it on both Linux and | |
| # Windows. | |
| # | |
| # Some internet resources suggest using urllib.request.url2pathname() but | |
| # but that converts forward slashes to backslashes and this causes | |
| # its own set of problems. | |
| if url.startswith('file://'): | |
| filename = urllib.parse.urlparse(url).path | |
| if re.match(r'^/[a-zA-Z]:', filename): | |
| filename = filename[1:] | |
| return filename if return_filename else open(filename, "rb") | |
| assert is_url(url) | |
| # Lookup from cache. | |
| if cache_dir is None: | |
| cache_dir = make_cache_dir_path('downloads') | |
| url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() | |
| if cache: | |
| cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) | |
| if len(cache_files) == 1: | |
| filename = cache_files[0] | |
| return filename if return_filename else open(filename, "rb") | |
| # Download. | |
| url_name = None | |
| url_data = None | |
| with requests.Session() as session: | |
| if verbose: | |
| print("Downloading %s ..." % url, end="", flush=True) | |
| for attempts_left in reversed(range(num_attempts)): | |
| try: | |
| with session.get(url) as res: | |
| res.raise_for_status() | |
| if len(res.content) == 0: | |
| raise IOError("No data received") | |
| if len(res.content) < 8192: | |
| content_str = res.content.decode("utf-8") | |
| if "download_warning" in res.headers.get( | |
| "Set-Cookie", ""): | |
| links = [ | |
| html.unescape(link) | |
| for link in content_str.split('"') | |
| if "export=download" in link | |
| ] | |
| if len(links) == 1: | |
| url = requests.compat.urljoin(url, links[0]) | |
| raise IOError("Google Drive virus checker nag") | |
| if "Google Drive - Quota exceeded" in content_str: | |
| raise IOError( | |
| "Google Drive download quota exceeded -- please try again later" | |
| ) | |
| match = re.search( | |
| r'filename="([^"]*)"', | |
| res.headers.get("Content-Disposition", "")) | |
| url_name = match[1] if match else url | |
| url_data = res.content | |
| if verbose: | |
| print(" done") | |
| break | |
| except KeyboardInterrupt: | |
| raise | |
| except: | |
| if not attempts_left: | |
| if verbose: | |
| print(" failed") | |
| raise | |
| if verbose: | |
| print(".", end="", flush=True) | |
| # Save to cache. | |
| if cache: | |
| safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) | |
| cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) | |
| temp_file = os.path.join( | |
| cache_dir, | |
| "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) | |
| os.makedirs(cache_dir, exist_ok=True) | |
| with open(temp_file, "wb") as f: | |
| f.write(url_data) | |
| os.replace(temp_file, cache_file) # atomic | |
| if return_filename: | |
| return cache_file | |
| # Return data as file object. | |
| assert not return_filename | |
| return io.BytesIO(url_data) | |
| class InfiniteSampler(torch.utils.data.Sampler): | |
| def __init__(self, | |
| dataset, | |
| rank=0, | |
| num_replicas=1, | |
| shuffle=True, | |
| seed=0, | |
| window_size=0.5): | |
| assert len(dataset) > 0 | |
| assert num_replicas > 0 | |
| assert 0 <= rank < num_replicas | |
| assert 0 <= window_size <= 1 | |
| super().__init__(dataset) | |
| self.dataset = dataset | |
| self.rank = rank | |
| self.num_replicas = num_replicas | |
| self.shuffle = shuffle | |
| self.seed = seed | |
| self.window_size = window_size | |
| def __iter__(self): | |
| order = np.arange(len(self.dataset)) | |
| rnd = None | |
| window = 0 | |
| if self.shuffle: | |
| rnd = np.random.RandomState(self.seed) | |
| rnd.shuffle(order) | |
| window = int(np.rint(order.size * self.window_size)) | |
| idx = 0 | |
| while True: | |
| i = idx % order.size | |
| if idx % self.num_replicas == self.rank: | |
| yield order[i] | |
| if window >= 2: | |
| j = (i - rnd.randint(window)) % order.size | |
| order[i], order[j] = order[j], order[i] | |
| idx += 1 | |
| def requires_grad(model, flag=True): | |
| for p in model.parameters(): | |
| p.requires_grad = flag | |