File size: 2,618 Bytes
baa8e90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import torch
import torch.nn as nn
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download


class Upscaler(nn.Module):
	"""
		Basic NN layout, ported from:
		https://github.com/city96/SD-Latent-Upscaler/blob/main/upscaler.py
	"""
	version = 2.1 # network revision
	def head(self):
		return [
			nn.Conv2d(self.chan, self.size, kernel_size=self.krn, padding=self.pad),
			nn.ReLU(),
			nn.Upsample(scale_factor=self.fac, mode="nearest"),
			nn.ReLU(),
		]
	def core(self):
		layers = []
		for _ in range(self.depth):
			layers += [
				nn.Conv2d(self.size, self.size, kernel_size=self.krn, padding=self.pad),
				nn.ReLU(),
			]
		return layers
	def tail(self):
		return [
			nn.Conv2d(self.size, self.chan, kernel_size=self.krn, padding=self.pad),
		]

	def __init__(self, fac, depth=16):
		super().__init__()
		self.size = 64      # Conv2d size
		self.chan = 4       # in/out channels
		self.depth = depth  # no. of layers
		self.fac = fac      # scale factor
		self.krn = 3        # kernel size
		self.pad = 1        # padding

		self.sequential = nn.Sequential(
			*self.head(),
			*self.core(),
			*self.tail(),
		)

	def forward(self, x: torch.Tensor) -> torch.Tensor:
		return self.sequential(x)


class LatentUpscaler:
	def __init__(self):
		pass

	@classmethod
	def INPUT_TYPES(s):
		return {
			"required": {
				"samples": ("LATENT", ),
				"latent_ver": (["v1", "xl"],),
				"scale_factor": (["1.25", "1.5", "2.0"],),
			}
		}

	RETURN_TYPES = ("LATENT",)
	FUNCTION = "upscale"
	CATEGORY = "latent"

	def upscale(self, samples, latent_ver, scale_factor):
		model = Upscaler(scale_factor)
		filename = f"latent-upscaler-v{model.version}_SD{latent_ver}-x{scale_factor}.safetensors"
		local = os.path.join(
			os.path.join(os.path.dirname(os.path.realpath(__file__)),"models"),
			filename
		)

		if os.path.isfile(local):
			print("LatentUpscaler: Using local model")
			weights = local
		else:
			print("LatentUpscaler: Using HF Hub model")
			weights = str(hf_hub_download(
				repo_id="city96/SD-Latent-Upscaler",
				filename=filename)
			)

		model.load_state_dict(load_file(weights))
		lt = samples["samples"]
		lt = model(lt)
		del model
		if "noise_mask" in samples.keys():
			# expand the noise mask to the same shape as the latent
			mask = torch.nn.functional.interpolate(samples['noise_mask'], scale_factor=float(scale_factor), mode='bicubic') 
			return ({"samples": lt, "noise_mask": mask},)
		return ({"samples": lt},)

NODE_CLASS_MAPPINGS = {
	"LatentUpscaler": LatentUpscaler,
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "LatentUpscaler": "Latent Upscaler"
}