Spaces:
Sleeping
Sleeping
| import logging | |
| import sys | |
| import threading | |
| from typing import * | |
| import json | |
| import struct | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from diffusers import EulerAncestralDiscreteScheduler | |
| import diffusers.schedulers.scheduling_euler_ancestral_discrete | |
| from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput | |
| import cv2 | |
| from PIL import Image | |
| import numpy as np | |
| from safetensors.torch import load_file | |
| def fire_in_thread(f, *args, **kwargs): | |
| threading.Thread(target=f, args=args, kwargs=kwargs).start() | |
| # region Logging | |
| def add_logging_arguments(parser): | |
| parser.add_argument( | |
| "--console_log_level", | |
| type=str, | |
| default=None, | |
| choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], | |
| help="Set the logging level, default is INFO / ログレベルを設定する。デフォルトはINFO", | |
| ) | |
| parser.add_argument( | |
| "--console_log_file", | |
| type=str, | |
| default=None, | |
| help="Log to a file instead of stderr / 標準エラー出力ではなくファイルにログを出力する", | |
| ) | |
| parser.add_argument("--console_log_simple", action="store_true", help="Simple log output / シンプルなログ出力") | |
| def setup_logging(args=None, log_level=None, reset=False): | |
| if logging.root.handlers: | |
| if reset: | |
| # remove all handlers | |
| for handler in logging.root.handlers[:]: | |
| logging.root.removeHandler(handler) | |
| else: | |
| return | |
| # log_level can be set by the caller or by the args, the caller has priority. If not set, use INFO | |
| if log_level is None and args is not None: | |
| log_level = args.console_log_level | |
| if log_level is None: | |
| log_level = "INFO" | |
| log_level = getattr(logging, log_level) | |
| msg_init = None | |
| if args is not None and args.console_log_file: | |
| handler = logging.FileHandler(args.console_log_file, mode="w") | |
| else: | |
| handler = None | |
| if not args or not args.console_log_simple: | |
| try: | |
| from rich.logging import RichHandler | |
| from rich.console import Console | |
| from rich.logging import RichHandler | |
| handler = RichHandler(console=Console(stderr=True)) | |
| except ImportError: | |
| # print("rich is not installed, using basic logging") | |
| msg_init = "rich is not installed, using basic logging" | |
| if handler is None: | |
| handler = logging.StreamHandler(sys.stdout) # same as print | |
| handler.propagate = False | |
| formatter = logging.Formatter( | |
| fmt="%(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S", | |
| ) | |
| handler.setFormatter(formatter) | |
| logging.root.setLevel(log_level) | |
| logging.root.addHandler(handler) | |
| if msg_init is not None: | |
| logger = logging.getLogger(__name__) | |
| logger.info(msg_init) | |
| # endregion | |
| # region PyTorch utils | |
| def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): | |
| assert layer_to_cpu.__class__ == layer_to_cuda.__class__ | |
| weight_swap_jobs = [] | |
| for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): | |
| if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: | |
| weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) | |
| torch.cuda.current_stream().synchronize() # this prevents the illegal loss value | |
| stream = torch.cuda.Stream() | |
| with torch.cuda.stream(stream): | |
| # cuda to cpu | |
| for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: | |
| cuda_data_view.record_stream(stream) | |
| module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) | |
| stream.synchronize() | |
| # cpu to cuda | |
| for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: | |
| cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) | |
| module_to_cuda.weight.data = cuda_data_view | |
| stream.synchronize() | |
| torch.cuda.current_stream().synchronize() # this prevents the illegal loss value | |
| def weighs_to_device(layer: nn.Module, device: torch.device): | |
| for module in layer.modules(): | |
| if hasattr(module, "weight") and module.weight is not None: | |
| module.weight.data = module.weight.data.to(device, non_blocking=True) | |
| def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype: | |
| """ | |
| Convert a string to a torch.dtype | |
| Args: | |
| s: string representation of the dtype | |
| default_dtype: default dtype to return if s is None | |
| Returns: | |
| torch.dtype: the corresponding torch.dtype | |
| Raises: | |
| ValueError: if the dtype is not supported | |
| Examples: | |
| >>> str_to_dtype("float32") | |
| torch.float32 | |
| >>> str_to_dtype("fp32") | |
| torch.float32 | |
| >>> str_to_dtype("float16") | |
| torch.float16 | |
| >>> str_to_dtype("fp16") | |
| torch.float16 | |
| >>> str_to_dtype("bfloat16") | |
| torch.bfloat16 | |
| >>> str_to_dtype("bf16") | |
| torch.bfloat16 | |
| >>> str_to_dtype("fp8") | |
| torch.float8_e4m3fn | |
| >>> str_to_dtype("fp8_e4m3fn") | |
| torch.float8_e4m3fn | |
| >>> str_to_dtype("fp8_e4m3fnuz") | |
| torch.float8_e4m3fnuz | |
| >>> str_to_dtype("fp8_e5m2") | |
| torch.float8_e5m2 | |
| >>> str_to_dtype("fp8_e5m2fnuz") | |
| torch.float8_e5m2fnuz | |
| """ | |
| if s is None: | |
| return default_dtype | |
| if s in ["bf16", "bfloat16"]: | |
| return torch.bfloat16 | |
| elif s in ["fp16", "float16"]: | |
| return torch.float16 | |
| elif s in ["fp32", "float32", "float"]: | |
| return torch.float32 | |
| elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]: | |
| return torch.float8_e4m3fn | |
| elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]: | |
| return torch.float8_e4m3fnuz | |
| elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]: | |
| return torch.float8_e5m2 | |
| elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]: | |
| return torch.float8_e5m2fnuz | |
| elif s in ["fp8", "float8"]: | |
| return torch.float8_e4m3fn # default fp8 | |
| else: | |
| raise ValueError(f"Unsupported dtype: {s}") | |
| def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None): | |
| """ | |
| memory efficient save file | |
| """ | |
| _TYPES = { | |
| torch.float64: "F64", | |
| torch.float32: "F32", | |
| torch.float16: "F16", | |
| torch.bfloat16: "BF16", | |
| torch.int64: "I64", | |
| torch.int32: "I32", | |
| torch.int16: "I16", | |
| torch.int8: "I8", | |
| torch.uint8: "U8", | |
| torch.bool: "BOOL", | |
| getattr(torch, "float8_e5m2", None): "F8_E5M2", | |
| getattr(torch, "float8_e4m3fn", None): "F8_E4M3", | |
| } | |
| _ALIGN = 256 | |
| def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]: | |
| validated = {} | |
| for key, value in metadata.items(): | |
| if not isinstance(key, str): | |
| raise ValueError(f"Metadata key must be a string, got {type(key)}") | |
| if not isinstance(value, str): | |
| print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.") | |
| validated[key] = str(value) | |
| else: | |
| validated[key] = value | |
| return validated | |
| print(f"Using memory efficient save file: {filename}") | |
| header = {} | |
| offset = 0 | |
| if metadata: | |
| header["__metadata__"] = validate_metadata(metadata) | |
| for k, v in tensors.items(): | |
| if v.numel() == 0: # empty tensor | |
| header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]} | |
| else: | |
| size = v.numel() * v.element_size() | |
| header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]} | |
| offset += size | |
| hjson = json.dumps(header).encode("utf-8") | |
| hjson += b" " * (-(len(hjson) + 8) % _ALIGN) | |
| with open(filename, "wb") as f: | |
| f.write(struct.pack("<Q", len(hjson))) | |
| f.write(hjson) | |
| for k, v in tensors.items(): | |
| if v.numel() == 0: | |
| continue | |
| if v.is_cuda: | |
| # Direct GPU to disk save | |
| with torch.cuda.device(v.device): | |
| if v.dim() == 0: # if scalar, need to add a dimension to work with view | |
| v = v.unsqueeze(0) | |
| tensor_bytes = v.contiguous().view(torch.uint8) | |
| tensor_bytes.cpu().numpy().tofile(f) | |
| else: | |
| # CPU tensor save | |
| if v.dim() == 0: # if scalar, need to add a dimension to work with view | |
| v = v.unsqueeze(0) | |
| v.contiguous().view(torch.uint8).numpy().tofile(f) | |
| class MemoryEfficientSafeOpen: | |
| # does not support metadata loading | |
| def __init__(self, filename): | |
| self.filename = filename | |
| self.header, self.header_size = self._read_header() | |
| self.file = open(filename, "rb") | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| self.file.close() | |
| def keys(self): | |
| return [k for k in self.header.keys() if k != "__metadata__"] | |
| def get_tensor(self, key): | |
| if key not in self.header: | |
| raise KeyError(f"Tensor '{key}' not found in the file") | |
| metadata = self.header[key] | |
| offset_start, offset_end = metadata["data_offsets"] | |
| if offset_start == offset_end: | |
| tensor_bytes = None | |
| else: | |
| # adjust offset by header size | |
| self.file.seek(self.header_size + 8 + offset_start) | |
| tensor_bytes = self.file.read(offset_end - offset_start) | |
| return self._deserialize_tensor(tensor_bytes, metadata) | |
| def _read_header(self): | |
| with open(self.filename, "rb") as f: | |
| header_size = struct.unpack("<Q", f.read(8))[0] | |
| header_json = f.read(header_size).decode("utf-8") | |
| return json.loads(header_json), header_size | |
| def _deserialize_tensor(self, tensor_bytes, metadata): | |
| dtype = self._get_torch_dtype(metadata["dtype"]) | |
| shape = metadata["shape"] | |
| if tensor_bytes is None: | |
| byte_tensor = torch.empty(0, dtype=torch.uint8) | |
| else: | |
| tensor_bytes = bytearray(tensor_bytes) # make it writable | |
| byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8) | |
| # process float8 types | |
| if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]: | |
| return self._convert_float8(byte_tensor, metadata["dtype"], shape) | |
| # convert to the target dtype and reshape | |
| return byte_tensor.view(dtype).reshape(shape) | |
| def _get_torch_dtype(dtype_str): | |
| dtype_map = { | |
| "F64": torch.float64, | |
| "F32": torch.float32, | |
| "F16": torch.float16, | |
| "BF16": torch.bfloat16, | |
| "I64": torch.int64, | |
| "I32": torch.int32, | |
| "I16": torch.int16, | |
| "I8": torch.int8, | |
| "U8": torch.uint8, | |
| "BOOL": torch.bool, | |
| } | |
| # add float8 types if available | |
| if hasattr(torch, "float8_e5m2"): | |
| dtype_map["F8_E5M2"] = torch.float8_e5m2 | |
| if hasattr(torch, "float8_e4m3fn"): | |
| dtype_map["F8_E4M3"] = torch.float8_e4m3fn | |
| return dtype_map.get(dtype_str) | |
| def _convert_float8(byte_tensor, dtype_str, shape): | |
| if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"): | |
| return byte_tensor.view(torch.float8_e5m2).reshape(shape) | |
| elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"): | |
| return byte_tensor.view(torch.float8_e4m3fn).reshape(shape) | |
| else: | |
| # # convert to float16 if float8 is not supported | |
| # print(f"Warning: {dtype_str} is not supported in this PyTorch version. Converting to float16.") | |
| # return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape) | |
| raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)") | |
| def load_safetensors( | |
| path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32 | |
| ) -> dict[str, torch.Tensor]: | |
| if disable_mmap: | |
| # return safetensors.torch.load(open(path, "rb").read()) | |
| # use experimental loader | |
| # logger.info(f"Loading without mmap (experimental)") | |
| state_dict = {} | |
| with MemoryEfficientSafeOpen(path) as f: | |
| for key in f.keys(): | |
| state_dict[key] = f.get_tensor(key).to(device, dtype=dtype) | |
| return state_dict | |
| else: | |
| try: | |
| state_dict = load_file(path, device=device) | |
| except: | |
| state_dict = load_file(path) # prevent device invalid Error | |
| if dtype is not None: | |
| for key in state_dict.keys(): | |
| state_dict[key] = state_dict[key].to(dtype=dtype) | |
| return state_dict | |
| # endregion | |
| # region Image utils | |
| def pil_resize(image, size, interpolation=Image.LANCZOS): | |
| has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False | |
| if has_alpha: | |
| pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)) | |
| else: | |
| pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) | |
| resized_pil = pil_image.resize(size, interpolation) | |
| # Convert back to cv2 format | |
| if has_alpha: | |
| resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA) | |
| else: | |
| resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) | |
| return resized_cv2 | |
| # endregion | |
| # TODO make inf_utils.py | |
| # region Gradual Latent hires fix | |
| class GradualLatent: | |
| def __init__( | |
| self, | |
| ratio, | |
| start_timesteps, | |
| every_n_steps, | |
| ratio_step, | |
| s_noise=1.0, | |
| gaussian_blur_ksize=None, | |
| gaussian_blur_sigma=0.5, | |
| gaussian_blur_strength=0.5, | |
| unsharp_target_x=True, | |
| ): | |
| self.ratio = ratio | |
| self.start_timesteps = start_timesteps | |
| self.every_n_steps = every_n_steps | |
| self.ratio_step = ratio_step | |
| self.s_noise = s_noise | |
| self.gaussian_blur_ksize = gaussian_blur_ksize | |
| self.gaussian_blur_sigma = gaussian_blur_sigma | |
| self.gaussian_blur_strength = gaussian_blur_strength | |
| self.unsharp_target_x = unsharp_target_x | |
| def __str__(self) -> str: | |
| return ( | |
| f"GradualLatent(ratio={self.ratio}, start_timesteps={self.start_timesteps}, " | |
| + f"every_n_steps={self.every_n_steps}, ratio_step={self.ratio_step}, s_noise={self.s_noise}, " | |
| + f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength}, " | |
| + f"unsharp_target_x={self.unsharp_target_x})" | |
| ) | |
| def apply_unshark_mask(self, x: torch.Tensor): | |
| if self.gaussian_blur_ksize is None: | |
| return x | |
| blurred = transforms.functional.gaussian_blur(x, self.gaussian_blur_ksize, self.gaussian_blur_sigma) | |
| # mask = torch.sigmoid((x - blurred) * self.gaussian_blur_strength) | |
| mask = (x - blurred) * self.gaussian_blur_strength | |
| sharpened = x + mask | |
| return sharpened | |
| def interpolate(self, x: torch.Tensor, resized_size, unsharp=True): | |
| org_dtype = x.dtype | |
| if org_dtype == torch.bfloat16: | |
| x = x.float() | |
| x = torch.nn.functional.interpolate(x, size=resized_size, mode="bicubic", align_corners=False).to(dtype=org_dtype) | |
| # apply unsharp mask / アンシャープマスクを適用する | |
| if unsharp and self.gaussian_blur_ksize: | |
| x = self.apply_unshark_mask(x) | |
| return x | |
| class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.resized_size = None | |
| self.gradual_latent = None | |
| def set_gradual_latent_params(self, size, gradual_latent: GradualLatent): | |
| self.resized_size = size | |
| self.gradual_latent = gradual_latent | |
| def step( | |
| self, | |
| model_output: torch.FloatTensor, | |
| timestep: Union[float, torch.FloatTensor], | |
| sample: torch.FloatTensor, | |
| generator: Optional[torch.Generator] = None, | |
| return_dict: bool = True, | |
| ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: | |
| """ | |
| Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion | |
| process from the learned model outputs (most often the predicted noise). | |
| Args: | |
| model_output (`torch.FloatTensor`): | |
| The direct output from learned diffusion model. | |
| timestep (`float`): | |
| The current discrete timestep in the diffusion chain. | |
| sample (`torch.FloatTensor`): | |
| A current instance of a sample created by the diffusion process. | |
| generator (`torch.Generator`, *optional*): | |
| A random number generator. | |
| return_dict (`bool`): | |
| Whether or not to return a | |
| [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple. | |
| Returns: | |
| [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: | |
| If return_dict is `True`, | |
| [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned, | |
| otherwise a tuple is returned where the first element is the sample tensor. | |
| """ | |
| if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor): | |
| raise ValueError( | |
| ( | |
| "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" | |
| " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" | |
| " one of the `scheduler.timesteps` as a timestep." | |
| ), | |
| ) | |
| if not self.is_scale_input_called: | |
| # logger.warning( | |
| print( | |
| "The `scale_model_input` function should be called before `step` to ensure correct denoising. " | |
| "See `StableDiffusionPipeline` for a usage example." | |
| ) | |
| if self.step_index is None: | |
| self._init_step_index(timestep) | |
| sigma = self.sigmas[self.step_index] | |
| # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise | |
| if self.config.prediction_type == "epsilon": | |
| pred_original_sample = sample - sigma * model_output | |
| elif self.config.prediction_type == "v_prediction": | |
| # * c_out + input * c_skip | |
| pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) | |
| elif self.config.prediction_type == "sample": | |
| raise NotImplementedError("prediction_type not implemented yet: sample") | |
| else: | |
| raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`") | |
| sigma_from = self.sigmas[self.step_index] | |
| sigma_to = self.sigmas[self.step_index + 1] | |
| sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 | |
| sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 | |
| # 2. Convert to an ODE derivative | |
| derivative = (sample - pred_original_sample) / sigma | |
| dt = sigma_down - sigma | |
| device = model_output.device | |
| if self.resized_size is None: | |
| prev_sample = sample + derivative * dt | |
| noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( | |
| model_output.shape, dtype=model_output.dtype, device=device, generator=generator | |
| ) | |
| s_noise = 1.0 | |
| else: | |
| print("resized_size", self.resized_size, "model_output.shape", model_output.shape, "sample.shape", sample.shape) | |
| s_noise = self.gradual_latent.s_noise | |
| if self.gradual_latent.unsharp_target_x: | |
| prev_sample = sample + derivative * dt | |
| prev_sample = self.gradual_latent.interpolate(prev_sample, self.resized_size) | |
| else: | |
| sample = self.gradual_latent.interpolate(sample, self.resized_size) | |
| derivative = self.gradual_latent.interpolate(derivative, self.resized_size, unsharp=False) | |
| prev_sample = sample + derivative * dt | |
| noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( | |
| (model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]), | |
| dtype=model_output.dtype, | |
| device=device, | |
| generator=generator, | |
| ) | |
| prev_sample = prev_sample + noise * sigma_up * s_noise | |
| # upon completion increase step index by one | |
| self._step_index += 1 | |
| if not return_dict: | |
| return (prev_sample,) | |
| return EulerAncestralDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) | |
| # endregion | |