test / SUPIR /util.py
quantumiracle
fix
ccfcf8d
raw
history blame
6.42 kB
import os
import torch
import numpy as np
import cv2
from PIL import Image
from torch.nn.functional import interpolate
from omegaconf import OmegaConf
from sgm.util import instantiate_from_config
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file as load_safetensors
def get_state_dict(d):
return d.get('state_dict', d)
def load_state_dict(ckpt_path, location='cpu'):
_, extension = os.path.splitext(ckpt_path)
if extension.lower() == ".safetensors":
import safetensors.torch
state_dict = safetensors.torch.load_file(ckpt_path, device=location)
else:
state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
state_dict = get_state_dict(state_dict)
print(f'Loaded state_dict from [{ckpt_path}]')
return state_dict
def create_model(config_path):
config = OmegaConf.load(config_path)
model = instantiate_from_config(config.model).cpu()
print(f'Loaded model config from [{config_path}]')
return model
def resolve_ckpt_path(path_or_hub):
path_or_hub = path_or_hub.strip()
if os.path.exists(path_or_hub):
return path_or_hub
if "/" in path_or_hub and path_or_hub.endswith((".ckpt", ".pt", ".bin", ".safetensors")):
parts = path_or_hub.split("/")
repo_id = "/".join(parts[:2])
filename = "/".join(parts[2:])
return hf_hub_download(repo_id=repo_id, filename=filename)
raise FileNotFoundError(f"Could not resolve checkpoint path: {path_or_hub}")
def load_checkpoint(path):
if path.endswith(".safetensors"):
return load_safetensors(path, device='cpu')
else:
return torch.load(path, map_location='cpu')
def create_SUPIR_model(config_path, SUPIR_sign=None, load_default_setting=False):
config = OmegaConf.load(config_path)
model = instantiate_from_config(config.model).cpu()
print(f'Loaded model config from [{config_path}]')
if config.get("SDXL_CKPT"):
path = resolve_ckpt_path(config.SDXL_CKPT)
model.load_state_dict(load_checkpoint(path), strict=False)
if config.get("SUPIR_CKPT"):
path = resolve_ckpt_path(config.SUPIR_CKPT)
model.load_state_dict(load_checkpoint(path), strict=False)
if SUPIR_sign is not None:
assert SUPIR_sign in ['F', 'Q'], "SUPIR_sign must be 'F' or 'Q'"
key = f"SUPIR_CKPT_{SUPIR_sign}"
if config.get(key):
path = resolve_ckpt_path(config[key])
model.load_state_dict(load_checkpoint(path), strict=False)
if load_default_setting:
return model, config.default_setting
return model
def load_QF_ckpt(config_path):
config = OmegaConf.load(config_path)
ckpt_F = load_checkpoint(resolve_ckpt_path(config.SUPIR_CKPT_F))
ckpt_Q = load_checkpoint(resolve_ckpt_path(config.SUPIR_CKPT_Q))
return ckpt_Q, ckpt_F
def PIL2Tensor(img, upsacle=1, min_size=1024, fix_resize=None):
'''
PIL.Image -> Tensor[C, H, W], RGB, [-1, 1]
'''
# size
w, h = img.size
w *= upsacle
h *= upsacle
w0, h0 = round(w), round(h)
if min(w, h) < min_size:
_upsacle = min_size / min(w, h)
w *= _upsacle
h *= _upsacle
if fix_resize is not None:
_upsacle = fix_resize / min(w, h)
w *= _upsacle
h *= _upsacle
w0, h0 = round(w), round(h)
w = int(np.round(w / 64.0)) * 64
h = int(np.round(h / 64.0)) * 64
x = img.resize((w, h), Image.BICUBIC)
x = np.array(x).round().clip(0, 255).astype(np.uint8)
x = x / 255 * 2 - 1
x = torch.tensor(x, dtype=torch.float32).permute(2, 0, 1)
return x, h0, w0
def Tensor2PIL(x, h0, w0):
'''
Tensor[C, H, W], RGB, [-1, 1] -> PIL.Image
'''
x = x.unsqueeze(0)
x = interpolate(x, size=(h0, w0), mode='bicubic')
x = (x.squeeze(0).permute(1, 2, 0) * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
return Image.fromarray(x)
def HWC3(x):
assert x.dtype == np.uint8
if x.ndim == 2:
x = x[:, :, None]
assert x.ndim == 3
H, W, C = x.shape
assert C == 1 or C == 3 or C == 4
if C == 3:
return x
if C == 1:
return np.concatenate([x, x, x], axis=2)
if C == 4:
color = x[:, :, 0:3].astype(np.float32)
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
y = color * alpha + 255.0 * (1.0 - alpha)
y = y.clip(0, 255).astype(np.uint8)
return y
def upscale_image(input_image, upscale, min_size=None, unit_resolution=64):
H, W, C = input_image.shape
H = float(H)
W = float(W)
H *= upscale
W *= upscale
if min_size is not None:
if min(H, W) < min_size:
_upsacle = min_size / min(W, H)
W *= _upsacle
H *= _upsacle
H = int(np.round(H / unit_resolution)) * unit_resolution
W = int(np.round(W / unit_resolution)) * unit_resolution
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if upscale > 1 else cv2.INTER_AREA)
img = img.round().clip(0, 255).astype(np.uint8)
return img
def fix_resize(input_image, size=512, unit_resolution=64):
H, W, C = input_image.shape
H = float(H)
W = float(W)
upscale = size / min(H, W)
H *= upscale
W *= upscale
H = int(np.round(H / unit_resolution)) * unit_resolution
W = int(np.round(W / unit_resolution)) * unit_resolution
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if upscale > 1 else cv2.INTER_AREA)
img = img.round().clip(0, 255).astype(np.uint8)
return img
def Numpy2Tensor(img):
'''
np.array[H, w, C] [0, 255] -> Tensor[C, H, W], RGB, [-1, 1]
'''
# size
img = np.array(img) / 255 * 2 - 1
img = torch.tensor(img, dtype=torch.float32).permute(2, 0, 1)
return img
def Tensor2Numpy(x, h0=None, w0=None):
'''
Tensor[C, H, W], RGB, [-1, 1] -> PIL.Image
'''
if h0 is not None and w0 is not None:
x = x.unsqueeze(0)
x = interpolate(x, size=(h0, w0), mode='bicubic')
x = x.squeeze(0)
x = (x.permute(1, 2, 0) * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
return x
def convert_dtype(dtype_str):
if dtype_str == 'fp32':
return torch.float32
elif dtype_str == 'fp16':
return torch.float16
elif dtype_str == 'bf16':
return torch.bfloat16
else:
raise NotImplementedError