|
import os |
|
import torch |
|
import torch.nn as nn |
|
from safetensors.torch import load_file |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
class Upscaler(nn.Module): |
|
""" |
|
Basic NN layout, ported from: |
|
https://github.com/city96/SD-Latent-Upscaler/blob/main/upscaler.py |
|
""" |
|
version = 2.1 |
|
def head(self): |
|
return [ |
|
nn.Conv2d(self.chan, self.size, kernel_size=self.krn, padding=self.pad), |
|
nn.ReLU(), |
|
nn.Upsample(scale_factor=self.fac, mode="nearest"), |
|
nn.ReLU(), |
|
] |
|
def core(self): |
|
layers = [] |
|
for _ in range(self.depth): |
|
layers += [ |
|
nn.Conv2d(self.size, self.size, kernel_size=self.krn, padding=self.pad), |
|
nn.ReLU(), |
|
] |
|
return layers |
|
def tail(self): |
|
return [ |
|
nn.Conv2d(self.size, self.chan, kernel_size=self.krn, padding=self.pad), |
|
] |
|
|
|
def __init__(self, fac, depth=16): |
|
super().__init__() |
|
self.size = 64 |
|
self.chan = 4 |
|
self.depth = depth |
|
self.fac = fac |
|
self.krn = 3 |
|
self.pad = 1 |
|
|
|
self.sequential = nn.Sequential( |
|
*self.head(), |
|
*self.core(), |
|
*self.tail(), |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return self.sequential(x) |
|
|
|
|
|
class LatentUpscaler: |
|
def __init__(self): |
|
pass |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"samples": ("LATENT", ), |
|
"latent_ver": (["v1", "xl"],), |
|
"scale_factor": (["1.25", "1.5", "2.0"],), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("LATENT",) |
|
FUNCTION = "upscale" |
|
CATEGORY = "latent" |
|
|
|
def upscale(self, samples, latent_ver, scale_factor): |
|
model = Upscaler(scale_factor) |
|
filename = f"latent-upscaler-v{model.version}_SD{latent_ver}-x{scale_factor}.safetensors" |
|
local = os.path.join( |
|
os.path.join(os.path.dirname(os.path.realpath(__file__)),"models"), |
|
filename |
|
) |
|
|
|
if os.path.isfile(local): |
|
print("LatentUpscaler: Using local model") |
|
weights = local |
|
else: |
|
print("LatentUpscaler: Using HF Hub model") |
|
weights = str(hf_hub_download( |
|
repo_id="city96/SD-Latent-Upscaler", |
|
filename=filename) |
|
) |
|
|
|
model.load_state_dict(load_file(weights)) |
|
lt = samples["samples"] |
|
lt = model(lt) |
|
del model |
|
if "noise_mask" in samples.keys(): |
|
|
|
mask = torch.nn.functional.interpolate(samples['noise_mask'], scale_factor=float(scale_factor), mode='bicubic') |
|
return ({"samples": lt, "noise_mask": mask},) |
|
return ({"samples": lt},) |
|
|
|
NODE_CLASS_MAPPINGS = { |
|
"LatentUpscaler": LatentUpscaler, |
|
} |
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = { |
|
"LatentUpscaler": "Latent Upscaler" |
|
} |
|
|