diff --git a/CKPT_PTH.py b/CKPT_PTH.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee41e72382556902e1c0f82ad110881c4935bcf4
--- /dev/null
+++ b/CKPT_PTH.py
@@ -0,0 +1,2 @@
+SDXL_CLIP1_PATH = 'openai/clip-vit-large-patch14'
+SDXL_CLIP2_CKPT_PTH = 'laion_CLIP-ViT-bigG-14-laion2B-39B-b160k/open_clip_pytorch_model.bin'
\ No newline at end of file
diff --git a/Examples/Example1.png b/Examples/Example1.png
new file mode 100644
index 0000000000000000000000000000000000000000..8f13a571f0d86c4efd62eabe26370ec18d27d347
Binary files /dev/null and b/Examples/Example1.png differ
diff --git a/Examples/Example2.jpeg b/Examples/Example2.jpeg
new file mode 100644
index 0000000000000000000000000000000000000000..1e4d6a6a94864419a8859b4836aa108124967325
Binary files /dev/null and b/Examples/Example2.jpeg differ
diff --git a/Examples/Example3.webp b/Examples/Example3.webp
new file mode 100644
index 0000000000000000000000000000000000000000..a6bbe8f4ef97ca45ef821abde7fa1eb7c2e04473
Binary files /dev/null and b/Examples/Example3.webp differ
diff --git a/README.md b/README.md
index 55b8fccd3a5f495741d055eb55338be5539650d5..5f3b080c09c2d3054e8fd113d71cc835eccdf276 100644
--- a/README.md
+++ b/README.md
@@ -1,16 +1,21 @@
----
-title: Face Real ESRGAN 2x 4x 8x
-emoji: 😻
-colorFrom: green
-colorTo: gray
-sdk: gradio
-sdk_version: 5.16.0
-python_version: 3.11.11
-app_file: app.py
-pinned: true
-license: apache-2.0
-models:
-- ai-forever/Real-ESRGAN
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
\ No newline at end of file
+---
+title: SUPIR Image Upscaler
+sdk: gradio
+emoji: 📷
+sdk_version: 4.38.1
+app_file: app.py
+license: mit
+colorFrom: blue
+colorTo: pink
+tags:
+ - Upscaling
+ - Restoring
+ - Image-to-Image
+ - Image-2-Image
+ - Img-to-Img
+ - Img-2-Img
+ - language models
+ - LLMs
+short_description: Restore blurred or small images with prompt
+suggested_hardware: zero-a10g
+---
\ No newline at end of file
diff --git a/SUPIR/__init__.py b/SUPIR/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SUPIR/models/SUPIR_model.py b/SUPIR/models/SUPIR_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..d683da7f908f1f2b49a19427d1763af530c3e18f
--- /dev/null
+++ b/SUPIR/models/SUPIR_model.py
@@ -0,0 +1,195 @@
+import torch
+from sgm.models.diffusion import DiffusionEngine
+from sgm.util import instantiate_from_config
+import copy
+from sgm.modules.distributions.distributions import DiagonalGaussianDistribution
+import random
+from SUPIR.utils.colorfix import wavelet_reconstruction, adaptive_instance_normalization
+from pytorch_lightning import seed_everything
+from torch.nn.functional import interpolate
+from SUPIR.utils.tilevae import VAEHook
+
+class SUPIRModel(DiffusionEngine):
+ def __init__(self, control_stage_config, ae_dtype='fp32', diffusion_dtype='fp32', p_p='', n_p='', *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ control_model = instantiate_from_config(control_stage_config)
+ self.model.load_control_model(control_model)
+ self.first_stage_model.denoise_encoder = copy.deepcopy(self.first_stage_model.encoder)
+ self.sampler_config = kwargs['sampler_config']
+
+ assert (ae_dtype in ['fp32', 'fp16', 'bf16']) and (diffusion_dtype in ['fp32', 'fp16', 'bf16'])
+ if ae_dtype == 'fp32':
+ ae_dtype = torch.float32
+ elif ae_dtype == 'fp16':
+ raise RuntimeError('fp16 cause NaN in AE')
+ elif ae_dtype == 'bf16':
+ ae_dtype = torch.bfloat16
+
+ if diffusion_dtype == 'fp32':
+ diffusion_dtype = torch.float32
+ elif diffusion_dtype == 'fp16':
+ diffusion_dtype = torch.float16
+ elif diffusion_dtype == 'bf16':
+ diffusion_dtype = torch.bfloat16
+
+ self.ae_dtype = ae_dtype
+ self.model.dtype = diffusion_dtype
+
+ self.p_p = p_p
+ self.n_p = n_p
+
+ @torch.no_grad()
+ def encode_first_stage(self, x):
+ with torch.autocast("cuda", dtype=self.ae_dtype):
+ z = self.first_stage_model.encode(x)
+ z = self.scale_factor * z
+ return z
+
+ @torch.no_grad()
+ def encode_first_stage_with_denoise(self, x, use_sample=True, is_stage1=False):
+ with torch.autocast("cuda", dtype=self.ae_dtype):
+ if is_stage1:
+ h = self.first_stage_model.denoise_encoder_s1(x)
+ else:
+ h = self.first_stage_model.denoise_encoder(x)
+ moments = self.first_stage_model.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ if use_sample:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ z = self.scale_factor * z
+ return z
+
+ @torch.no_grad()
+ def decode_first_stage(self, z):
+ z = 1.0 / self.scale_factor * z
+ with torch.autocast("cuda", dtype=self.ae_dtype):
+ out = self.first_stage_model.decode(z)
+ return out.float()
+
+ @torch.no_grad()
+ def batchify_denoise(self, x, is_stage1=False):
+ '''
+ [N, C, H, W], [-1, 1], RGB
+ '''
+ x = self.encode_first_stage_with_denoise(x, use_sample=False, is_stage1=is_stage1)
+ return self.decode_first_stage(x)
+
+ @torch.no_grad()
+ def batchify_sample(self, x, p, p_p='default', n_p='default', num_steps=100, restoration_scale=4.0, s_churn=0, s_noise=1.003, cfg_scale=4.0, seed=-1,
+ num_samples=1, control_scale=1, color_fix_type='None', use_linear_CFG=False, use_linear_control_scale=False,
+ cfg_scale_start=1.0, control_scale_start=0.0, **kwargs):
+ '''
+ [N, C], [-1, 1], RGB
+ '''
+ assert len(x) == len(p)
+ assert color_fix_type in ['Wavelet', 'AdaIn', 'None']
+
+ N = len(x)
+ if num_samples > 1:
+ assert N == 1
+ N = num_samples
+ x = x.repeat(N, 1, 1, 1)
+ p = p * N
+
+ if p_p == 'default':
+ p_p = self.p_p
+ if n_p == 'default':
+ n_p = self.n_p
+
+ self.sampler_config.params.num_steps = num_steps
+ if use_linear_CFG:
+ self.sampler_config.params.guider_config.params.scale_min = cfg_scale
+ self.sampler_config.params.guider_config.params.scale = cfg_scale_start
+ else:
+ self.sampler_config.params.guider_config.params.scale_min = cfg_scale
+ self.sampler_config.params.guider_config.params.scale = cfg_scale
+ self.sampler_config.params.restore_cfg = restoration_scale
+ self.sampler_config.params.s_churn = s_churn
+ self.sampler_config.params.s_noise = s_noise
+ self.sampler = instantiate_from_config(self.sampler_config)
+
+ if seed == -1:
+ seed = random.randint(0, 65535)
+ seed_everything(seed)
+
+ _z = self.encode_first_stage_with_denoise(x, use_sample=False)
+ x_stage1 = self.decode_first_stage(_z)
+ z_stage1 = self.encode_first_stage(x_stage1)
+
+ c, uc = self.prepare_condition(_z, p, p_p, n_p, N)
+
+ denoiser = lambda input, sigma, c, control_scale: self.denoiser(
+ self.model, input, sigma, c, control_scale, **kwargs
+ )
+
+ noised_z = torch.randn_like(_z).to(_z.device)
+
+ _samples = self.sampler(denoiser, noised_z, cond=c, uc=uc, x_center=z_stage1, control_scale=control_scale,
+ use_linear_control_scale=use_linear_control_scale, control_scale_start=control_scale_start)
+ samples = self.decode_first_stage(_samples)
+ if color_fix_type == 'Wavelet':
+ samples = wavelet_reconstruction(samples, x_stage1)
+ elif color_fix_type == 'AdaIn':
+ samples = adaptive_instance_normalization(samples, x_stage1)
+ return samples
+
+ def init_tile_vae(self, encoder_tile_size=512, decoder_tile_size=64):
+ self.first_stage_model.denoise_encoder.original_forward = self.first_stage_model.denoise_encoder.forward
+ self.first_stage_model.encoder.original_forward = self.first_stage_model.encoder.forward
+ self.first_stage_model.decoder.original_forward = self.first_stage_model.decoder.forward
+ self.first_stage_model.denoise_encoder.forward = VAEHook(
+ self.first_stage_model.denoise_encoder, encoder_tile_size, is_decoder=False, fast_decoder=False,
+ fast_encoder=False, color_fix=False, to_gpu=True)
+ self.first_stage_model.encoder.forward = VAEHook(
+ self.first_stage_model.encoder, encoder_tile_size, is_decoder=False, fast_decoder=False,
+ fast_encoder=False, color_fix=False, to_gpu=True)
+ self.first_stage_model.decoder.forward = VAEHook(
+ self.first_stage_model.decoder, decoder_tile_size, is_decoder=True, fast_decoder=False,
+ fast_encoder=False, color_fix=False, to_gpu=True)
+
+ def prepare_condition(self, _z, p, p_p, n_p, N):
+ batch = {}
+ batch['original_size_as_tuple'] = torch.tensor([1024, 1024]).repeat(N, 1).to(_z.device)
+ batch['crop_coords_top_left'] = torch.tensor([0, 0]).repeat(N, 1).to(_z.device)
+ batch['target_size_as_tuple'] = torch.tensor([1024, 1024]).repeat(N, 1).to(_z.device)
+ batch['aesthetic_score'] = torch.tensor([9.0]).repeat(N, 1).to(_z.device)
+ batch['control'] = _z
+
+ batch_uc = copy.deepcopy(batch)
+ batch_uc['txt'] = [n_p for _ in p]
+
+ if not isinstance(p[0], list):
+ batch['txt'] = [''.join([_p, p_p]) for _p in p]
+ with torch.cuda.amp.autocast(dtype=self.ae_dtype):
+ c, uc = self.conditioner.get_unconditional_conditioning(batch, batch_uc)
+ else:
+ assert len(p) == 1, 'Support bs=1 only for local prompt conditioning.'
+ p_tiles = p[0]
+ c = []
+ for i, p_tile in enumerate(p_tiles):
+ batch['txt'] = [''.join([p_tile, p_p])]
+ with torch.cuda.amp.autocast(dtype=self.ae_dtype):
+ if i == 0:
+ _c, uc = self.conditioner.get_unconditional_conditioning(batch, batch_uc)
+ else:
+ _c, _ = self.conditioner.get_unconditional_conditioning(batch, None)
+ c.append(_c)
+ return c, uc
+
+
+if __name__ == '__main__':
+ from SUPIR.util import create_model, load_state_dict
+
+ model = create_model('../../options/dev/SUPIR_paper_version.yaml')
+
+ SDXL_CKPT = '/opt/data/private/AIGC_pretrain/SDXL_cache/sd_xl_base_1.0_0.9vae.safetensors'
+ SUPIR_CKPT = '/opt/data/private/AIGC_pretrain/SUPIR_cache/SUPIR-paper.ckpt'
+ model.load_state_dict(load_state_dict(SDXL_CKPT), strict=False)
+ model.load_state_dict(load_state_dict(SUPIR_CKPT), strict=False)
+ model = model.cuda()
+
+ x = torch.randn(1, 3, 512, 512).cuda()
+ p = ['a professional, detailed, high-quality photo']
+ samples = model.batchify_sample(x, p, num_steps=50, restoration_scale=4.0, s_churn=0, cfg_scale=4.0, seed=-1, num_samples=1)
diff --git a/SUPIR/models/__init__.py b/SUPIR/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SUPIR/modules/SUPIR_v0.py b/SUPIR/modules/SUPIR_v0.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8b74d2b0786b9febd6efe82df08cbf3f0a4b866
--- /dev/null
+++ b/SUPIR/modules/SUPIR_v0.py
@@ -0,0 +1,718 @@
+# from einops._torch_specific import allow_ops_in_compiled_graph
+# allow_ops_in_compiled_graph()
+import einops
+import torch
+import torch as th
+import torch.nn as nn
+from einops import rearrange, repeat
+
+from sgm.modules.diffusionmodules.util import (
+ avg_pool_nd,
+ checkpoint,
+ conv_nd,
+ linear,
+ normalization,
+ timestep_embedding,
+ zero_module,
+)
+
+from sgm.modules.diffusionmodules.openaimodel import Downsample, Upsample, UNetModel, Timestep, \
+ TimestepEmbedSequential, ResBlock, AttentionBlock, TimestepBlock
+from sgm.modules.attention import SpatialTransformer, MemoryEfficientCrossAttention, CrossAttention
+from sgm.util import default, log_txt_as_img, exists, instantiate_from_config
+import re
+import torch
+from functools import partial
+
+
+try:
+ import xformers
+ import xformers.ops
+ XFORMERS_IS_AVAILBLE = True
+except:
+ XFORMERS_IS_AVAILBLE = False
+
+
+# dummy replace
+def convert_module_to_f16(x):
+ pass
+
+
+def convert_module_to_f32(x):
+ pass
+
+
+class ZeroConv(nn.Module):
+ def __init__(self, label_nc, norm_nc, mask=False):
+ super().__init__()
+ self.zero_conv = zero_module(conv_nd(2, label_nc, norm_nc, 1, 1, 0))
+ self.mask = mask
+
+ def forward(self, c, h, h_ori=None):
+ # with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32):
+ if not self.mask:
+ h = h + self.zero_conv(c)
+ else:
+ h = h + self.zero_conv(c) * torch.zeros_like(h)
+ if h_ori is not None:
+ h = th.cat([h_ori, h], dim=1)
+ return h
+
+
+class ZeroSFT(nn.Module):
+ def __init__(self, label_nc, norm_nc, concat_channels=0, norm=True, mask=False):
+ super().__init__()
+
+ # param_free_norm_type = str(parsed.group(1))
+ ks = 3
+ pw = ks // 2
+
+ self.norm = norm
+ if self.norm:
+ self.param_free_norm = normalization(norm_nc + concat_channels)
+ else:
+ self.param_free_norm = nn.Identity()
+
+ nhidden = 128
+
+ self.mlp_shared = nn.Sequential(
+ nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
+ nn.SiLU()
+ )
+ self.zero_mul = zero_module(nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw))
+ self.zero_add = zero_module(nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw))
+ # self.zero_mul = nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw)
+ # self.zero_add = nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw)
+
+ self.zero_conv = zero_module(conv_nd(2, label_nc, norm_nc, 1, 1, 0))
+ self.pre_concat = bool(concat_channels != 0)
+ self.mask = mask
+
+ def forward(self, c, h, h_ori=None, control_scale=1):
+ assert self.mask is False
+ if h_ori is not None and self.pre_concat:
+ h_raw = th.cat([h_ori, h], dim=1)
+ else:
+ h_raw = h
+
+ if self.mask:
+ h = h + self.zero_conv(c) * torch.zeros_like(h)
+ else:
+ h = h + self.zero_conv(c)
+ if h_ori is not None and self.pre_concat:
+ h = th.cat([h_ori, h], dim=1)
+ actv = self.mlp_shared(c)
+ gamma = self.zero_mul(actv)
+ beta = self.zero_add(actv)
+ if self.mask:
+ gamma = gamma * torch.zeros_like(gamma)
+ beta = beta * torch.zeros_like(beta)
+ h = self.param_free_norm(h) * (gamma + 1) + beta
+ if h_ori is not None and not self.pre_concat:
+ h = th.cat([h_ori, h], dim=1)
+ return h * control_scale + h_raw * (1 - control_scale)
+
+
+class ZeroCrossAttn(nn.Module):
+ ATTENTION_MODES = {
+ "softmax": CrossAttention, # vanilla attention
+ "softmax-xformers": MemoryEfficientCrossAttention
+ }
+
+ def __init__(self, context_dim, query_dim, zero_out=True, mask=False):
+ super().__init__()
+ attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
+ assert attn_mode in self.ATTENTION_MODES
+ attn_cls = self.ATTENTION_MODES[attn_mode]
+ self.attn = attn_cls(query_dim=query_dim, context_dim=context_dim, heads=query_dim//64, dim_head=64)
+ self.norm1 = normalization(query_dim)
+ self.norm2 = normalization(context_dim)
+
+ self.mask = mask
+
+ # if zero_out:
+ # # for p in self.attn.to_out.parameters():
+ # # p.detach().zero_()
+ # self.attn.to_out = zero_module(self.attn.to_out)
+
+ def forward(self, context, x, control_scale=1):
+ assert self.mask is False
+ x_in = x
+ x = self.norm1(x)
+ context = self.norm2(context)
+ b, c, h, w = x.shape
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
+ context = rearrange(context, 'b c h w -> b (h w) c').contiguous()
+ x = self.attn(x, context)
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
+ if self.mask:
+ x = x * torch.zeros_like(x)
+ x = x_in + x * control_scale
+
+ return x
+
+
+class GLVControl(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ disable_self_attentions=None,
+ num_attention_blocks=None,
+ disable_middle_self_attn=False,
+ use_linear_in_transformer=False,
+ spatial_transformer_attn_type="softmax",
+ adm_in_channels=None,
+ use_fairscale_checkpoint=False,
+ offload_to_cpu=False,
+ transformer_depth_middle=None,
+ input_upscale=1,
+ ):
+ super().__init__()
+ from omegaconf.listconfig import ListConfig
+
+ if use_spatial_transformer:
+ assert (
+ context_dim is not None
+ ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
+
+ if context_dim is not None:
+ assert (
+ use_spatial_transformer
+ ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert (
+ num_head_channels != -1
+ ), "Either num_heads or num_head_channels has to be set"
+
+ if num_head_channels == -1:
+ assert (
+ num_heads != -1
+ ), "Either num_heads or num_head_channels has to be set"
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ if isinstance(transformer_depth, int):
+ transformer_depth = len(channel_mult) * [transformer_depth]
+ elif isinstance(transformer_depth, ListConfig):
+ transformer_depth = list(transformer_depth)
+ transformer_depth_middle = default(
+ transformer_depth_middle, transformer_depth[-1]
+ )
+
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError(
+ "provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult"
+ )
+ self.num_res_blocks = num_res_blocks
+ # self.num_res_blocks = num_res_blocks
+ if disable_self_attentions is not None:
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+ assert len(disable_self_attentions) == len(channel_mult)
+ if num_attention_blocks is not None:
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
+ assert all(
+ map(
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
+ range(len(num_attention_blocks)),
+ )
+ )
+ print(
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+ f"attention will still not be set."
+ ) # todo: convert to warning
+
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ if use_fp16:
+ print("WARNING: use_fp16 was dropped and has no effect anymore.")
+ # self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ assert use_fairscale_checkpoint != use_checkpoint or not (
+ use_checkpoint or use_fairscale_checkpoint
+ )
+
+ self.use_fairscale_checkpoint = False
+ checkpoint_wrapper_fn = (
+ partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu)
+ if self.use_fairscale_checkpoint
+ else lambda x: x
+ )
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = checkpoint_wrapper_fn(
+ nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ )
+
+ if self.num_classes is not None:
+ if isinstance(self.num_classes, int):
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+ elif self.num_classes == "continuous":
+ print("setting up linear c_adm embedding layer")
+ self.label_emb = nn.Linear(1, time_embed_dim)
+ elif self.num_classes == "timestep":
+ self.label_emb = checkpoint_wrapper_fn(
+ nn.Sequential(
+ Timestep(model_channels),
+ nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ ),
+ )
+ )
+ elif self.num_classes == "sequential":
+ assert adm_in_channels is not None
+ self.label_emb = nn.Sequential(
+ nn.Sequential(
+ linear(adm_in_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ )
+ else:
+ raise ValueError()
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers = [
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = (
+ ch // num_heads
+ if use_spatial_transformer
+ else num_head_channels
+ )
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if (
+ not exists(num_attention_blocks)
+ or nr < num_attention_blocks[level]
+ ):
+ layers.append(
+ checkpoint_wrapper_fn(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ if not use_spatial_transformer
+ else checkpoint_wrapper_fn(
+ SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth[level],
+ context_dim=context_dim,
+ disable_self_attn=disabled_sa,
+ use_linear=use_linear_in_transformer,
+ attn_type=spatial_transformer_attn_type,
+ use_checkpoint=use_checkpoint,
+ )
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequential(
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ),
+ checkpoint_wrapper_fn(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ if not use_spatial_transformer
+ else checkpoint_wrapper_fn(
+ SpatialTransformer( # always uses a self-attn
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth_middle,
+ context_dim=context_dim,
+ disable_self_attn=disable_middle_self_attn,
+ use_linear=use_linear_in_transformer,
+ attn_type=spatial_transformer_attn_type,
+ use_checkpoint=use_checkpoint,
+ )
+ ),
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ),
+ )
+
+ self.input_upscale = input_upscale
+ self.input_hint_block = TimestepEmbedSequential(
+ zero_module(conv_nd(dims, in_channels, model_channels, 3, padding=1))
+ )
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps, xt, context=None, y=None, **kwargs):
+ # with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32):
+ # x = x.to(torch.float32)
+ # timesteps = timesteps.to(torch.float32)
+ # xt = xt.to(torch.float32)
+ # context = context.to(torch.float32)
+ # y = y.to(torch.float32)
+ # print(x.dtype)
+ xt, context, y = xt.to(x.dtype), context.to(x.dtype), y.to(x.dtype)
+
+ if self.input_upscale != 1:
+ x = nn.functional.interpolate(x, scale_factor=self.input_upscale, mode='bilinear', antialias=True)
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
+ # import pdb
+ # pdb.set_trace()
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y.shape[0] == xt.shape[0]
+ emb = emb + self.label_emb(y)
+
+ guided_hint = self.input_hint_block(x, emb, context)
+
+ # h = x.type(self.dtype)
+ h = xt
+ for module in self.input_blocks:
+ if guided_hint is not None:
+ h = module(h, emb, context)
+ h += guided_hint
+ guided_hint = None
+ else:
+ h = module(h, emb, context)
+ hs.append(h)
+ # print(module)
+ # print(h.shape)
+ h = self.middle_block(h, emb, context)
+ hs.append(h)
+ return hs
+
+
+class LightGLVUNet(UNetModel):
+ def __init__(self, mode='', project_type='ZeroSFT', project_channel_scale=1,
+ *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ if mode == 'XL-base':
+ cond_output_channels = [320] * 4 + [640] * 3 + [1280] * 3
+ project_channels = [160] * 4 + [320] * 3 + [640] * 3
+ concat_channels = [320] * 2 + [640] * 3 + [1280] * 4 + [0]
+ cross_attn_insert_idx = [6, 3]
+ self.progressive_mask_nums = [0, 3, 7, 11]
+ elif mode == 'XL-refine':
+ cond_output_channels = [384] * 4 + [768] * 3 + [1536] * 6
+ project_channels = [192] * 4 + [384] * 3 + [768] * 6
+ concat_channels = [384] * 2 + [768] * 3 + [1536] * 7 + [0]
+ cross_attn_insert_idx = [9, 6, 3]
+ self.progressive_mask_nums = [0, 3, 6, 10, 14]
+ else:
+ raise NotImplementedError
+
+ project_channels = [int(c * project_channel_scale) for c in project_channels]
+
+ self.project_modules = nn.ModuleList()
+ for i in range(len(cond_output_channels)):
+ # if i == len(cond_output_channels) - 1:
+ # _project_type = 'ZeroCrossAttn'
+ # else:
+ # _project_type = project_type
+ _project_type = project_type
+ if _project_type == 'ZeroSFT':
+ self.project_modules.append(ZeroSFT(project_channels[i], cond_output_channels[i],
+ concat_channels=concat_channels[i]))
+ elif _project_type == 'ZeroCrossAttn':
+ self.project_modules.append(ZeroCrossAttn(cond_output_channels[i], project_channels[i]))
+ else:
+ raise NotImplementedError
+
+ for i in cross_attn_insert_idx:
+ self.project_modules.insert(i, ZeroCrossAttn(cond_output_channels[i], concat_channels[i]))
+ # print(self.project_modules[i])
+
+ def step_progressive_mask(self):
+ if len(self.progressive_mask_nums) > 0:
+ mask_num = self.progressive_mask_nums.pop()
+ for i in range(len(self.project_modules)):
+ if i < mask_num:
+ self.project_modules[i].mask = True
+ else:
+ self.project_modules[i].mask = False
+ return
+ # print(f'step_progressive_mask, current masked layers: {mask_num}')
+ else:
+ return
+ # print('step_progressive_mask, no more masked layers')
+ # for i in range(len(self.project_modules)):
+ # print(self.project_modules[i].mask)
+
+
+ def forward(self, x, timesteps=None, context=None, y=None, control=None, control_scale=1, **kwargs):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ hs = []
+
+ _dtype = control[0].dtype
+ x, context, y = x.to(_dtype), context.to(_dtype), y.to(_dtype)
+
+ with torch.no_grad():
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y.shape[0] == x.shape[0]
+ emb = emb + self.label_emb(y)
+
+ # h = x.type(self.dtype)
+ h = x
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+
+ adapter_idx = len(self.project_modules) - 1
+ control_idx = len(control) - 1
+ h = self.middle_block(h, emb, context)
+ h = self.project_modules[adapter_idx](control[control_idx], h, control_scale=control_scale)
+ adapter_idx -= 1
+ control_idx -= 1
+
+ for i, module in enumerate(self.output_blocks):
+ _h = hs.pop()
+ h = self.project_modules[adapter_idx](control[control_idx], _h, h, control_scale=control_scale)
+ adapter_idx -= 1
+ # h = th.cat([h, _h], dim=1)
+ if len(module) == 3:
+ assert isinstance(module[2], Upsample)
+ for layer in module[:2]:
+ if isinstance(layer, TimestepBlock):
+ h = layer(h, emb)
+ elif isinstance(layer, SpatialTransformer):
+ h = layer(h, context)
+ else:
+ h = layer(h)
+ # print('cross_attn_here')
+ h = self.project_modules[adapter_idx](control[control_idx], h, control_scale=control_scale)
+ adapter_idx -= 1
+ h = module[2](h)
+ else:
+ h = module(h, emb, context)
+ control_idx -= 1
+ # print(module)
+ # print(h.shape)
+
+ h = h.type(x.dtype)
+ if self.predict_codebook_ids:
+ assert False, "not supported anymore. what the f*** are you doing?"
+ else:
+ return self.out(h)
+
+if __name__ == '__main__':
+ from omegaconf import OmegaConf
+
+ # refiner
+ # opt = OmegaConf.load('../../options/train/debug_p2_xl.yaml')
+ #
+ # model = instantiate_from_config(opt.model.params.control_stage_config)
+ # hint = model(torch.randn([1, 4, 64, 64]), torch.randn([1]), torch.randn([1, 4, 64, 64]))
+ # hint = [h.cuda() for h in hint]
+ # print(sum(map(lambda hint: hint.numel(), model.parameters())))
+ #
+ # unet = instantiate_from_config(opt.model.params.network_config)
+ # unet = unet.cuda()
+ #
+ # _output = unet(torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1]).cuda(), torch.randn([1, 77, 1280]).cuda(),
+ # torch.randn([1, 2560]).cuda(), hint)
+ # print(sum(map(lambda _output: _output.numel(), unet.parameters())))
+
+ # base
+ with torch.no_grad():
+ opt = OmegaConf.load('../../options/dev/SUPIR_tmp.yaml')
+
+ model = instantiate_from_config(opt.model.params.control_stage_config)
+ model = model.cuda()
+
+ hint = model(torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1]).cuda(), torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1, 77, 2048]).cuda(),
+ torch.randn([1, 2816]).cuda())
+
+ for h in hint:
+ print(h.shape)
+ #
+ unet = instantiate_from_config(opt.model.params.network_config)
+ unet = unet.cuda()
+ _output = unet(torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1]).cuda(), torch.randn([1, 77, 2048]).cuda(),
+ torch.randn([1, 2816]).cuda(), hint)
+
+
+ # model = instantiate_from_config(opt.model.params.control_stage_config)
+ # model = model.cuda()
+ # # hint = model(torch.randn([1, 4, 64, 64]), torch.randn([1]), torch.randn([1, 4, 64, 64]))
+ # hint = model(torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1]).cuda(), torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1, 77, 1280]).cuda(),
+ # torch.randn([1, 2560]).cuda())
+ # # hint = [h.cuda() for h in hint]
+ #
+ # for h in hint:
+ # print(h.shape)
+ #
+ # unet = instantiate_from_config(opt.model.params.network_config)
+ # unet = unet.cuda()
+ # _output = unet(torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1]).cuda(), torch.randn([1, 77, 1280]).cuda(),
+ # torch.randn([1, 2560]).cuda(), hint)
diff --git a/SUPIR/modules/__init__.py b/SUPIR/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..06574dd3ff7d8e44e88c3e57d3b11a9ccdde194e
--- /dev/null
+++ b/SUPIR/modules/__init__.py
@@ -0,0 +1,11 @@
+SDXL_BASE_CHANNEL_DICT = {
+ 'cond_output_channels': [320] * 4 + [640] * 3 + [1280] * 3,
+ 'project_channels': [160] * 4 + [320] * 3 + [640] * 3,
+ 'concat_channels': [320] * 2 + [640] * 3 + [1280] * 4 + [0]
+}
+
+SDXL_REFINE_CHANNEL_DICT = {
+ 'cond_output_channels': [384] * 4 + [768] * 3 + [1536] * 6,
+ 'project_channels': [192] * 4 + [384] * 3 + [768] * 6,
+ 'concat_channels': [384] * 2 + [768] * 3 + [1536] * 7 + [0]
+}
\ No newline at end of file
diff --git a/SUPIR/util.py b/SUPIR/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fb5ae31ed938ec41177902dcc2e8dd14ef2d4b0
--- /dev/null
+++ b/SUPIR/util.py
@@ -0,0 +1,179 @@
+import os
+import torch
+import numpy as np
+import cv2
+from PIL import Image
+from torch.nn.functional import interpolate
+from omegaconf import OmegaConf
+from sgm.util import instantiate_from_config
+
+
+def get_state_dict(d):
+ return d.get('state_dict', d)
+
+
+def load_state_dict(ckpt_path, location='cpu'):
+ _, extension = os.path.splitext(ckpt_path)
+ if extension.lower() == ".safetensors":
+ import safetensors.torch
+ state_dict = safetensors.torch.load_file(ckpt_path, device=location)
+ else:
+ state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
+ state_dict = get_state_dict(state_dict)
+ print(f'Loaded state_dict from [{ckpt_path}]')
+ return state_dict
+
+
+def create_model(config_path):
+ config = OmegaConf.load(config_path)
+ model = instantiate_from_config(config.model).cpu()
+ print(f'Loaded model config from [{config_path}]')
+ return model
+
+
+def create_SUPIR_model(config_path, SUPIR_sign=None, load_default_setting=False):
+ config = OmegaConf.load(config_path)
+ model = instantiate_from_config(config.model).cpu()
+ print(f'Loaded model config from [{config_path}]')
+ if config.SDXL_CKPT is not None:
+ model.load_state_dict(load_state_dict(config.SDXL_CKPT), strict=False)
+ if config.SUPIR_CKPT is not None:
+ model.load_state_dict(load_state_dict(config.SUPIR_CKPT), strict=False)
+ if SUPIR_sign is not None:
+ assert SUPIR_sign in ['F', 'Q']
+ if SUPIR_sign == 'F':
+ model.load_state_dict(load_state_dict(config.SUPIR_CKPT_F), strict=False)
+ elif SUPIR_sign == 'Q':
+ model.load_state_dict(load_state_dict(config.SUPIR_CKPT_Q), strict=False)
+ if load_default_setting:
+ default_setting = config.default_setting
+ return model, default_setting
+ return model
+
+def load_QF_ckpt(config_path):
+ config = OmegaConf.load(config_path)
+ ckpt_F = torch.load(config.SUPIR_CKPT_F, map_location='cpu')
+ ckpt_Q = torch.load(config.SUPIR_CKPT_Q, map_location='cpu')
+ return ckpt_Q, ckpt_F
+
+
+def PIL2Tensor(img, upsacle=1, min_size=1024, fix_resize=None):
+ '''
+ PIL.Image -> Tensor[C, H, W], RGB, [-1, 1]
+ '''
+ # size
+ w, h = img.size
+ w *= upsacle
+ h *= upsacle
+ w0, h0 = round(w), round(h)
+ if min(w, h) < min_size:
+ _upsacle = min_size / min(w, h)
+ w *= _upsacle
+ h *= _upsacle
+ if fix_resize is not None:
+ _upsacle = fix_resize / min(w, h)
+ w *= _upsacle
+ h *= _upsacle
+ w0, h0 = round(w), round(h)
+ w = int(np.round(w / 64.0)) * 64
+ h = int(np.round(h / 64.0)) * 64
+ x = img.resize((w, h), Image.BICUBIC)
+ x = np.array(x).round().clip(0, 255).astype(np.uint8)
+ x = x / 255 * 2 - 1
+ x = torch.tensor(x, dtype=torch.float32).permute(2, 0, 1)
+ return x, h0, w0
+
+
+def Tensor2PIL(x, h0, w0):
+ '''
+ Tensor[C, H, W], RGB, [-1, 1] -> PIL.Image
+ '''
+ x = x.unsqueeze(0)
+ x = interpolate(x, size=(h0, w0), mode='bicubic')
+ x = (x.squeeze(0).permute(1, 2, 0) * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
+ return Image.fromarray(x)
+
+
+def HWC3(x):
+ assert x.dtype == np.uint8
+ if x.ndim == 2:
+ x = x[:, :, None]
+ assert x.ndim == 3
+ H, W, C = x.shape
+ assert C == 1 or C == 3 or C == 4
+ if C == 3:
+ return x
+ if C == 1:
+ return np.concatenate([x, x, x], axis=2)
+ if C == 4:
+ color = x[:, :, 0:3].astype(np.float32)
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
+ y = color * alpha + 255.0 * (1.0 - alpha)
+ y = y.clip(0, 255).astype(np.uint8)
+ return y
+
+
+def upscale_image(input_image, upscale, min_size=None, unit_resolution=64):
+ H, W, C = input_image.shape
+ H = float(H)
+ W = float(W)
+ H *= upscale
+ W *= upscale
+ if min_size is not None:
+ if min(H, W) < min_size:
+ _upsacle = min_size / min(W, H)
+ W *= _upsacle
+ H *= _upsacle
+ H = int(np.round(H / unit_resolution)) * unit_resolution
+ W = int(np.round(W / unit_resolution)) * unit_resolution
+ img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if upscale > 1 else cv2.INTER_AREA)
+ img = img.round().clip(0, 255).astype(np.uint8)
+ return img
+
+
+def fix_resize(input_image, size=512, unit_resolution=64):
+ H, W, C = input_image.shape
+ H = float(H)
+ W = float(W)
+ upscale = size / min(H, W)
+ H *= upscale
+ W *= upscale
+ H = int(np.round(H / unit_resolution)) * unit_resolution
+ W = int(np.round(W / unit_resolution)) * unit_resolution
+ img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if upscale > 1 else cv2.INTER_AREA)
+ img = img.round().clip(0, 255).astype(np.uint8)
+ return img
+
+
+
+def Numpy2Tensor(img):
+ '''
+ np.array[H, w, C] [0, 255] -> Tensor[C, H, W], RGB, [-1, 1]
+ '''
+ # size
+ img = np.array(img) / 255 * 2 - 1
+ img = torch.tensor(img, dtype=torch.float32).permute(2, 0, 1)
+ return img
+
+
+def Tensor2Numpy(x, h0=None, w0=None):
+ '''
+ Tensor[C, H, W], RGB, [-1, 1] -> PIL.Image
+ '''
+ if h0 is not None and w0 is not None:
+ x = x.unsqueeze(0)
+ x = interpolate(x, size=(h0, w0), mode='bicubic')
+ x = x.squeeze(0)
+ x = (x.permute(1, 2, 0) * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
+ return x
+
+
+def convert_dtype(dtype_str):
+ if dtype_str == 'fp32':
+ return torch.float32
+ elif dtype_str == 'fp16':
+ return torch.float16
+ elif dtype_str == 'bf16':
+ return torch.bfloat16
+ else:
+ raise NotImplementedError
diff --git a/SUPIR/utils/__init__.py b/SUPIR/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SUPIR/utils/colorfix.py b/SUPIR/utils/colorfix.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2dedd7c6c6aead6b96db84b034a3bd16012d3b6
--- /dev/null
+++ b/SUPIR/utils/colorfix.py
@@ -0,0 +1,120 @@
+'''
+# --------------------------------------------------------------------------------
+# Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)
+# --------------------------------------------------------------------------------
+'''
+
+import torch
+from PIL import Image
+from torch import Tensor
+from torch.nn import functional as F
+
+from torchvision.transforms import ToTensor, ToPILImage
+
+def adain_color_fix(target: Image, source: Image):
+ # Convert images to tensors
+ to_tensor = ToTensor()
+ target_tensor = to_tensor(target).unsqueeze(0)
+ source_tensor = to_tensor(source).unsqueeze(0)
+
+ # Apply adaptive instance normalization
+ result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)
+
+ # Convert tensor back to image
+ to_image = ToPILImage()
+ result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
+
+ return result_image
+
+def wavelet_color_fix(target: Image, source: Image):
+ # Convert images to tensors
+ to_tensor = ToTensor()
+ target_tensor = to_tensor(target).unsqueeze(0)
+ source_tensor = to_tensor(source).unsqueeze(0)
+
+ # Apply wavelet reconstruction
+ result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
+
+ # Convert tensor back to image
+ to_image = ToPILImage()
+ result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
+
+ return result_image
+
+def calc_mean_std(feat: Tensor, eps=1e-5):
+ """Calculate mean and std for adaptive_instance_normalization.
+ Args:
+ feat (Tensor): 4D tensor.
+ eps (float): A small value added to the variance to avoid
+ divide-by-zero. Default: 1e-5.
+ """
+ size = feat.size()
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
+ b, c = size[:2]
+ feat_var = feat.reshape(b, c, -1).var(dim=2) + eps
+ feat_std = feat_var.sqrt().reshape(b, c, 1, 1)
+ feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)
+ return feat_mean, feat_std
+
+def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
+ """Adaptive instance normalization.
+ Adjust the reference features to have the similar color and illuminations
+ as those in the degradate features.
+ Args:
+ content_feat (Tensor): The reference feature.
+ style_feat (Tensor): The degradate features.
+ """
+ size = content_feat.size()
+ style_mean, style_std = calc_mean_std(style_feat)
+ content_mean, content_std = calc_mean_std(content_feat)
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
+
+def wavelet_blur(image: Tensor, radius: int):
+ """
+ Apply wavelet blur to the input tensor.
+ """
+ # input shape: (1, 3, H, W)
+ # convolution kernel
+ kernel_vals = [
+ [0.0625, 0.125, 0.0625],
+ [0.125, 0.25, 0.125],
+ [0.0625, 0.125, 0.0625],
+ ]
+ kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
+ # add channel dimensions to the kernel to make it a 4D tensor
+ kernel = kernel[None, None]
+ # repeat the kernel across all input channels
+ kernel = kernel.repeat(3, 1, 1, 1)
+ image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
+ # apply convolution
+ output = F.conv2d(image, kernel, groups=3, dilation=radius)
+ return output
+
+def wavelet_decomposition(image: Tensor, levels=5):
+ """
+ Apply wavelet decomposition to the input tensor.
+ This function only returns the low frequency & the high frequency.
+ """
+ high_freq = torch.zeros_like(image)
+ for i in range(levels):
+ radius = 2 ** i
+ low_freq = wavelet_blur(image, radius)
+ high_freq += (image - low_freq)
+ image = low_freq
+
+ return high_freq, low_freq
+
+def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
+ """
+ Apply wavelet decomposition, so that the content will have the same color as the style.
+ """
+ # calculate the wavelet decomposition of the content feature
+ content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
+ del content_low_freq
+ # calculate the wavelet decomposition of the style feature
+ style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
+ del style_high_freq
+ # reconstruct the content feature with the style's high frequency
+ return content_high_freq + style_low_freq
+
diff --git a/SUPIR/utils/devices.py b/SUPIR/utils/devices.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc64037429a466bf8ea24432870d70e58631fe89
--- /dev/null
+++ b/SUPIR/utils/devices.py
@@ -0,0 +1,138 @@
+import sys
+import contextlib
+from functools import lru_cache
+
+import torch
+#from modules import errors
+
+if sys.platform == "darwin":
+ from modules import mac_specific
+
+
+def has_mps() -> bool:
+ if sys.platform != "darwin":
+ return False
+ else:
+ return mac_specific.has_mps
+
+
+def get_cuda_device_string():
+ return "cuda"
+
+
+def get_optimal_device_name():
+ if torch.cuda.is_available():
+ return get_cuda_device_string()
+
+ if has_mps():
+ return "mps"
+
+ return "cpu"
+
+
+def get_optimal_device():
+ return torch.device(get_optimal_device_name())
+
+
+def get_device_for(task):
+ return get_optimal_device()
+
+
+def torch_gc():
+
+ if torch.cuda.is_available():
+ with torch.cuda.device(get_cuda_device_string()):
+ torch.cuda.empty_cache()
+ torch.cuda.ipc_collect()
+
+ if has_mps():
+ mac_specific.torch_mps_gc()
+
+
+def enable_tf32():
+ if torch.cuda.is_available():
+
+ # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
+ # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
+ if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
+ torch.backends.cudnn.benchmark = True
+
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+
+
+enable_tf32()
+#errors.run(enable_tf32, "Enabling TF32")
+
+cpu = torch.device("cpu")
+device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = torch.device("cuda")
+dtype = torch.float16
+dtype_vae = torch.float16
+dtype_unet = torch.float16
+unet_needs_upcast = False
+
+
+def cond_cast_unet(input):
+ return input.to(dtype_unet) if unet_needs_upcast else input
+
+
+def cond_cast_float(input):
+ return input.float() if unet_needs_upcast else input
+
+
+def randn(seed, shape):
+ torch.manual_seed(seed)
+ return torch.randn(shape, device=device)
+
+
+def randn_without_seed(shape):
+ return torch.randn(shape, device=device)
+
+
+def autocast(disable=False):
+ if disable:
+ return contextlib.nullcontext()
+
+ return torch.autocast("cuda")
+
+
+def without_autocast(disable=False):
+ return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
+
+
+class NansException(Exception):
+ pass
+
+
+def test_for_nans(x, where):
+ if not torch.all(torch.isnan(x)).item():
+ return
+
+ if where == "unet":
+ message = "A tensor with all NaNs was produced in Unet."
+
+ elif where == "vae":
+ message = "A tensor with all NaNs was produced in VAE."
+
+ else:
+ message = "A tensor with all NaNs was produced."
+
+ message += " Use --disable-nan-check commandline argument to disable this check."
+
+ raise NansException(message)
+
+
+@lru_cache
+def first_time_calculation():
+ """
+ just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
+ spends about 2.7 seconds doing that, at least wih NVidia.
+ """
+
+ x = torch.zeros((1, 1)).to(device, dtype)
+ linear = torch.nn.Linear(1, 1).to(device, dtype)
+ linear(x)
+
+ x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
+ conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
+ conv2d(x)
\ No newline at end of file
diff --git a/SUPIR/utils/face_restoration_helper.py b/SUPIR/utils/face_restoration_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5499e94621ef526f299c0c236a76369707203a2
--- /dev/null
+++ b/SUPIR/utils/face_restoration_helper.py
@@ -0,0 +1,514 @@
+import cv2
+import numpy as np
+import os
+import torch
+from torchvision.transforms.functional import normalize
+
+from facexlib.detection import init_detection_model
+from facexlib.parsing import init_parsing_model
+from facexlib.utils.misc import img2tensor, imwrite
+
+from .file import load_file_from_url
+
+
+def get_largest_face(det_faces, h, w):
+ def get_location(val, length):
+ if val < 0:
+ return 0
+ elif val > length:
+ return length
+ else:
+ return val
+
+ face_areas = []
+ for det_face in det_faces:
+ left = get_location(det_face[0], w)
+ right = get_location(det_face[2], w)
+ top = get_location(det_face[1], h)
+ bottom = get_location(det_face[3], h)
+ face_area = (right - left) * (bottom - top)
+ face_areas.append(face_area)
+ largest_idx = face_areas.index(max(face_areas))
+ return det_faces[largest_idx], largest_idx
+
+
+def get_center_face(det_faces, h=0, w=0, center=None):
+ if center is not None:
+ center = np.array(center)
+ else:
+ center = np.array([w / 2, h / 2])
+ center_dist = []
+ for det_face in det_faces:
+ face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2])
+ dist = np.linalg.norm(face_center - center)
+ center_dist.append(dist)
+ center_idx = center_dist.index(min(center_dist))
+ return det_faces[center_idx], center_idx
+
+
+class FaceRestoreHelper(object):
+ """Helper for the face restoration pipeline (base class)."""
+
+ def __init__(self,
+ upscale_factor,
+ face_size=512,
+ crop_ratio=(1, 1),
+ det_model='retinaface_resnet50',
+ save_ext='png',
+ template_3points=False,
+ pad_blur=False,
+ use_parse=False,
+ device=None):
+ self.template_3points = template_3points # improve robustness
+ self.upscale_factor = int(upscale_factor)
+ # the cropped face ratio based on the square face
+ self.crop_ratio = crop_ratio # (h, w)
+ assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
+ self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
+ self.det_model = det_model
+
+ if self.det_model == 'dlib':
+ # standard 5 landmarks for FFHQ faces with 1024 x 1024
+ self.face_template = np.array([[686.77227723, 488.62376238], [586.77227723, 493.59405941],
+ [337.91089109, 488.38613861], [437.95049505, 493.51485149],
+ [513.58415842, 678.5049505]])
+ self.face_template = self.face_template / (1024 // face_size)
+ elif self.template_3points:
+ self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
+ else:
+ # standard 5 landmarks for FFHQ faces with 512 x 512
+ # facexlib
+ self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
+ [201.26117, 371.41043], [313.08905, 371.15118]])
+
+ # dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54
+ # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
+ # [198.22603, 372.82502], [313.91018, 372.75659]])
+
+ self.face_template = self.face_template * (face_size / 512.0)
+ if self.crop_ratio[0] > 1:
+ self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
+ if self.crop_ratio[1] > 1:
+ self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
+ self.save_ext = save_ext
+ self.pad_blur = pad_blur
+ if self.pad_blur is True:
+ self.template_3points = False
+
+ self.all_landmarks_5 = []
+ self.det_faces = []
+ self.affine_matrices = []
+ self.inverse_affine_matrices = []
+ self.cropped_faces = []
+ self.restored_faces = []
+ self.pad_input_imgs = []
+
+ if device is None:
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ # self.device = get_device()
+ else:
+ self.device = device
+
+ # init face detection model
+ self.face_detector = init_detection_model(det_model, half=False, device=self.device)
+
+ # init face parsing model
+ self.use_parse = use_parse
+ self.face_parse = init_parsing_model(model_name='parsenet', device=self.device)
+
+ def set_upscale_factor(self, upscale_factor):
+ self.upscale_factor = upscale_factor
+
+ def read_image(self, img):
+ """img can be image path or cv2 loaded image."""
+ # self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
+ if isinstance(img, str):
+ img = cv2.imread(img)
+
+ if np.max(img) > 256: # 16-bit image
+ img = img / 65535 * 255
+ if len(img.shape) == 2: # gray image
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+ elif img.shape[2] == 4: # BGRA image with alpha channel
+ img = img[:, :, 0:3]
+
+ self.input_img = img
+ # self.is_gray = is_gray(img, threshold=10)
+ # if self.is_gray:
+ # print('Grayscale input: True')
+
+ if min(self.input_img.shape[:2]) < 512:
+ f = 512.0 / min(self.input_img.shape[:2])
+ self.input_img = cv2.resize(self.input_img, (0, 0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
+
+ def init_dlib(self, detection_path, landmark5_path):
+ """Initialize the dlib detectors and predictors."""
+ try:
+ import dlib
+ except ImportError:
+ print('Please install dlib by running:' 'conda install -c conda-forge dlib')
+ detection_path = load_file_from_url(url=detection_path, model_dir='weights/dlib', progress=True, file_name=None)
+ landmark5_path = load_file_from_url(url=landmark5_path, model_dir='weights/dlib', progress=True, file_name=None)
+ face_detector = dlib.cnn_face_detection_model_v1(detection_path)
+ shape_predictor_5 = dlib.shape_predictor(landmark5_path)
+ return face_detector, shape_predictor_5
+
+ def get_face_landmarks_5_dlib(self,
+ only_keep_largest=False,
+ scale=1):
+ det_faces = self.face_detector(self.input_img, scale)
+
+ if len(det_faces) == 0:
+ print('No face detected. Try to increase upsample_num_times.')
+ return 0
+ else:
+ if only_keep_largest:
+ print('Detect several faces and only keep the largest.')
+ face_areas = []
+ for i in range(len(det_faces)):
+ face_area = (det_faces[i].rect.right() - det_faces[i].rect.left()) * (
+ det_faces[i].rect.bottom() - det_faces[i].rect.top())
+ face_areas.append(face_area)
+ largest_idx = face_areas.index(max(face_areas))
+ self.det_faces = [det_faces[largest_idx]]
+ else:
+ self.det_faces = det_faces
+
+ if len(self.det_faces) == 0:
+ return 0
+
+ for face in self.det_faces:
+ shape = self.shape_predictor_5(self.input_img, face.rect)
+ landmark = np.array([[part.x, part.y] for part in shape.parts()])
+ self.all_landmarks_5.append(landmark)
+
+ return len(self.all_landmarks_5)
+
+ def get_face_landmarks_5(self,
+ only_keep_largest=False,
+ only_center_face=False,
+ resize=None,
+ blur_ratio=0.01,
+ eye_dist_threshold=None):
+ if self.det_model == 'dlib':
+ return self.get_face_landmarks_5_dlib(only_keep_largest)
+
+ if resize is None:
+ scale = 1
+ input_img = self.input_img
+ else:
+ h, w = self.input_img.shape[0:2]
+ scale = resize / min(h, w)
+ scale = max(1, scale) # always scale up
+ h, w = int(h * scale), int(w * scale)
+ interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
+ input_img = cv2.resize(self.input_img, (w, h), interpolation=interp)
+
+ with torch.no_grad():
+ bboxes = self.face_detector.detect_faces(input_img)
+
+ if bboxes is None or bboxes.shape[0] == 0:
+ return 0
+ else:
+ bboxes = bboxes / scale
+
+ for bbox in bboxes:
+ # remove faces with too small eye distance: side faces or too small faces
+ eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]])
+ if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
+ continue
+
+ if self.template_3points:
+ landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
+ else:
+ landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
+ self.all_landmarks_5.append(landmark)
+ self.det_faces.append(bbox[0:5])
+
+ if len(self.det_faces) == 0:
+ return 0
+ if only_keep_largest:
+ h, w, _ = self.input_img.shape
+ self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
+ self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
+ elif only_center_face:
+ h, w, _ = self.input_img.shape
+ self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
+ self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
+
+ # pad blurry images
+ if self.pad_blur:
+ self.pad_input_imgs = []
+ for landmarks in self.all_landmarks_5:
+ # get landmarks
+ eye_left = landmarks[0, :]
+ eye_right = landmarks[1, :]
+ eye_avg = (eye_left + eye_right) * 0.5
+ mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
+ eye_to_eye = eye_right - eye_left
+ eye_to_mouth = mouth_avg - eye_avg
+
+ # Get the oriented crop rectangle
+ # x: half width of the oriented crop rectangle
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
+ # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
+ # norm with the hypotenuse: get the direction
+ x /= np.hypot(*x) # get the hypotenuse of a right triangle
+ rect_scale = 1.5
+ x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
+ # y: half height of the oriented crop rectangle
+ y = np.flipud(x) * [-1, 1]
+
+ # c: center
+ c = eye_avg + eye_to_mouth * 0.1
+ # quad: (left_top, left_bottom, right_bottom, right_top)
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
+ # qsize: side length of the square
+ qsize = np.hypot(*x) * 2
+ border = max(int(np.rint(qsize * 0.1)), 3)
+
+ # get pad
+ # pad: (width_left, height_top, width_right, height_bottom)
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))))
+ pad = [
+ max(-pad[0] + border, 1),
+ max(-pad[1] + border, 1),
+ max(pad[2] - self.input_img.shape[0] + border, 1),
+ max(pad[3] - self.input_img.shape[1] + border, 1)
+ ]
+
+ if max(pad) > 1:
+ # pad image
+ pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
+ # modify landmark coords
+ landmarks[:, 0] += pad[0]
+ landmarks[:, 1] += pad[1]
+ # blur pad images
+ h, w, _ = pad_img.shape
+ y, x, _ = np.ogrid[:h, :w, :1]
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
+ np.float32(w - 1 - x) / pad[2]),
+ 1.0 - np.minimum(np.float32(y) / pad[1],
+ np.float32(h - 1 - y) / pad[3]))
+ blur = int(qsize * blur_ratio)
+ if blur % 2 == 0:
+ blur += 1
+ blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
+ # blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
+
+ pad_img = pad_img.astype('float32')
+ pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
+ pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0)
+ pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255]
+ self.pad_input_imgs.append(pad_img)
+ else:
+ self.pad_input_imgs.append(np.copy(self.input_img))
+
+ return len(self.all_landmarks_5)
+
+ def align_warp_face(self, save_cropped_path=None, border_mode='constant'):
+ """Align and warp faces with face template.
+ """
+ if self.pad_blur:
+ assert len(self.pad_input_imgs) == len(
+ self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}'
+ for idx, landmark in enumerate(self.all_landmarks_5):
+ # use 5 landmarks to get affine matrix
+ # use cv2.LMEDS method for the equivalence to skimage transform
+ # ref: https://blog.csdn.net/yichxi/article/details/115827338
+ affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0]
+ self.affine_matrices.append(affine_matrix)
+ # warp and crop faces
+ if border_mode == 'constant':
+ border_mode = cv2.BORDER_CONSTANT
+ elif border_mode == 'reflect101':
+ border_mode = cv2.BORDER_REFLECT101
+ elif border_mode == 'reflect':
+ border_mode = cv2.BORDER_REFLECT
+ if self.pad_blur:
+ input_img = self.pad_input_imgs[idx]
+ else:
+ input_img = self.input_img
+ cropped_face = cv2.warpAffine(
+ input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray
+ self.cropped_faces.append(cropped_face)
+ # save the cropped face
+ if save_cropped_path is not None:
+ path = os.path.splitext(save_cropped_path)[0]
+ save_path = f'{path}_{idx:02d}.{self.save_ext}'
+ imwrite(cropped_face, save_path)
+
+ def get_inverse_affine(self, save_inverse_affine_path=None):
+ """Get inverse affine matrix."""
+ for idx, affine_matrix in enumerate(self.affine_matrices):
+ inverse_affine = cv2.invertAffineTransform(affine_matrix)
+ inverse_affine *= self.upscale_factor
+ self.inverse_affine_matrices.append(inverse_affine)
+ # save inverse affine matrices
+ if save_inverse_affine_path is not None:
+ path, _ = os.path.splitext(save_inverse_affine_path)
+ save_path = f'{path}_{idx:02d}.pth'
+ torch.save(inverse_affine, save_path)
+
+ def add_restored_face(self, restored_face, input_face=None):
+ # if self.is_gray:
+ # restored_face = bgr2gray(restored_face) # convert img into grayscale
+ # if input_face is not None:
+ # restored_face = adain_npy(restored_face, input_face) # transfer the color
+ self.restored_faces.append(restored_face)
+
+ def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
+ h, w, _ = self.input_img.shape
+ h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
+
+ if upsample_img is None:
+ # simply resize the background
+ # upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
+ upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR)
+ else:
+ upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
+
+ assert len(self.restored_faces) == len(
+ self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
+
+ inv_mask_borders = []
+ for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
+ if face_upsampler is not None:
+ restored_face = face_upsampler.enhance(restored_face, outscale=self.upscale_factor)[0]
+ inverse_affine /= self.upscale_factor
+ inverse_affine[:, 2] *= self.upscale_factor
+ face_size = (self.face_size[0] * self.upscale_factor, self.face_size[1] * self.upscale_factor)
+ else:
+ # Add an offset to inverse affine matrix, for more precise back alignment
+ if self.upscale_factor > 1:
+ extra_offset = 0.5 * self.upscale_factor
+ else:
+ extra_offset = 0
+ inverse_affine[:, 2] += extra_offset
+ face_size = self.face_size
+ inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
+
+ # if draw_box or not self.use_parse: # use square parse maps
+ # mask = np.ones(face_size, dtype=np.float32)
+ # inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
+ # # remove the black borders
+ # inv_mask_erosion = cv2.erode(
+ # inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
+ # pasted_face = inv_mask_erosion[:, :, None] * inv_restored
+ # total_face_area = np.sum(inv_mask_erosion) # // 3
+ # # add border
+ # if draw_box:
+ # h, w = face_size
+ # mask_border = np.ones((h, w, 3), dtype=np.float32)
+ # border = int(1400/np.sqrt(total_face_area))
+ # mask_border[border:h-border, border:w-border,:] = 0
+ # inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
+ # inv_mask_borders.append(inv_mask_border)
+ # if not self.use_parse:
+ # # compute the fusion edge based on the area of face
+ # w_edge = int(total_face_area**0.5) // 20
+ # erosion_radius = w_edge * 2
+ # inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
+ # blur_size = w_edge * 2
+ # inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
+ # if len(upsample_img.shape) == 2: # upsample_img is gray image
+ # upsample_img = upsample_img[:, :, None]
+ # inv_soft_mask = inv_soft_mask[:, :, None]
+
+ # always use square mask
+ mask = np.ones(face_size, dtype=np.float32)
+ inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
+ # remove the black borders
+ inv_mask_erosion = cv2.erode(
+ inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
+ pasted_face = inv_mask_erosion[:, :, None] * inv_restored
+ total_face_area = np.sum(inv_mask_erosion) # // 3
+ # add border
+ if draw_box:
+ h, w = face_size
+ mask_border = np.ones((h, w, 3), dtype=np.float32)
+ border = int(1400 / np.sqrt(total_face_area))
+ mask_border[border:h - border, border:w - border, :] = 0
+ inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
+ inv_mask_borders.append(inv_mask_border)
+ # compute the fusion edge based on the area of face
+ w_edge = int(total_face_area ** 0.5) // 20
+ erosion_radius = w_edge * 2
+ inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
+ blur_size = w_edge * 2
+ inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
+ if len(upsample_img.shape) == 2: # upsample_img is gray image
+ upsample_img = upsample_img[:, :, None]
+ inv_soft_mask = inv_soft_mask[:, :, None]
+
+ # parse mask
+ if self.use_parse:
+ # inference
+ face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
+ face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True)
+ normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+ face_input = torch.unsqueeze(face_input, 0).to(self.device)
+ with torch.no_grad():
+ out = self.face_parse(face_input)[0]
+ out = out.argmax(dim=1).squeeze().cpu().numpy()
+
+ parse_mask = np.zeros(out.shape)
+ MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
+ for idx, color in enumerate(MASK_COLORMAP):
+ parse_mask[out == idx] = color
+ # blur the mask
+ parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
+ parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
+ # remove the black borders
+ thres = 10
+ parse_mask[:thres, :] = 0
+ parse_mask[-thres:, :] = 0
+ parse_mask[:, :thres] = 0
+ parse_mask[:, -thres:] = 0
+ parse_mask = parse_mask / 255.
+
+ parse_mask = cv2.resize(parse_mask, face_size)
+ parse_mask = cv2.warpAffine(parse_mask, inverse_affine, (w_up, h_up), flags=3)
+ inv_soft_parse_mask = parse_mask[:, :, None]
+ # pasted_face = inv_restored
+ fuse_mask = (inv_soft_parse_mask < inv_soft_mask).astype('int')
+ inv_soft_mask = inv_soft_parse_mask * fuse_mask + inv_soft_mask * (1 - fuse_mask)
+
+ if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4: # alpha channel
+ alpha = upsample_img[:, :, 3:]
+ upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
+ upsample_img = np.concatenate((upsample_img, alpha), axis=2)
+ else:
+ upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
+
+ if np.max(upsample_img) > 256: # 16-bit image
+ upsample_img = upsample_img.astype(np.uint16)
+ else:
+ upsample_img = upsample_img.astype(np.uint8)
+
+ # draw bounding box
+ if draw_box:
+ # upsample_input_img = cv2.resize(input_img, (w_up, h_up))
+ img_color = np.ones([*upsample_img.shape], dtype=np.float32)
+ img_color[:, :, 0] = 0
+ img_color[:, :, 1] = 255
+ img_color[:, :, 2] = 0
+ for inv_mask_border in inv_mask_borders:
+ upsample_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img
+ # upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img
+
+ if save_path is not None:
+ path = os.path.splitext(save_path)[0]
+ save_path = f'{path}.{self.save_ext}'
+ imwrite(upsample_img, save_path)
+ return upsample_img
+
+ def clean_all(self):
+ self.all_landmarks_5 = []
+ self.restored_faces = []
+ self.affine_matrices = []
+ self.cropped_faces = []
+ self.inverse_affine_matrices = []
+ self.det_faces = []
+ self.pad_input_imgs = []
\ No newline at end of file
diff --git a/SUPIR/utils/file.py b/SUPIR/utils/file.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e1b86304bf8038351c64aa6601b3778c43e91e7
--- /dev/null
+++ b/SUPIR/utils/file.py
@@ -0,0 +1,79 @@
+import os
+from typing import List, Tuple
+
+from urllib.parse import urlparse
+from torch.hub import download_url_to_file, get_dir
+
+
+def load_file_list(file_list_path: str) -> List[str]:
+ files = []
+ # each line in file list contains a path of an image
+ with open(file_list_path, "r") as fin:
+ for line in fin:
+ path = line.strip()
+ if path:
+ files.append(path)
+ return files
+
+
+def list_image_files(
+ img_dir: str,
+ exts: Tuple[str]=(".jpg", ".png", ".jpeg"),
+ follow_links: bool=False,
+ log_progress: bool=False,
+ log_every_n_files: int=10000,
+ max_size: int=-1
+) -> List[str]:
+ files = []
+ for dir_path, _, file_names in os.walk(img_dir, followlinks=follow_links):
+ early_stop = False
+ for file_name in file_names:
+ if os.path.splitext(file_name)[1].lower() in exts:
+ if max_size >= 0 and len(files) >= max_size:
+ early_stop = True
+ break
+ files.append(os.path.join(dir_path, file_name))
+ if log_progress and len(files) % log_every_n_files == 0:
+ print(f"find {len(files)} images in {img_dir}")
+ if early_stop:
+ break
+ return files
+
+
+def get_file_name_parts(file_path: str) -> Tuple[str, str, str]:
+ parent_path, file_name = os.path.split(file_path)
+ stem, ext = os.path.splitext(file_name)
+ return parent_path, stem, ext
+
+
+# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/download_util.py/
+def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
+ """Load file form http url, will download models if necessary.
+
+ Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
+
+ Args:
+ url (str): URL to be downloaded.
+ model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
+ Default: None.
+ progress (bool): Whether to show the download progress. Default: True.
+ file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
+
+ Returns:
+ str: The path to the downloaded file.
+ """
+ if model_dir is None: # use the pytorch hub_dir
+ hub_dir = get_dir()
+ model_dir = os.path.join(hub_dir, 'checkpoints')
+
+ os.makedirs(model_dir, exist_ok=True)
+
+ parts = urlparse(url)
+ filename = os.path.basename(parts.path)
+ if file_name is not None:
+ filename = file_name
+ cached_file = os.path.abspath(os.path.join(model_dir, filename))
+ if not os.path.exists(cached_file):
+ print(f'Downloading: "{url}" to {cached_file}\n')
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
+ return cached_file
\ No newline at end of file
diff --git a/SUPIR/utils/tilevae.py b/SUPIR/utils/tilevae.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1b6668c2c3b5934acdf836e57940e809b190dfb
--- /dev/null
+++ b/SUPIR/utils/tilevae.py
@@ -0,0 +1,971 @@
+# ------------------------------------------------------------------------
+#
+# Ultimate VAE Tile Optimization
+#
+# Introducing a revolutionary new optimization designed to make
+# the VAE work with giant images on limited VRAM!
+# Say goodbye to the frustration of OOM and hello to seamless output!
+#
+# ------------------------------------------------------------------------
+#
+# This script is a wild hack that splits the image into tiles,
+# encodes each tile separately, and merges the result back together.
+#
+# Advantages:
+# - The VAE can now work with giant images on limited VRAM
+# (~10 GB for 8K images!)
+# - The merged output is completely seamless without any post-processing.
+#
+# Drawbacks:
+# - Giant RAM needed. To store the intermediate results for a 4096x4096
+# images, you need 32 GB RAM it consumes ~20GB); for 8192x8192
+# you need 128 GB RAM machine (it consumes ~100 GB)
+# - NaNs always appear in for 8k images when you use fp16 (half) VAE
+# You must use --no-half-vae to disable half VAE for that giant image.
+# - Slow speed. With default tile size, it takes around 50/200 seconds
+# to encode/decode a 4096x4096 image; and 200/900 seconds to encode/decode
+# a 8192x8192 image. (The speed is limited by both the GPU and the CPU.)
+# - The gradient calculation is not compatible with this hack. It
+# will break any backward() or torch.autograd.grad() that passes VAE.
+# (But you can still use the VAE to generate training data.)
+#
+# How it works:
+# 1) The image is split into tiles.
+# - To ensure perfect results, each tile is padded with 32 pixels
+# on each side.
+# - Then the conv2d/silu/upsample/downsample can produce identical
+# results to the original image without splitting.
+# 2) The original forward is decomposed into a task queue and a task worker.
+# - The task queue is a list of functions that will be executed in order.
+# - The task worker is a loop that executes the tasks in the queue.
+# 3) The task queue is executed for each tile.
+# - Current tile is sent to GPU.
+# - local operations are directly executed.
+# - Group norm calculation is temporarily suspended until the mean
+# and var of all tiles are calculated.
+# - The residual is pre-calculated and stored and addded back later.
+# - When need to go to the next tile, the current tile is send to cpu.
+# 4) After all tiles are processed, tiles are merged on cpu and return.
+#
+# Enjoy!
+#
+# @author: LI YI @ Nanyang Technological University - Singapore
+# @date: 2023-03-02
+# @license: MIT License
+#
+# Please give me a star if you like this project!
+#
+# -------------------------------------------------------------------------
+
+import gc
+from time import time
+import math
+from tqdm import tqdm
+
+import torch
+import torch.version
+import torch.nn.functional as F
+from einops import rearrange
+from diffusers.utils.import_utils import is_xformers_available
+
+import SUPIR.utils.devices as devices
+
+try:
+ import xformers
+ import xformers.ops
+except ImportError:
+ pass
+
+sd_flag = True
+
+def get_recommend_encoder_tile_size():
+ if torch.cuda.is_available():
+ total_memory = torch.cuda.get_device_properties(
+ devices.device).total_memory // 2**20
+ if total_memory > 16*1000:
+ ENCODER_TILE_SIZE = 3072
+ elif total_memory > 12*1000:
+ ENCODER_TILE_SIZE = 2048
+ elif total_memory > 8*1000:
+ ENCODER_TILE_SIZE = 1536
+ else:
+ ENCODER_TILE_SIZE = 960
+ else:
+ ENCODER_TILE_SIZE = 512
+ return ENCODER_TILE_SIZE
+
+
+def get_recommend_decoder_tile_size():
+ if torch.cuda.is_available():
+ total_memory = torch.cuda.get_device_properties(
+ devices.device).total_memory // 2**20
+ if total_memory > 30*1000:
+ DECODER_TILE_SIZE = 256
+ elif total_memory > 16*1000:
+ DECODER_TILE_SIZE = 192
+ elif total_memory > 12*1000:
+ DECODER_TILE_SIZE = 128
+ elif total_memory > 8*1000:
+ DECODER_TILE_SIZE = 96
+ else:
+ DECODER_TILE_SIZE = 64
+ else:
+ DECODER_TILE_SIZE = 64
+ return DECODER_TILE_SIZE
+
+
+if 'global const':
+ DEFAULT_ENABLED = False
+ DEFAULT_MOVE_TO_GPU = False
+ DEFAULT_FAST_ENCODER = True
+ DEFAULT_FAST_DECODER = True
+ DEFAULT_COLOR_FIX = 0
+ DEFAULT_ENCODER_TILE_SIZE = get_recommend_encoder_tile_size()
+ DEFAULT_DECODER_TILE_SIZE = get_recommend_decoder_tile_size()
+
+
+# inplace version of silu
+def inplace_nonlinearity(x):
+ # Test: fix for Nans
+ return F.silu(x, inplace=True)
+
+# extracted from ldm.modules.diffusionmodules.model
+
+# from diffusers lib
+def attn_forward_new(self, h_):
+ batch_size, channel, height, width = h_.shape
+ hidden_states = h_.view(batch_size, channel, height * width).transpose(1, 2)
+
+ attention_mask = None
+ encoder_hidden_states = None
+ batch_size, sequence_length, _ = hidden_states.shape
+ attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ query = self.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif self.norm_cross:
+ encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+
+ query = self.head_to_batch_dim(query)
+ key = self.head_to_batch_dim(key)
+ value = self.head_to_batch_dim(value)
+
+ attention_probs = self.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = self.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ return hidden_states
+
+def attn_forward_new_pt2_0(self, hidden_states,):
+ scale = 1
+ attention_mask = None
+ encoder_hidden_states = None
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, self.heads, -1, attention_mask.shape[-1])
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = self.to_q(hidden_states, scale=scale)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif self.norm_cross:
+ encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = self.to_k(encoder_hidden_states, scale=scale)
+ value = self.to_v(encoder_hidden_states, scale=scale)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // self.heads
+
+ query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states, scale=scale)
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ return hidden_states
+
+def attn_forward_new_xformers(self, hidden_states):
+ scale = 1
+ attention_op = None
+ attention_mask = None
+ encoder_hidden_states = None
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, key_tokens, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ attention_mask = self.prepare_attention_mask(attention_mask, key_tokens, batch_size)
+ if attention_mask is not None:
+ # expand our mask's singleton query_tokens dimension:
+ # [batch*heads, 1, key_tokens] ->
+ # [batch*heads, query_tokens, key_tokens]
+ # so that it can be added as a bias onto the attention scores that xformers computes:
+ # [batch*heads, query_tokens, key_tokens]
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
+ _, query_tokens, _ = hidden_states.shape
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = self.to_q(hidden_states, scale=scale)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif self.norm_cross:
+ encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = self.to_k(encoder_hidden_states, scale=scale)
+ value = self.to_v(encoder_hidden_states, scale=scale)
+
+ query = self.head_to_batch_dim(query).contiguous()
+ key = self.head_to_batch_dim(key).contiguous()
+ value = self.head_to_batch_dim(value).contiguous()
+
+ hidden_states = xformers.ops.memory_efficient_attention(
+ query, key, value, attn_bias=attention_mask, op=attention_op#, scale=scale
+ )
+ hidden_states = hidden_states.to(query.dtype)
+ hidden_states = self.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states, scale=scale)
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ return hidden_states
+
+def attn_forward(self, h_):
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = q.reshape(b, c, h*w)
+ q = q.permute(0, 2, 1) # b,hw,c
+ k = k.reshape(b, c, h*w) # b,c,hw
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b, c, h*w)
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = torch.bmm(v, w_)
+ h_ = h_.reshape(b, c, h, w)
+
+ h_ = self.proj_out(h_)
+
+ return h_
+
+
+def xformer_attn_forward(self, h_):
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ B, C, H, W = q.shape
+ q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
+
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(B, t.shape[1], 1, C)
+ .permute(0, 2, 1, 3)
+ .reshape(B * 1, t.shape[1], C)
+ .contiguous(),
+ (q, k, v),
+ )
+ out = xformers.ops.memory_efficient_attention(
+ q, k, v, attn_bias=None, op=self.attention_op)
+
+ out = (
+ out.unsqueeze(0)
+ .reshape(B, 1, out.shape[1], C)
+ .permute(0, 2, 1, 3)
+ .reshape(B, out.shape[1], C)
+ )
+ out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
+ out = self.proj_out(out)
+ return out
+
+
+def attn2task(task_queue, net):
+ if False: #isinstance(net, AttnBlock):
+ task_queue.append(('store_res', lambda x: x))
+ task_queue.append(('pre_norm', net.norm))
+ task_queue.append(('attn', lambda x, net=net: attn_forward(net, x)))
+ task_queue.append(['add_res', None])
+ elif False: #isinstance(net, MemoryEfficientAttnBlock):
+ task_queue.append(('store_res', lambda x: x))
+ task_queue.append(('pre_norm', net.norm))
+ task_queue.append(
+ ('attn', lambda x, net=net: xformer_attn_forward(net, x)))
+ task_queue.append(['add_res', None])
+ else:
+ task_queue.append(('store_res', lambda x: x))
+ task_queue.append(('pre_norm', net.norm))
+ if is_xformers_available:
+ # task_queue.append(('attn', lambda x, net=net: attn_forward_new_xformers(net, x)))
+ task_queue.append(
+ ('attn', lambda x, net=net: xformer_attn_forward(net, x)))
+ elif hasattr(F, "scaled_dot_product_attention"):
+ task_queue.append(('attn', lambda x, net=net: attn_forward_new_pt2_0(net, x)))
+ else:
+ task_queue.append(('attn', lambda x, net=net: attn_forward_new(net, x)))
+ task_queue.append(['add_res', None])
+
+def resblock2task(queue, block):
+ """
+ Turn a ResNetBlock into a sequence of tasks and append to the task queue
+
+ @param queue: the target task queue
+ @param block: ResNetBlock
+
+ """
+ if block.in_channels != block.out_channels:
+ if sd_flag:
+ if block.use_conv_shortcut:
+ queue.append(('store_res', block.conv_shortcut))
+ else:
+ queue.append(('store_res', block.nin_shortcut))
+ else:
+ if block.use_in_shortcut:
+ queue.append(('store_res', block.conv_shortcut))
+ else:
+ queue.append(('store_res', block.nin_shortcut))
+
+ else:
+ queue.append(('store_res', lambda x: x))
+ queue.append(('pre_norm', block.norm1))
+ queue.append(('silu', inplace_nonlinearity))
+ queue.append(('conv1', block.conv1))
+ queue.append(('pre_norm', block.norm2))
+ queue.append(('silu', inplace_nonlinearity))
+ queue.append(('conv2', block.conv2))
+ queue.append(['add_res', None])
+
+
+def build_sampling(task_queue, net, is_decoder):
+ """
+ Build the sampling part of a task queue
+ @param task_queue: the target task queue
+ @param net: the network
+ @param is_decoder: currently building decoder or encoder
+ """
+ if is_decoder:
+ if sd_flag:
+ resblock2task(task_queue, net.mid.block_1)
+ attn2task(task_queue, net.mid.attn_1)
+ print(task_queue)
+ resblock2task(task_queue, net.mid.block_2)
+ resolution_iter = reversed(range(net.num_resolutions))
+ block_ids = net.num_res_blocks + 1
+ condition = 0
+ module = net.up
+ func_name = 'upsample'
+ else:
+ resblock2task(task_queue, net.mid_block.resnets[0])
+ attn2task(task_queue, net.mid_block.attentions[0])
+ resblock2task(task_queue, net.mid_block.resnets[1])
+ resolution_iter = (range(len(net.up_blocks))) # net.num_resolutions = 3
+ block_ids = 2 + 1
+ condition = len(net.up_blocks) - 1
+ module = net.up_blocks
+ func_name = 'upsamplers'
+ else:
+ if sd_flag:
+ resolution_iter = range(net.num_resolutions)
+ block_ids = net.num_res_blocks
+ condition = net.num_resolutions - 1
+ module = net.down
+ func_name = 'downsample'
+ else:
+ resolution_iter = range(len(net.down_blocks))
+ block_ids = 2
+ condition = len(net.down_blocks) - 1
+ module = net.down_blocks
+ func_name = 'downsamplers'
+
+ for i_level in resolution_iter:
+ for i_block in range(block_ids):
+ if sd_flag:
+ resblock2task(task_queue, module[i_level].block[i_block])
+ else:
+ resblock2task(task_queue, module[i_level].resnets[i_block])
+ if i_level != condition:
+ if sd_flag:
+ task_queue.append((func_name, getattr(module[i_level], func_name)))
+ else:
+ if is_decoder:
+ task_queue.append((func_name, module[i_level].upsamplers[0]))
+ else:
+ task_queue.append((func_name, module[i_level].downsamplers[0]))
+
+ if not is_decoder:
+ if sd_flag:
+ resblock2task(task_queue, net.mid.block_1)
+ attn2task(task_queue, net.mid.attn_1)
+ resblock2task(task_queue, net.mid.block_2)
+ else:
+ resblock2task(task_queue, net.mid_block.resnets[0])
+ attn2task(task_queue, net.mid_block.attentions[0])
+ resblock2task(task_queue, net.mid_block.resnets[1])
+
+
+def build_task_queue(net, is_decoder):
+ """
+ Build a single task queue for the encoder or decoder
+ @param net: the VAE decoder or encoder network
+ @param is_decoder: currently building decoder or encoder
+ @return: the task queue
+ """
+ task_queue = []
+ task_queue.append(('conv_in', net.conv_in))
+
+ # construct the sampling part of the task queue
+ # because encoder and decoder share the same architecture, we extract the sampling part
+ build_sampling(task_queue, net, is_decoder)
+ if is_decoder and not sd_flag:
+ net.give_pre_end = False
+ net.tanh_out = False
+
+ if not is_decoder or not net.give_pre_end:
+ if sd_flag:
+ task_queue.append(('pre_norm', net.norm_out))
+ else:
+ task_queue.append(('pre_norm', net.conv_norm_out))
+ task_queue.append(('silu', inplace_nonlinearity))
+ task_queue.append(('conv_out', net.conv_out))
+ if is_decoder and net.tanh_out:
+ task_queue.append(('tanh', torch.tanh))
+
+ return task_queue
+
+
+def clone_task_queue(task_queue):
+ """
+ Clone a task queue
+ @param task_queue: the task queue to be cloned
+ @return: the cloned task queue
+ """
+ return [[item for item in task] for task in task_queue]
+
+
+def get_var_mean(input, num_groups, eps=1e-6):
+ """
+ Get mean and var for group norm
+ """
+ b, c = input.size(0), input.size(1)
+ channel_in_group = int(c/num_groups)
+ input_reshaped = input.contiguous().view(
+ 1, int(b * num_groups), channel_in_group, *input.size()[2:])
+ var, mean = torch.var_mean(
+ input_reshaped, dim=[0, 2, 3, 4], unbiased=False)
+ return var, mean
+
+
+def custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=1e-6):
+ """
+ Custom group norm with fixed mean and var
+
+ @param input: input tensor
+ @param num_groups: number of groups. by default, num_groups = 32
+ @param mean: mean, must be pre-calculated by get_var_mean
+ @param var: var, must be pre-calculated by get_var_mean
+ @param weight: weight, should be fetched from the original group norm
+ @param bias: bias, should be fetched from the original group norm
+ @param eps: epsilon, by default, eps = 1e-6 to match the original group norm
+
+ @return: normalized tensor
+ """
+ b, c = input.size(0), input.size(1)
+ channel_in_group = int(c/num_groups)
+ input_reshaped = input.contiguous().view(
+ 1, int(b * num_groups), channel_in_group, *input.size()[2:])
+
+ out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None,
+ training=False, momentum=0, eps=eps)
+
+ out = out.view(b, c, *input.size()[2:])
+
+ # post affine transform
+ if weight is not None:
+ out *= weight.view(1, -1, 1, 1)
+ if bias is not None:
+ out += bias.view(1, -1, 1, 1)
+ return out
+
+
+def crop_valid_region(x, input_bbox, target_bbox, is_decoder):
+ """
+ Crop the valid region from the tile
+ @param x: input tile
+ @param input_bbox: original input bounding box
+ @param target_bbox: output bounding box
+ @param scale: scale factor
+ @return: cropped tile
+ """
+ padded_bbox = [i * 8 if is_decoder else i//8 for i in input_bbox]
+ margin = [target_bbox[i] - padded_bbox[i] for i in range(4)]
+ return x[:, :, margin[2]:x.size(2)+margin[3], margin[0]:x.size(3)+margin[1]]
+
+# ↓↓↓ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↓↓↓
+
+
+def perfcount(fn):
+ def wrapper(*args, **kwargs):
+ ts = time()
+
+ if torch.cuda.is_available():
+ torch.cuda.reset_peak_memory_stats(devices.device)
+ devices.torch_gc()
+ gc.collect()
+
+ ret = fn(*args, **kwargs)
+
+ devices.torch_gc()
+ gc.collect()
+ if torch.cuda.is_available():
+ vram = torch.cuda.max_memory_allocated(devices.device) / 2**20
+ torch.cuda.reset_peak_memory_stats(devices.device)
+ print(
+ f'[Tiled VAE]: Done in {time() - ts:.3f}s, max VRAM alloc {vram:.3f} MB')
+ else:
+ print(f'[Tiled VAE]: Done in {time() - ts:.3f}s')
+
+ return ret
+ return wrapper
+
+# copy end :)
+
+
+class GroupNormParam:
+ def __init__(self):
+ self.var_list = []
+ self.mean_list = []
+ self.pixel_list = []
+ self.weight = None
+ self.bias = None
+
+ def add_tile(self, tile, layer):
+ var, mean = get_var_mean(tile, 32)
+ # For giant images, the variance can be larger than max float16
+ # In this case we create a copy to float32
+ if var.dtype == torch.float16 and var.isinf().any():
+ fp32_tile = tile.float()
+ var, mean = get_var_mean(fp32_tile, 32)
+ # ============= DEBUG: test for infinite =============
+ # if torch.isinf(var).any():
+ # print('var: ', var)
+ # ====================================================
+ self.var_list.append(var)
+ self.mean_list.append(mean)
+ self.pixel_list.append(
+ tile.shape[2]*tile.shape[3])
+ if hasattr(layer, 'weight'):
+ self.weight = layer.weight
+ self.bias = layer.bias
+ else:
+ self.weight = None
+ self.bias = None
+
+ def summary(self):
+ """
+ summarize the mean and var and return a function
+ that apply group norm on each tile
+ """
+ if len(self.var_list) == 0:
+ return None
+ var = torch.vstack(self.var_list)
+ mean = torch.vstack(self.mean_list)
+ max_value = max(self.pixel_list)
+ pixels = torch.tensor(
+ self.pixel_list, dtype=torch.float32, device=devices.device) / max_value
+ sum_pixels = torch.sum(pixels)
+ pixels = pixels.unsqueeze(
+ 1) / sum_pixels
+ var = torch.sum(
+ var * pixels, dim=0)
+ mean = torch.sum(
+ mean * pixels, dim=0)
+ return lambda x: custom_group_norm(x, 32, mean, var, self.weight, self.bias)
+
+ @staticmethod
+ def from_tile(tile, norm):
+ """
+ create a function from a single tile without summary
+ """
+ var, mean = get_var_mean(tile, 32)
+ if var.dtype == torch.float16 and var.isinf().any():
+ fp32_tile = tile.float()
+ var, mean = get_var_mean(fp32_tile, 32)
+ # if it is a macbook, we need to convert back to float16
+ if var.device.type == 'mps':
+ # clamp to avoid overflow
+ var = torch.clamp(var, 0, 60000)
+ var = var.half()
+ mean = mean.half()
+ if hasattr(norm, 'weight'):
+ weight = norm.weight
+ bias = norm.bias
+ else:
+ weight = None
+ bias = None
+
+ def group_norm_func(x, mean=mean, var=var, weight=weight, bias=bias):
+ return custom_group_norm(x, 32, mean, var, weight, bias, 1e-6)
+ return group_norm_func
+
+
+class VAEHook:
+ def __init__(self, net, tile_size, is_decoder, fast_decoder, fast_encoder, color_fix, to_gpu=False):
+ self.net = net # encoder | decoder
+ self.tile_size = tile_size
+ self.is_decoder = is_decoder
+ self.fast_mode = (fast_encoder and not is_decoder) or (
+ fast_decoder and is_decoder)
+ self.color_fix = color_fix and not is_decoder
+ self.to_gpu = to_gpu
+ self.pad = 11 if is_decoder else 32
+
+ def __call__(self, x):
+ B, C, H, W = x.shape
+ original_device = next(self.net.parameters()).device
+ try:
+ if self.to_gpu:
+ self.net.to(devices.get_optimal_device())
+ if max(H, W) <= self.pad * 2 + self.tile_size:
+ print("[Tiled VAE]: the input size is tiny and unnecessary to tile.")
+ return self.net.original_forward(x)
+ else:
+ return self.vae_tile_forward(x)
+ finally:
+ self.net.to(original_device)
+
+ def get_best_tile_size(self, lowerbound, upperbound):
+ """
+ Get the best tile size for GPU memory
+ """
+ divider = 32
+ while divider >= 2:
+ remainer = lowerbound % divider
+ if remainer == 0:
+ return lowerbound
+ candidate = lowerbound - remainer + divider
+ if candidate <= upperbound:
+ return candidate
+ divider //= 2
+ return lowerbound
+
+ def split_tiles(self, h, w):
+ """
+ Tool function to split the image into tiles
+ @param h: height of the image
+ @param w: width of the image
+ @return: tile_input_bboxes, tile_output_bboxes
+ """
+ tile_input_bboxes, tile_output_bboxes = [], []
+ tile_size = self.tile_size
+ pad = self.pad
+ num_height_tiles = math.ceil((h - 2 * pad) / tile_size)
+ num_width_tiles = math.ceil((w - 2 * pad) / tile_size)
+ # If any of the numbers are 0, we let it be 1
+ # This is to deal with long and thin images
+ num_height_tiles = max(num_height_tiles, 1)
+ num_width_tiles = max(num_width_tiles, 1)
+
+ # Suggestions from https://github.com/Kahsolt: auto shrink the tile size
+ real_tile_height = math.ceil((h - 2 * pad) / num_height_tiles)
+ real_tile_width = math.ceil((w - 2 * pad) / num_width_tiles)
+ real_tile_height = self.get_best_tile_size(real_tile_height, tile_size)
+ real_tile_width = self.get_best_tile_size(real_tile_width, tile_size)
+
+ print(f'[Tiled VAE]: split to {num_height_tiles}x{num_width_tiles} = {num_height_tiles*num_width_tiles} tiles. ' +
+ f'Optimal tile size {real_tile_width}x{real_tile_height}, original tile size {tile_size}x{tile_size}')
+
+ for i in range(num_height_tiles):
+ for j in range(num_width_tiles):
+ # bbox: [x1, x2, y1, y2]
+ # the padding is is unnessary for image borders. So we directly start from (32, 32)
+ input_bbox = [
+ pad + j * real_tile_width,
+ min(pad + (j + 1) * real_tile_width, w),
+ pad + i * real_tile_height,
+ min(pad + (i + 1) * real_tile_height, h),
+ ]
+
+ # if the output bbox is close to the image boundary, we extend it to the image boundary
+ output_bbox = [
+ input_bbox[0] if input_bbox[0] > pad else 0,
+ input_bbox[1] if input_bbox[1] < w - pad else w,
+ input_bbox[2] if input_bbox[2] > pad else 0,
+ input_bbox[3] if input_bbox[3] < h - pad else h,
+ ]
+
+ # scale to get the final output bbox
+ output_bbox = [x * 8 if self.is_decoder else x // 8 for x in output_bbox]
+ tile_output_bboxes.append(output_bbox)
+
+ # indistinguishable expand the input bbox by pad pixels
+ tile_input_bboxes.append([
+ max(0, input_bbox[0] - pad),
+ min(w, input_bbox[1] + pad),
+ max(0, input_bbox[2] - pad),
+ min(h, input_bbox[3] + pad),
+ ])
+
+ return tile_input_bboxes, tile_output_bboxes
+
+ @torch.no_grad()
+ def estimate_group_norm(self, z, task_queue, color_fix):
+ device = z.device
+ tile = z
+ last_id = len(task_queue) - 1
+ while last_id >= 0 and task_queue[last_id][0] != 'pre_norm':
+ last_id -= 1
+ if last_id <= 0 or task_queue[last_id][0] != 'pre_norm':
+ raise ValueError('No group norm found in the task queue')
+ # estimate until the last group norm
+ for i in range(last_id + 1):
+ task = task_queue[i]
+ if task[0] == 'pre_norm':
+ group_norm_func = GroupNormParam.from_tile(tile, task[1])
+ task_queue[i] = ('apply_norm', group_norm_func)
+ if i == last_id:
+ return True
+ tile = group_norm_func(tile)
+ elif task[0] == 'store_res':
+ task_id = i + 1
+ while task_id < last_id and task_queue[task_id][0] != 'add_res':
+ task_id += 1
+ if task_id >= last_id:
+ continue
+ task_queue[task_id][1] = task[1](tile)
+ elif task[0] == 'add_res':
+ tile += task[1].to(device)
+ task[1] = None
+ elif color_fix and task[0] == 'downsample':
+ for j in range(i, last_id + 1):
+ if task_queue[j][0] == 'store_res':
+ task_queue[j] = ('store_res_cpu', task_queue[j][1])
+ return True
+ else:
+ tile = task[1](tile)
+ try:
+ devices.test_for_nans(tile, "vae")
+ except:
+ print(f'Nan detected in fast mode estimation. Fast mode disabled.')
+ return False
+
+ raise IndexError('Should not reach here')
+
+ @perfcount
+ @torch.no_grad()
+ def vae_tile_forward(self, z):
+ """
+ Decode a latent vector z into an image in a tiled manner.
+ @param z: latent vector
+ @return: image
+ """
+ device = next(self.net.parameters()).device
+ dtype = z.dtype
+ net = self.net
+ tile_size = self.tile_size
+ is_decoder = self.is_decoder
+
+ z = z.detach() # detach the input to avoid backprop
+
+ N, height, width = z.shape[0], z.shape[2], z.shape[3]
+ net.last_z_shape = z.shape
+
+ # Split the input into tiles and build a task queue for each tile
+ print(f'[Tiled VAE]: input_size: {z.shape}, tile_size: {tile_size}, padding: {self.pad}')
+
+ in_bboxes, out_bboxes = self.split_tiles(height, width)
+
+ # Prepare tiles by split the input latents
+ tiles = []
+ for input_bbox in in_bboxes:
+ tile = z[:, :, input_bbox[2]:input_bbox[3], input_bbox[0]:input_bbox[1]].cpu()
+ tiles.append(tile)
+
+ num_tiles = len(tiles)
+ num_completed = 0
+
+ # Build task queues
+ single_task_queue = build_task_queue(net, is_decoder)
+ #print(single_task_queue)
+ if self.fast_mode:
+ # Fast mode: downsample the input image to the tile size,
+ # then estimate the group norm parameters on the downsampled image
+ scale_factor = tile_size / max(height, width)
+ z = z.to(device)
+ downsampled_z = F.interpolate(z, scale_factor=scale_factor, mode='nearest-exact')
+ # use nearest-exact to keep statictics as close as possible
+ print(f'[Tiled VAE]: Fast mode enabled, estimating group norm parameters on {downsampled_z.shape[3]} x {downsampled_z.shape[2]} image')
+
+ # ======= Special thanks to @Kahsolt for distribution shift issue ======= #
+ # The downsampling will heavily distort its mean and std, so we need to recover it.
+ std_old, mean_old = torch.std_mean(z, dim=[0, 2, 3], keepdim=True)
+ std_new, mean_new = torch.std_mean(downsampled_z, dim=[0, 2, 3], keepdim=True)
+ downsampled_z = (downsampled_z - mean_new) / std_new * std_old + mean_old
+ del std_old, mean_old, std_new, mean_new
+ # occasionally the std_new is too small or too large, which exceeds the range of float16
+ # so we need to clamp it to max z's range.
+ downsampled_z = torch.clamp_(downsampled_z, min=z.min(), max=z.max())
+ estimate_task_queue = clone_task_queue(single_task_queue)
+ if self.estimate_group_norm(downsampled_z, estimate_task_queue, color_fix=self.color_fix):
+ single_task_queue = estimate_task_queue
+ del downsampled_z
+
+ task_queues = [clone_task_queue(single_task_queue) for _ in range(num_tiles)]
+
+ # Dummy result
+ result = None
+ result_approx = None
+ #try:
+ # with devices.autocast():
+ # result_approx = torch.cat([F.interpolate(cheap_approximation(x).unsqueeze(0), scale_factor=opt_f, mode='nearest-exact') for x in z], dim=0).cpu()
+ #except: pass
+ # Free memory of input latent tensor
+ del z
+
+ # Task queue execution
+ pbar = tqdm(total=num_tiles * len(task_queues[0]), desc=f"[Tiled VAE]: Executing {'Decoder' if is_decoder else 'Encoder'} Task Queue: ")
+
+ # execute the task back and forth when switch tiles so that we always
+ # keep one tile on the GPU to reduce unnecessary data transfer
+ forward = True
+ interrupted = False
+ #state.interrupted = interrupted
+ while True:
+ #if state.interrupted: interrupted = True ; break
+
+ group_norm_param = GroupNormParam()
+ for i in range(num_tiles) if forward else reversed(range(num_tiles)):
+ #if state.interrupted: interrupted = True ; break
+
+ tile = tiles[i].to(device)
+ input_bbox = in_bboxes[i]
+ task_queue = task_queues[i]
+
+ interrupted = False
+ while len(task_queue) > 0:
+ #if state.interrupted: interrupted = True ; break
+
+ # DEBUG: current task
+ # print('Running task: ', task_queue[0][0], ' on tile ', i, '/', num_tiles, ' with shape ', tile.shape)
+ task = task_queue.pop(0)
+ if task[0] == 'pre_norm':
+ group_norm_param.add_tile(tile, task[1])
+ break
+ elif task[0] == 'store_res' or task[0] == 'store_res_cpu':
+ task_id = 0
+ res = task[1](tile)
+ if not self.fast_mode or task[0] == 'store_res_cpu':
+ res = res.cpu()
+ while task_queue[task_id][0] != 'add_res':
+ task_id += 1
+ task_queue[task_id][1] = res
+ elif task[0] == 'add_res':
+ tile += task[1].to(device)
+ task[1] = None
+ else:
+ tile = task[1](tile)
+ #print(tiles[i].shape, tile.shape, task)
+ pbar.update(1)
+
+ if interrupted: break
+
+ # check for NaNs in the tile.
+ # If there are NaNs, we abort the process to save user's time
+ #devices.test_for_nans(tile, "vae")
+
+ #print(tiles[i].shape, tile.shape, i, num_tiles)
+ if len(task_queue) == 0:
+ tiles[i] = None
+ num_completed += 1
+ if result is None: # NOTE: dim C varies from different cases, can only be inited dynamically
+ result = torch.zeros((N, tile.shape[1], height * 8 if is_decoder else height // 8, width * 8 if is_decoder else width // 8), device=device, requires_grad=False)
+ result[:, :, out_bboxes[i][2]:out_bboxes[i][3], out_bboxes[i][0]:out_bboxes[i][1]] = crop_valid_region(tile, in_bboxes[i], out_bboxes[i], is_decoder)
+ del tile
+ elif i == num_tiles - 1 and forward:
+ forward = False
+ tiles[i] = tile
+ elif i == 0 and not forward:
+ forward = True
+ tiles[i] = tile
+ else:
+ tiles[i] = tile.cpu()
+ del tile
+
+ if interrupted: break
+ if num_completed == num_tiles: break
+
+ # insert the group norm task to the head of each task queue
+ group_norm_func = group_norm_param.summary()
+ if group_norm_func is not None:
+ for i in range(num_tiles):
+ task_queue = task_queues[i]
+ task_queue.insert(0, ('apply_norm', group_norm_func))
+
+ # Done!
+ pbar.close()
+ return result.to(dtype) if result is not None else result_approx.to(device)
\ No newline at end of file
diff --git a/app.py b/app.py
index 8c56db5902062894dad68eb34126b4d86c1bc272..617f099c56c9436762d3445a2b5ce836e9c6dfc7 100644
--- a/app.py
+++ b/app.py
@@ -1,78 +1,909 @@
-import torch
-from PIL import Image
-from RealESRGAN import RealESRGAN
-import gradio as gr
-
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-model2 = RealESRGAN(device, scale=2)
-model2.load_weights('weights/RealESRGAN_x2.pth', download=True)
-model4 = RealESRGAN(device, scale=4)
-model4.load_weights('weights/RealESRGAN_x4.pth', download=True)
-model8 = RealESRGAN(device, scale=8)
-model8.load_weights('weights/RealESRGAN_x8.pth', download=True)
-
-
-def inference(image, size):
- global model2
- global model4
- global model8
- if image is None:
- raise gr.Error("Image not uploaded")
-
-
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
-
- if size == '2x':
- try:
- result = model2.predict(image.convert('RGB'))
- except torch.cuda.OutOfMemoryError as e:
- print(e)
- model2 = RealESRGAN(device, scale=2)
- model2.load_weights('weights/RealESRGAN_x2.pth', download=False)
- result = model2.predict(image.convert('RGB'))
- elif size == '4x':
- try:
- result = model4.predict(image.convert('RGB'))
- except torch.cuda.OutOfMemoryError as e:
- print(e)
- model4 = RealESRGAN(device, scale=4)
- model4.load_weights('weights/RealESRGAN_x4.pth', download=False)
- result = model2.predict(image.convert('RGB'))
- else:
- try:
- width, height = image.size
- if width >= 5000 or height >= 5000:
- raise gr.Error("The image is too large.")
- result = model8.predict(image.convert('RGB'))
- except torch.cuda.OutOfMemoryError as e:
- print(e)
- model8 = RealESRGAN(device, scale=8)
- model8.load_weights('weights/RealESRGAN_x8.pth', download=False)
- result = model2.predict(image.convert('RGB'))
-
- print(f"Image size ({device}): {size} ... OK")
- return result
-
-
-title = "Face Real ESRGAN UpScale: 2x 4x 8x"
-description = "This is an unofficial demo for Real-ESRGAN. Scales the resolution of a photo. This model shows better results on faces compared to the original version.
Telegram BOT: https://t.me/restoration_photo_bot"
-article = "
This is an online demo of SUPIR, a practicing model scaling for photo-realistic image restoration. + The content added by SUPIR is imagination, not real-world information. + SUPIR is for beauty and illustration only. + Most of the processes last few minutes. + If you want to upscale AI-generated images, be noticed that PixArt Sigma space can directly generate 5984x5984 images. + Due to Gradio issues, the generated image is slightly less satured than the original. + Please leave a message in discussion if you encounter issues. + You can also use AuraSR to upscale x4. + +
⚠️To use SUPIR, duplicate this space and set a GPU with 30 GB VRAM. + + You can't use SUPIR directly here because this space runs on a CPU, which is not enough for SUPIR. Please provide feedback if you have issues. +
+ """) + gr.HTML(title_html) + + input_image = gr.Image(label="Input (*.png, *.webp, *.jpeg, *.jpg, *.gif, *.bmp, *.heic)", show_label=True, type="filepath", height=600, elem_id="image-input") + rotation = gr.Radio([["No rotation", 0], ["⤵ Rotate +90°", 90], ["↩ Return 180°", 180], ["⤴ Rotate -90°", -90]], label="Orientation correction", info="Will apply the following rotation before restoring the image; the AI needs a good orientation to understand the content", value=0, interactive=True, visible=False) + with gr.Group(): + prompt = gr.Textbox(label="Image description", info="Help the AI understand what the image represents; describe as much as possible, especially the details we can't see on the original image; you can write in any language", value="", placeholder="A 33 years old man, walking, in the street, Santiago, morning, Summer, photorealistic", lines=3) + prompt_hint = gr.HTML("You can use a LlaVa space to auto-generate the description of your image.") + upscale = gr.Radio([["x1", 1], ["x2", 2], ["x3", 3], ["x4", 4], ["x5", 5], ["x6", 6], ["x7", 7], ["x8", 8], ["x9", 9], ["x10", 10], ["x20", 20], ["x100", 100]], label="Upscale factor", info="Resolution x1 to x100", value=2, interactive=True) + output_format = gr.Radio([["As input", "input"], ["*.png", "png"], ["*.webp", "webp"], ["*.jpeg", "jpeg"], ["*.gif", "gif"], ["*.bmp", "bmp"]], label="Image format for result", info="File extention", value="input", interactive=True) + allocation = gr.Radio([["1 min", 1], ["2 min", 2], ["3 min", 3], ["4 min", 4], ["5 min", 5]], label="GPU allocation time", info="lower=May abort run, higher=Quota penalty for next runs", value=3, interactive=True) + + with gr.Accordion("Pre-denoising (optional)", open=False): + gamma_correction = gr.Slider(label="Gamma Correction", info = "lower=lighter, higher=darker", minimum=0.1, maximum=2.0, value=1.0, step=0.1) + denoise_button = gr.Button(value="Pre-denoise") + denoise_image = gr.Image(label="Denoised image", show_label=True, type="filepath", sources=[], interactive = False, height=600, elem_id="image-s1") + denoise_information = gr.HTML(value="If present, the denoised image will be used for the restoration instead of the input image.", visible=False) + + with gr.Accordion("Advanced options", open=False): + a_prompt = gr.Textbox(label="Additional image description", + info="Completes the main image description", + value='Cinematic, High Contrast, highly detailed, taken using a Canon EOS R ' + 'camera, hyper detailed photo - realistic maximum detail, 32k, Color ' + 'Grading, ultra HD, extreme meticulous detailing, skin pore detailing, clothing fabric detailing, ' + 'hyper sharpness, perfect without deformations.', + lines=3) + n_prompt = gr.Textbox(label="Negative image description", + info="Disambiguate by listing what the image does NOT represent", + value='painting, oil painting, illustration, drawing, art, sketch, anime, ' + 'cartoon, CG Style, 3D render, unreal engine, blurring, aliasing, pixel, unsharp, weird textures, ugly, dirty, messy, ' + 'worst quality, low quality, frames, watermark, signature, jpeg artifacts, ' + 'deformed, lowres, over-smooth', + lines=3) + edm_steps = gr.Slider(label="Steps", info="lower=faster, higher=more details; too many steps create a checker effect", minimum=1, maximum=200, value=default_setting.edm_steps if torch.cuda.device_count() > 0 else 1, step=1) + num_samples = gr.Slider(label="Num Samples", info="Number of generated results", minimum=1, maximum=4 if not args.use_image_slider else 1 + , value=1, step=1) + min_size = gr.Slider(label="Minimum size", info="Minimum height, minimum width of the result", minimum=32, maximum=4096, value=1024, step=32) + downscale = gr.Radio([["/1", 1], ["/2", 2], ["/3", 3], ["/4", 4], ["/5", 5], ["/6", 6], ["/7", 7], ["/8", 8], ["/9", 9], ["/10", 10]], label="Pre-downscale factor", info="Reducing blurred image reduce the process time", value=1, interactive=True) + with gr.Row(): + with gr.Column(): + model_select = gr.Radio([["💃 Quality (v0-Q)", "v0-Q"], ["🎯 Fidelity (v0-F)", "v0-F"]], label="Model Selection", info="Pretrained model", value="v0-Q", + interactive=True) + with gr.Column(): + color_fix_type = gr.Radio([["None", "None"], ["AdaIn (improve as a photo)", "AdaIn"], ["Wavelet (for JPEG artifacts)", "Wavelet"]], label="Color-Fix Type", info="AdaIn=Improve following a style, Wavelet=For JPEG artifacts", value="AdaIn", + interactive=True) + s_cfg = gr.Slider(label="Text Guidance Scale", info="lower=follow the image, higher=follow the prompt", minimum=1.0, maximum=15.0, + value=default_setting.s_cfg_Quality if torch.cuda.device_count() > 0 else 1.0, step=0.1) + s_stage2 = gr.Slider(label="Restoring Guidance Strength", minimum=0., maximum=1., value=1., step=0.05) + s_stage1 = gr.Slider(label="Pre-denoising Guidance Strength", minimum=-1.0, maximum=6.0, value=-1.0, step=1.0) + s_churn = gr.Slider(label="S-Churn", minimum=0, maximum=40, value=5, step=1) + s_noise = gr.Slider(label="S-Noise", minimum=1.0, maximum=1.1, value=1.003, step=0.001) + with gr.Row(): + with gr.Column(): + linear_CFG = gr.Checkbox(label="Linear CFG", value=True) + spt_linear_CFG = gr.Slider(label="CFG Start", minimum=1.0, + maximum=9.0, value=default_setting.spt_linear_CFG_Quality if torch.cuda.device_count() > 0 else 1.0, step=0.5) + with gr.Column(): + linear_s_stage2 = gr.Checkbox(label="Linear Restoring Guidance", value=False) + spt_linear_s_stage2 = gr.Slider(label="Guidance Start", minimum=0., + maximum=1., value=0., step=0.05) + with gr.Column(): + diff_dtype = gr.Radio([["fp32 (precision)", "fp32"], ["fp16 (medium)", "fp16"], ["bf16 (speed)", "bf16"]], label="Diffusion Data Type", value="fp32", + interactive=True) + with gr.Column(): + ae_dtype = gr.Radio([["fp32 (precision)", "fp32"], ["bf16 (speed)", "bf16"]], label="Auto-Encoder Data Type", value="fp32", + interactive=True) + randomize_seed = gr.Checkbox(label = "\U0001F3B2 Randomize seed", value = True, info = "If checked, result is always different") + seed = gr.Slider(label="Seed", minimum=0, maximum=max_64_bit_int, step=1, randomize=True) + with gr.Group(): + param_setting = gr.Radio(["Quality", "Fidelity"], interactive=True, label="Presetting", value = "Quality") + restart_button = gr.Button(value="Apply presetting") + + with gr.Column(): + diffusion_button = gr.Button(value="🚀 Upscale/Restore", variant = "primary", elem_id = "process_button") + reset_btn = gr.Button(value="🧹 Reinit page", variant="stop", elem_id="reset_button", visible = False) + + warning = gr.HTML(value = "chrome://discards/