|
import os
|
|
import torch
|
|
from torch.nn import functional as F
|
|
from omegaconf import OmegaConf
|
|
import comfy.utils
|
|
import comfy.model_management as mm
|
|
import folder_paths
|
|
from nodes import ImageScaleBy
|
|
from nodes import ImageScale
|
|
import torch.cuda
|
|
from .sgm.util import instantiate_from_config
|
|
from .SUPIR.util import convert_dtype, load_state_dict
|
|
import open_clip
|
|
from contextlib import contextmanager
|
|
|
|
from transformers import (
|
|
CLIPTextModel,
|
|
CLIPTokenizer,
|
|
CLIPTextConfig,
|
|
|
|
)
|
|
script_directory = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
def dummy_build_vision_tower(*args, **kwargs):
|
|
|
|
return None
|
|
|
|
@contextmanager
|
|
def patch_build_vision_tower():
|
|
original_build_vision_tower = open_clip.model._build_vision_tower
|
|
open_clip.model._build_vision_tower = dummy_build_vision_tower
|
|
|
|
try:
|
|
yield
|
|
finally:
|
|
open_clip.model._build_vision_tower = original_build_vision_tower
|
|
|
|
def build_text_model_from_openai_state_dict(
|
|
state_dict: dict,
|
|
cast_dtype=torch.float16,
|
|
):
|
|
|
|
embed_dim = state_dict["text_projection"].shape[1]
|
|
context_length = state_dict["positional_embedding"].shape[0]
|
|
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
|
transformer_width = state_dict["ln_final.weight"].shape[0]
|
|
transformer_heads = transformer_width // 64
|
|
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
|
|
|
vision_cfg = None
|
|
text_cfg = open_clip.CLIPTextCfg(
|
|
context_length=context_length,
|
|
vocab_size=vocab_size,
|
|
width=transformer_width,
|
|
heads=transformer_heads,
|
|
layers=transformer_layers,
|
|
)
|
|
|
|
with patch_build_vision_tower():
|
|
model = open_clip.CLIP(
|
|
embed_dim,
|
|
vision_cfg=vision_cfg,
|
|
text_cfg=text_cfg,
|
|
quick_gelu=True,
|
|
cast_dtype=cast_dtype,
|
|
)
|
|
|
|
model.load_state_dict(state_dict, strict=False)
|
|
model = model.eval()
|
|
for param in model.parameters():
|
|
param.requires_grad = False
|
|
return model
|
|
|
|
class SUPIR_Upscale:
|
|
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": {
|
|
"supir_model": (folder_paths.get_filename_list("checkpoints"),),
|
|
"sdxl_model": (folder_paths.get_filename_list("checkpoints"),),
|
|
"image": ("IMAGE",),
|
|
"seed": ("INT", {"default": 123, "min": 0, "max": 0xffffffffffffffff, "step": 1}),
|
|
"resize_method": (s.upscale_methods, {"default": "lanczos"}),
|
|
"scale_by": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 20.0, "step": 0.01}),
|
|
"steps": ("INT", {"default": 45, "min": 3, "max": 4096, "step": 1}),
|
|
"restoration_scale": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 6.0, "step": 1.0}),
|
|
"cfg_scale": ("FLOAT", {"default": 4.0, "min": 0, "max": 100, "step": 0.01}),
|
|
"a_prompt": ("STRING", {"multiline": True, "default": "high quality, detailed", }),
|
|
"n_prompt": ("STRING", {"multiline": True, "default": "bad quality, blurry, messy", }),
|
|
"s_churn": ("INT", {"default": 5, "min": 0, "max": 40, "step": 1}),
|
|
"s_noise": ("FLOAT", {"default": 1.003, "min": 1.0, "max": 1.1, "step": 0.001}),
|
|
"control_scale": ("FLOAT", {"default": 1.0, "min": 0, "max": 10.0, "step": 0.05}),
|
|
"cfg_scale_start": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 100.0, "step": 0.05}),
|
|
"control_scale_start": ("FLOAT", {"default": 0.0, "min": 0, "max": 1.0, "step": 0.05}),
|
|
"color_fix_type": (
|
|
[
|
|
'None',
|
|
'AdaIn',
|
|
'Wavelet',
|
|
], {
|
|
"default": 'Wavelet'
|
|
}),
|
|
"keep_model_loaded": ("BOOLEAN", {"default": True}),
|
|
"use_tiled_vae": ("BOOLEAN", {"default": True}),
|
|
"encoder_tile_size_pixels": ("INT", {"default": 512, "min": 64, "max": 8192, "step": 64}),
|
|
"decoder_tile_size_latent": ("INT", {"default": 64, "min": 32, "max": 8192, "step": 64}),
|
|
},
|
|
"optional": {
|
|
"captions": ("STRING", {"forceInput": True, "multiline": False, "default": "", }),
|
|
"diffusion_dtype": (
|
|
[
|
|
'fp16',
|
|
'bf16',
|
|
'fp32',
|
|
'auto'
|
|
], {
|
|
"default": 'auto'
|
|
}),
|
|
"encoder_dtype": (
|
|
[
|
|
'bf16',
|
|
'fp32',
|
|
'auto'
|
|
], {
|
|
"default": 'auto'
|
|
}),
|
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 128, "step": 1}),
|
|
"use_tiled_sampling": ("BOOLEAN", {"default": False}),
|
|
"sampler_tile_size": ("INT", {"default": 1024, "min": 64, "max": 4096, "step": 32}),
|
|
"sampler_tile_stride": ("INT", {"default": 512, "min": 32, "max": 2048, "step": 32}),
|
|
"fp8_unet": ("BOOLEAN", {"default": False}),
|
|
"fp8_vae": ("BOOLEAN", {"default": False}),
|
|
"sampler": (
|
|
[
|
|
'RestoreDPMPP2MSampler',
|
|
'RestoreEDMSampler',
|
|
], {
|
|
"default": 'RestoreEDMSampler'
|
|
}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
RETURN_NAMES = ("upscaled_image",)
|
|
FUNCTION = "process"
|
|
|
|
CATEGORY = "SUPIR"
|
|
|
|
def process(self, steps, image, color_fix_type, seed, scale_by, cfg_scale, resize_method, s_churn, s_noise,
|
|
encoder_tile_size_pixels, decoder_tile_size_latent,
|
|
control_scale, cfg_scale_start, control_scale_start, restoration_scale, keep_model_loaded,
|
|
a_prompt, n_prompt, sdxl_model, supir_model, use_tiled_vae, use_tiled_sampling=False, sampler_tile_size=128, sampler_tile_stride=64, captions="", diffusion_dtype="auto",
|
|
encoder_dtype="auto", batch_size=1, fp8_unet=False, fp8_vae=False, sampler="RestoreEDMSampler"):
|
|
device = mm.get_torch_device()
|
|
mm.unload_all_models()
|
|
|
|
SUPIR_MODEL_PATH = folder_paths.get_full_path("checkpoints", supir_model)
|
|
SDXL_MODEL_PATH = folder_paths.get_full_path("checkpoints", sdxl_model)
|
|
|
|
config_path = os.path.join(script_directory, "options/SUPIR_v0.yaml")
|
|
config_path_tiled = os.path.join(script_directory, "options/SUPIR_v0_tiled.yaml")
|
|
clip_config_path = os.path.join(script_directory, "configs/clip_vit_config.json")
|
|
tokenizer_path = os.path.join(script_directory, "configs/tokenizer")
|
|
|
|
custom_config = {
|
|
'sdxl_model': sdxl_model,
|
|
'diffusion_dtype': diffusion_dtype,
|
|
'encoder_dtype': encoder_dtype,
|
|
'use_tiled_vae': use_tiled_vae,
|
|
'supir_model': supir_model,
|
|
'use_tiled_sampling': use_tiled_sampling,
|
|
'fp8_unet': fp8_unet,
|
|
'fp8_vae': fp8_vae,
|
|
'sampler': sampler
|
|
}
|
|
|
|
if diffusion_dtype == 'auto':
|
|
try:
|
|
if mm.should_use_fp16():
|
|
print("Diffusion using fp16")
|
|
dtype = torch.float16
|
|
model_dtype = 'fp16'
|
|
if mm.should_use_bf16():
|
|
print("Diffusion using bf16")
|
|
dtype = torch.bfloat16
|
|
model_dtype = 'bf16'
|
|
else:
|
|
print("Diffusion using using fp32")
|
|
dtype = torch.float32
|
|
model_dtype = 'fp32'
|
|
except:
|
|
raise AttributeError("ComfyUI too old, can't autodecet properly. Set your dtypes manually.")
|
|
else:
|
|
print(f"Diffusion using using {diffusion_dtype}")
|
|
dtype = convert_dtype(diffusion_dtype)
|
|
model_dtype = diffusion_dtype
|
|
|
|
if encoder_dtype == 'auto':
|
|
try:
|
|
if mm.should_use_bf16():
|
|
print("Encoder using bf16")
|
|
vae_dtype = 'bf16'
|
|
else:
|
|
print("Encoder using using fp32")
|
|
vae_dtype = 'fp32'
|
|
except:
|
|
raise AttributeError("ComfyUI too old, can't autodetect properly. Set your dtypes manually.")
|
|
else:
|
|
vae_dtype = encoder_dtype
|
|
print(f"Encoder using using {vae_dtype}")
|
|
|
|
if not hasattr(self, "model") or self.model is None or self.current_config != custom_config:
|
|
self.current_config = custom_config
|
|
self.model = None
|
|
|
|
mm.soft_empty_cache()
|
|
|
|
if use_tiled_sampling:
|
|
config = OmegaConf.load(config_path_tiled)
|
|
config.model.params.sampler_config.params.tile_size = sampler_tile_size // 8
|
|
config.model.params.sampler_config.params.tile_stride = sampler_tile_stride // 8
|
|
config.model.params.sampler_config.target = f".sgm.modules.diffusionmodules.sampling.Tiled{sampler}"
|
|
print("Using tiled sampling")
|
|
else:
|
|
config = OmegaConf.load(config_path)
|
|
config.model.params.sampler_config.target = f".sgm.modules.diffusionmodules.sampling.{sampler}"
|
|
print("Using non-tiled sampling")
|
|
|
|
if mm.XFORMERS_IS_AVAILABLE:
|
|
config.model.params.control_stage_config.params.spatial_transformer_attn_type = "softmax-xformers"
|
|
config.model.params.network_config.params.spatial_transformer_attn_type = "softmax-xformers"
|
|
config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla-xformers"
|
|
|
|
config.model.params.ae_dtype = vae_dtype
|
|
config.model.params.diffusion_dtype = model_dtype
|
|
|
|
self.model = instantiate_from_config(config.model).cpu()
|
|
|
|
try:
|
|
print(f'Attempting to load SUPIR model: [{SUPIR_MODEL_PATH}]')
|
|
supir_state_dict = load_state_dict(SUPIR_MODEL_PATH)
|
|
|
|
except:
|
|
raise Exception("Failed to load SUPIR model")
|
|
try:
|
|
print(f"Attempting to load SDXL model: [{SDXL_MODEL_PATH}]")
|
|
sdxl_state_dict = load_state_dict(SDXL_MODEL_PATH)
|
|
except:
|
|
raise Exception("Failed to load SDXL model")
|
|
self.model.load_state_dict(supir_state_dict, strict=False)
|
|
self.model.load_state_dict(sdxl_state_dict, strict=False)
|
|
|
|
del supir_state_dict
|
|
|
|
|
|
try:
|
|
print("Loading first clip model from SDXL checkpoint")
|
|
|
|
replace_prefix = {}
|
|
replace_prefix["conditioner.embedders.0.transformer."] = ""
|
|
|
|
sd = comfy.utils.state_dict_prefix_replace(sdxl_state_dict, replace_prefix, filter_keys=False)
|
|
clip_text_config = CLIPTextConfig.from_pretrained(clip_config_path)
|
|
self.model.conditioner.embedders[0].tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
|
self.model.conditioner.embedders[0].transformer = CLIPTextModel(clip_text_config)
|
|
self.model.conditioner.embedders[0].transformer.load_state_dict(sd, strict=False)
|
|
self.model.conditioner.embedders[0].eval()
|
|
for param in self.model.conditioner.embedders[0].parameters():
|
|
param.requires_grad = False
|
|
except:
|
|
raise Exception("Failed to load first clip model from SDXL checkpoint")
|
|
|
|
del sdxl_state_dict
|
|
|
|
|
|
try:
|
|
print("Loading second clip model from SDXL checkpoint")
|
|
replace_prefix2 = {}
|
|
replace_prefix2["conditioner.embedders.1.model."] = ""
|
|
sd = comfy.utils.state_dict_prefix_replace(sd, replace_prefix2, filter_keys=True)
|
|
clip_g = build_text_model_from_openai_state_dict(sd, cast_dtype=dtype)
|
|
self.model.conditioner.embedders[1].model = clip_g
|
|
except:
|
|
raise Exception("Failed to load second clip model from SDXL checkpoint")
|
|
|
|
del sd, clip_g
|
|
mm.soft_empty_cache()
|
|
|
|
self.model.to(dtype)
|
|
|
|
|
|
if fp8_unet:
|
|
self.model.model.to(torch.float8_e4m3fn)
|
|
if fp8_vae:
|
|
self.model.first_stage_model.to(torch.float8_e4m3fn)
|
|
|
|
if use_tiled_vae:
|
|
self.model.init_tile_vae(encoder_tile_size=encoder_tile_size_pixels, decoder_tile_size=decoder_tile_size_latent)
|
|
|
|
upscaled_image, = ImageScaleBy.upscale(self, image, resize_method, scale_by)
|
|
B, H, W, C = upscaled_image.shape
|
|
new_height = H if H % 64 == 0 else ((H // 64) + 1) * 64
|
|
new_width = W if W % 64 == 0 else ((W // 64) + 1) * 64
|
|
upscaled_image = upscaled_image.permute(0, 3, 1, 2)
|
|
resized_image = F.interpolate(upscaled_image, size=(new_height, new_width), mode='bicubic', align_corners=False)
|
|
resized_image = resized_image.to(device)
|
|
|
|
captions_list = []
|
|
captions_list.append(captions)
|
|
print("captions: ", captions_list)
|
|
|
|
use_linear_CFG = cfg_scale_start > 0
|
|
use_linear_control_scale = control_scale_start > 0
|
|
out = []
|
|
pbar = comfy.utils.ProgressBar(B)
|
|
|
|
batched_images = [resized_image[i:i + batch_size] for i in
|
|
range(0, len(resized_image), batch_size)]
|
|
captions_list = captions_list * resized_image.shape[0]
|
|
batched_captions = [captions_list[i:i + batch_size] for i in range(0, len(captions_list), batch_size)]
|
|
|
|
mm.soft_empty_cache()
|
|
i = 1
|
|
for imgs, caps in zip(batched_images, batched_captions):
|
|
try:
|
|
samples = self.model.batchify_sample(imgs, caps, num_steps=steps,
|
|
restoration_scale=restoration_scale, s_churn=s_churn,
|
|
s_noise=s_noise, cfg_scale=cfg_scale, control_scale=control_scale,
|
|
seed=seed,
|
|
num_samples=1, p_p=a_prompt, n_p=n_prompt,
|
|
color_fix_type=color_fix_type,
|
|
use_linear_CFG=use_linear_CFG,
|
|
use_linear_control_scale=use_linear_control_scale,
|
|
cfg_scale_start=cfg_scale_start,
|
|
control_scale_start=control_scale_start)
|
|
except torch.cuda.OutOfMemoryError as e:
|
|
mm.free_memory(mm.get_total_memory(mm.get_torch_device()), mm.get_torch_device())
|
|
self.model = None
|
|
mm.soft_empty_cache()
|
|
print("It's likely that too large of an image or batch_size for SUPIR was used,"
|
|
" and it has devoured all of the memory it had reserved, you may need to restart ComfyUI. Make sure you are using tiled_vae, "
|
|
" you can also try using fp8 for reduced memory usage if your system supports it.")
|
|
raise e
|
|
|
|
out.append(samples.squeeze(0).cpu())
|
|
print("Sampled ", i * len(imgs), " out of ", B)
|
|
i = i + 1
|
|
pbar.update(1)
|
|
if not keep_model_loaded:
|
|
self.model = None
|
|
mm.soft_empty_cache()
|
|
|
|
if len(out[0].shape) == 4:
|
|
out_stacked = torch.cat(out, dim=0).cpu().to(torch.float32).permute(0, 2, 3, 1)
|
|
else:
|
|
out_stacked = torch.stack(out, dim=0).cpu().to(torch.float32).permute(0, 2, 3, 1)
|
|
|
|
final_image, = ImageScale.upscale(self, out_stacked, resize_method, W, H, crop="disabled")
|
|
|
|
return (final_image,)
|
|
|
|
NODE_CLASS_MAPPINGS = {
|
|
"SUPIR_Upscale": SUPIR_Upscale
|
|
}
|
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
"SUPIR_Upscale": "SUPIR_Upscale"
|
|
} |