import functools import importlib import os from functools import partial from inspect import isfunction #import fsspec import numpy as np import torch from PIL import Image, ImageDraw, ImageFont from safetensors.torch import load_file as load_safetensors def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" return self def get_string_from_tuple(s): try: # Check if the string starts and ends with parentheses if s[0] == "(" and s[-1] == ")": # Convert the string to a tuple t = eval(s) # Check if the type of t is tuple if type(t) == tuple: return t[0] else: pass except: pass return s def is_power_of_two(n): """ chat.openai.com/chat Return True if n is a power of 2, otherwise return False. The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False. The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False. If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise. Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False. """ if n <= 0: return False return (n & (n - 1)) == 0 def autocast(f, enabled=True): def do_autocast(*args, **kwargs): with torch.cuda.amp.autocast( enabled=enabled, dtype=torch.get_autocast_gpu_dtype(), cache_enabled=torch.is_autocast_cache_enabled(), ): return f(*args, **kwargs) return do_autocast def load_partial_from_config(config): return partial(get_obj_from_str(config["target"]), **config.get("params", dict())) def log_txt_as_img(wh, xc, size=10): # wh a tuple of (width, height) # xc a list of captions to plot b = len(xc) txts = list() for bi in range(b): txt = Image.new("RGB", wh, color="white") draw = ImageDraw.Draw(txt) font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) nc = int(40 * (wh[0] / 256)) if isinstance(xc[bi], list): text_seq = xc[bi][0] else: text_seq = xc[bi] lines = "\n".join( text_seq[start : start + nc] for start in range(0, len(text_seq), nc) ) try: draw.text((0, 0), lines, fill="black", font=font) except UnicodeEncodeError: print("Cant encode string for logging. Skipping.") txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 txts.append(txt) txts = np.stack(txts) txts = torch.tensor(txts) return txts def partialclass(cls, *args, **kwargs): class NewCls(cls): __init__ = functools.partialmethod(cls.__init__, *args, **kwargs) return NewCls def make_path_absolute(path): fs, p = fsspec.core.url_to_fs(path) if fs.protocol == "file": return os.path.abspath(p) return path def ismap(x): if not isinstance(x, torch.Tensor): return False return (len(x.shape) == 4) and (x.shape[1] > 3) def isimage(x): if not isinstance(x, torch.Tensor): return False return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) def isheatmap(x): if not isinstance(x, torch.Tensor): return False return x.ndim == 2 def isneighbors(x): if not isinstance(x, torch.Tensor): return False return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1) def exists(x): return x is not None def expand_dims_like(x, y): while x.dim() != y.dim(): x = x.unsqueeze(-1) return x def default(val, d): if exists(val): return val return d() if isfunction(d) else d def mean_flat(tensor): """ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 Take the mean over all non-batch dimensions. """ return tensor.mean(dim=list(range(1, len(tensor.shape)))) def count_params(model, verbose=False): total_params = sum(p.numel() for p in model.parameters()) if verbose: print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") return total_params def instantiate_from_config(config): if not "target" in config: if config == "__is_first_stage__": return None elif config == "__is_unconditional__": return None raise KeyError("Expected key `target` to instantiate.") return get_obj_from_str(config["target"])(**config.get("params", dict())) def get_obj_from_str(string, reload=False, invalidate_cache=True): package_directory_name = os.path.basename(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) module, cls = string.rsplit(".", 1) if invalidate_cache: importlib.invalidate_caches() if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=package_directory_name), cls) def append_zero(x): return torch.cat([x, x.new_zeros([1])]) def append_dims(x, target_dims): """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" dims_to_append = target_dims - x.ndim if dims_to_append < 0: raise ValueError( f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" ) return x[(...,) + (None,) * dims_to_append] def load_model_from_config(config, ckpt, verbose=True, freeze=True): print(f"Loading model from {ckpt}") if ckpt.endswith("ckpt"): pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] elif ckpt.endswith("safetensors"): sd = load_safetensors(ckpt) else: raise NotImplementedError model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: print("missing keys:") print(m) if len(u) > 0 and verbose: print("unexpected keys:") print(u) if freeze: for param in model.parameters(): param.requires_grad = False model.eval() return model def get_configs_path() -> str: """ Get the `configs` directory. For a working copy, this is the one in the root of the repository, but for an installed copy, it's in the `sgm` package (see pyproject.toml). """ this_dir = os.path.dirname(__file__) candidates = ( os.path.join(this_dir, "configs"), os.path.join(this_dir, "..", "configs"), ) for candidate in candidates: candidate = os.path.abspath(candidate) if os.path.isdir(candidate): return candidate raise FileNotFoundError(f"Could not find SGM configs in {candidates}")