Spaces:
Running
Running
Initial commit
Browse files- .DS_Store +0 -0
- app.py +223 -0
- checkpoints/checkpoint_epoch_015_20250808_154437.pt +3 -0
- examples/Places365_test_00000287.jpg +0 -0
- examples/Places365_test_00000314.jpg +0 -0
- model.py +236 -0
- requirements.txt +11 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app.py — Gradio-native metrics, clean UI, CUDA/CPU only
|
2 |
+
|
3 |
+
import os, math, cv2, base64
|
4 |
+
import torch, numpy as np, gradio as gr
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
# Optional (fine if missing)
|
8 |
+
try:
|
9 |
+
import kornia.color as kc
|
10 |
+
except Exception:
|
11 |
+
kc = None
|
12 |
+
|
13 |
+
from skimage.metrics import peak_signal_noise_ratio as psnr_metric
|
14 |
+
from skimage.metrics import structural_similarity as ssim_metric
|
15 |
+
|
16 |
+
# ---------------- Device & Model (no MPS) ----------------
|
17 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
18 |
+
|
19 |
+
from model import ViTUNetColorizer
|
20 |
+
CKPT = "checkpoints/checkpoint_epoch_015_20250808_154437.pt"
|
21 |
+
|
22 |
+
model = None
|
23 |
+
if os.path.exists(CKPT):
|
24 |
+
model = ViTUNetColorizer(vit_model_name="vit_tiny_patch16_224").to(device)
|
25 |
+
state = torch.load(CKPT, map_location=device)
|
26 |
+
sd = state.get("generator_state_dict", state)
|
27 |
+
model.load_state_dict(sd)
|
28 |
+
model.eval()
|
29 |
+
|
30 |
+
# ---------------- Utils ----------------
|
31 |
+
def is_grayscale(img: Image.Image) -> bool:
|
32 |
+
a = np.array(img)
|
33 |
+
if a.ndim == 2: return True
|
34 |
+
if a.ndim == 3 and a.shape[2] == 1: return True
|
35 |
+
if a.ndim == 3 and a.shape[2] == 3:
|
36 |
+
return np.allclose(a[...,0], a[...,1]) and np.allclose(a[...,1], a[...,2])
|
37 |
+
return False
|
38 |
+
|
39 |
+
def to_L(rgb_np: np.ndarray):
|
40 |
+
# ViTUNetColorizer expects L in [0,1]
|
41 |
+
if kc is None:
|
42 |
+
gray = cv2.cvtColor(rgb_np, cv2.COLOR_RGB2GRAY).astype(np.float32)
|
43 |
+
L = gray / 100.0
|
44 |
+
return torch.from_numpy(L).unsqueeze(0).unsqueeze(0).float().to(device)
|
45 |
+
t = torch.from_numpy(rgb_np.astype(np.float32)/255.).permute(2,0,1).unsqueeze(0).to(device)
|
46 |
+
with torch.no_grad():
|
47 |
+
return kc.rgb_to_lab(t)[:,0:1]/100.0
|
48 |
+
|
49 |
+
def lab_to_rgb(L, ab):
|
50 |
+
if kc is None:
|
51 |
+
lab = torch.cat([L*100.0, torch.clamp(ab, -1, 1)*110.0], dim=1)[0].permute(1,2,0).cpu().numpy()
|
52 |
+
lab = np.clip(lab, [0,-128,-128], [100,127,127]).astype(np.float32)
|
53 |
+
rgb = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
|
54 |
+
return (np.clip(rgb,0,1)*255).astype(np.uint8)
|
55 |
+
lab = torch.cat([L*100.0, torch.clamp(ab, -1, 1)*110.0], dim=1)
|
56 |
+
with torch.no_grad():
|
57 |
+
rgb = kc.lab_to_rgb(lab)
|
58 |
+
return (torch.clamp(rgb,0,1)[0].permute(1,2,0).cpu().numpy()*255).astype(np.uint8)
|
59 |
+
|
60 |
+
def pad_to_multiple(img_np, m=16):
|
61 |
+
h,w = img_np.shape[:2]
|
62 |
+
ph, pw = math.ceil(h/m)*m, math.ceil(w/m)*m
|
63 |
+
return cv2.copyMakeBorder(img_np,0,ph-h,0,pw-w,cv2.BORDER_CONSTANT,value=(0,0,0)), (h,w)
|
64 |
+
|
65 |
+
def compute_metrics(pred, gt):
|
66 |
+
p = pred.astype(np.float32)/255.; g = gt.astype(np.float32)/255.
|
67 |
+
mae = float(np.mean(np.abs(p-g)))
|
68 |
+
psnr = float(psnr_metric(g, p, data_range=1.0))
|
69 |
+
try:
|
70 |
+
ssim = float(ssim_metric(g, p, channel_axis=2, data_range=1.0, win_size=7))
|
71 |
+
except TypeError:
|
72 |
+
ssim = float(ssim_metric(g, p, multichannel=True, data_range=1.0, win_size=7))
|
73 |
+
return round(mae,4), round(psnr,2), round(ssim,4)
|
74 |
+
|
75 |
+
# ---------------- Inference ----------------
|
76 |
+
def infer(image: Image.Image, want_metrics: bool, sizing_mode: str, show_L: bool):
|
77 |
+
if image is None:
|
78 |
+
return None, None, None, None, None, "", ""
|
79 |
+
if model is None:
|
80 |
+
return None, None, None, None, None, "", "<div>Checkpoint not found in /checkpoints.</div>"
|
81 |
+
|
82 |
+
pil = image.convert("RGB")
|
83 |
+
rgb = np.array(pil)
|
84 |
+
w,h = pil.size
|
85 |
+
was_color = not is_grayscale(pil)
|
86 |
+
|
87 |
+
if sizing_mode == "Pad to keep size":
|
88 |
+
proc, (oh, ow) = pad_to_multiple(rgb, 16); back = (ow, oh)
|
89 |
+
else:
|
90 |
+
proc = cv2.resize(rgb, (256,256), interpolation=cv2.INTER_CUBIC); back = (w,h)
|
91 |
+
|
92 |
+
L = to_L(proc)
|
93 |
+
with torch.no_grad():
|
94 |
+
ab = model(L)
|
95 |
+
out = lab_to_rgb(L, ab)
|
96 |
+
|
97 |
+
if sizing_mode == "Pad to keep size":
|
98 |
+
out = out[:back[1], :back[0]]
|
99 |
+
else:
|
100 |
+
out = cv2.resize(out, back, interpolation=cv2.INTER_CUBIC)
|
101 |
+
|
102 |
+
# Metrics (Gradio-native numbers)
|
103 |
+
mae = psnr = ssim = None
|
104 |
+
if want_metrics:
|
105 |
+
mae, psnr, ssim = compute_metrics(out, np.array(pil))
|
106 |
+
|
107 |
+
# Optional L preview
|
108 |
+
extra_html = ""
|
109 |
+
if show_L:
|
110 |
+
L01 = np.clip(L[0,0].detach().cpu().numpy(),0,1)
|
111 |
+
L_vis = (L01*255).astype(np.uint8)
|
112 |
+
L_vis = cv2.cvtColor(L_vis, cv2.COLOR_GRAY2RGB)
|
113 |
+
_, buf = cv2.imencode(".png", cv2.cvtColor(L_vis, cv2.COLOR_RGB2BGR))
|
114 |
+
L_b64 = "data:image/png;base64," + base64.b64encode(buf).decode()
|
115 |
+
extra_html += f"<div><b>L-channel</b><br/><img style='max-height:140px;border-radius:12px' src='{L_b64}'/></div>"
|
116 |
+
|
117 |
+
# Subtle notice only if needed
|
118 |
+
if was_color:
|
119 |
+
extra_html += "<div style='opacity:.8;margin-top:8px'>We used a grayscale version of your image for colorization.</div>"
|
120 |
+
|
121 |
+
# Compare slider (HTML only; easy to remove if you want 100% Gradio)
|
122 |
+
_, bo = cv2.imencode(".jpg", cv2.cvtColor(np.array(pil), cv2.COLOR_RGB2BGR))
|
123 |
+
_, bc = cv2.imencode(".jpg", cv2.cvtColor(out, cv2.COLOR_RGB2BGR))
|
124 |
+
so = "data:image/jpeg;base64," + base64.b64encode(bo).decode()
|
125 |
+
sc = "data:image/jpeg;base64," + base64.b64encode(bc).decode()
|
126 |
+
compare = f"""
|
127 |
+
<div style="position:relative;max-width:500px;margin:auto;border-radius:14px;overflow:hidden;box-shadow:0 8px 20px rgba(0,0,0,.2)">
|
128 |
+
<img src="{so}" style="width:100%;display:block"/>
|
129 |
+
<div id="cmpTop" style="position:absolute;top:0;left:0;height:100%;width:50%;overflow:hidden">
|
130 |
+
<img src="{sc}" style="width:100%;display:block"/>
|
131 |
+
</div>
|
132 |
+
<input id="cmpRange" type="range" min="0" max="100" value="50"
|
133 |
+
oninput="document.getElementById('cmpTop').style.width=this.value+'%';"
|
134 |
+
style="position:absolute;left:0;right:0;bottom:8px;width:60%;margin:auto"/>
|
135 |
+
</div>
|
136 |
+
"""
|
137 |
+
|
138 |
+
|
139 |
+
return Image.fromarray(np.array(pil)), Image.fromarray(out), mae, psnr, ssim, compare, extra_html
|
140 |
+
|
141 |
+
# ---------------- Theme (fallback-safe) ----------------
|
142 |
+
def make_theme():
|
143 |
+
try:
|
144 |
+
from gradio.themes.utils import colors, fonts, sizes
|
145 |
+
return gr.themes.Soft(
|
146 |
+
primary_hue=colors.indigo,
|
147 |
+
neutral_hue=colors.gray,
|
148 |
+
font=fonts.GoogleFont("Inter"),
|
149 |
+
).set(radius_size=sizes.radius_lg, spacing_size=sizes.spacing_md)
|
150 |
+
except Exception:
|
151 |
+
return gr.themes.Soft()
|
152 |
+
|
153 |
+
THEME = make_theme()
|
154 |
+
|
155 |
+
# ---------------- UI ----------------
|
156 |
+
with gr.Blocks(theme=THEME, title="Neural Colorizer") as demo:
|
157 |
+
gr.Markdown("# 🎨 Neural Colorizer")
|
158 |
+
|
159 |
+
with gr.Row():
|
160 |
+
with gr.Column(scale=5):
|
161 |
+
img_in = gr.Image(
|
162 |
+
label="Upload grayscale or color image",
|
163 |
+
type="pil",
|
164 |
+
image_mode="RGB",
|
165 |
+
height=320,
|
166 |
+
sources=["upload", "clipboard"]
|
167 |
+
)
|
168 |
+
with gr.Row():
|
169 |
+
sizing = gr.Radio(
|
170 |
+
["Resize to 256", "Pad to keep size"],
|
171 |
+
value="Resize to 256",
|
172 |
+
label="Sizing"
|
173 |
+
)
|
174 |
+
show_L = gr.Checkbox(label="Show L-channel", value=False)
|
175 |
+
show_m = gr.Checkbox(label="Show metrics", value=True)
|
176 |
+
with gr.Row():
|
177 |
+
run = gr.Button("Colorize")
|
178 |
+
clr = gr.Button("Clear")
|
179 |
+
|
180 |
+
examples = gr.Examples(
|
181 |
+
examples=[os.path.join("examples", f) for f in os.listdir("examples")] if os.path.exists("examples") else [],
|
182 |
+
inputs=img_in,
|
183 |
+
examples_per_page=8,
|
184 |
+
label=None
|
185 |
+
)
|
186 |
+
|
187 |
+
with gr.Column(scale=7):
|
188 |
+
with gr.Row():
|
189 |
+
orig = gr.Image(label="Original", interactive=False, height=300, show_download_button=True)
|
190 |
+
out = gr.Image(label="Result", interactive=False, height=300, show_download_button=True)
|
191 |
+
|
192 |
+
# Pure Gradio metric fields
|
193 |
+
with gr.Row():
|
194 |
+
mae_box = gr.Number(label="MAE", interactive=False, precision=4)
|
195 |
+
psnr_box = gr.Number(label="PSNR (dB)", interactive=False, precision=2)
|
196 |
+
ssim_box = gr.Number(label="SSIM", interactive=False, precision=4)
|
197 |
+
|
198 |
+
gr.Markdown("**Compare**")
|
199 |
+
compare = gr.HTML()
|
200 |
+
extras = gr.HTML()
|
201 |
+
|
202 |
+
def _go(image, want_metrics, sizing_mode, show_L):
|
203 |
+
o, c, mae, psnr, ssim, cmp_html, extra = infer(image, want_metrics, sizing_mode, show_L)
|
204 |
+
if not want_metrics:
|
205 |
+
mae = psnr = ssim = None
|
206 |
+
return o, c, mae, psnr, ssim, cmp_html, extra
|
207 |
+
|
208 |
+
run.click(
|
209 |
+
_go,
|
210 |
+
inputs=[img_in, show_m, sizing, show_L],
|
211 |
+
outputs=[orig, out, mae_box, psnr_box, ssim_box, compare, extras]
|
212 |
+
)
|
213 |
+
|
214 |
+
def _clear():
|
215 |
+
return None, None, None, None, None, "", ""
|
216 |
+
clr.click(_clear, inputs=None, outputs=[orig, out, mae_box, psnr_box, ssim_box, compare, extras])
|
217 |
+
|
218 |
+
if __name__ == "__main__":
|
219 |
+
# No queue, no API panel
|
220 |
+
try:
|
221 |
+
demo.launch(show_api=False)
|
222 |
+
except TypeError:
|
223 |
+
demo.launch()
|
checkpoints/checkpoint_epoch_015_20250808_154437.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:160e9bdb21f474f0da5dd059866966e0fe74b3ae4008307f5e9b1e245b3019c1
|
3 |
+
size 84569969
|
examples/Places365_test_00000287.jpg
ADDED
![]() |
examples/Places365_test_00000314.jpg
ADDED
![]() |
model.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import timm
|
5 |
+
import json
|
6 |
+
from torch.nn.utils import spectral_norm
|
7 |
+
from torchinfo import summary
|
8 |
+
|
9 |
+
|
10 |
+
class EncoderBlock(nn.Module):
|
11 |
+
def __init__(self, in_channels, out_channels):
|
12 |
+
super(EncoderBlock, self).__init__()
|
13 |
+
self.conv_block = nn.Sequential(
|
14 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
15 |
+
nn.GroupNorm(8, out_channels),
|
16 |
+
nn.LeakyReLU(0.01, inplace=True),
|
17 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
18 |
+
nn.GroupNorm(8, out_channels),
|
19 |
+
nn.LeakyReLU(0.01, inplace=True),
|
20 |
+
)
|
21 |
+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
features = self.conv_block(x)
|
25 |
+
pooled = self.pool(features)
|
26 |
+
return pooled, features
|
27 |
+
|
28 |
+
|
29 |
+
class DecoderBlock(nn.Module):
|
30 |
+
def __init__(self, in_channels, skip_channels, out_channels):
|
31 |
+
super(DecoderBlock, self).__init__()
|
32 |
+
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
|
33 |
+
self.ag = AttentionGate(F_g=in_channels // 2, F_l=skip_channels, F_int=in_channels // 4)
|
34 |
+
|
35 |
+
conv_in_channels = in_channels // 2 + skip_channels
|
36 |
+
|
37 |
+
self.conv_block = nn.Sequential(
|
38 |
+
nn.Conv2d(conv_in_channels, out_channels, kernel_size=3, padding=1),
|
39 |
+
nn.GroupNorm(8, out_channels),
|
40 |
+
nn.LeakyReLU(0.01, inplace=True),
|
41 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
42 |
+
nn.GroupNorm(8, out_channels),
|
43 |
+
nn.LeakyReLU(0.01, inplace=True),
|
44 |
+
)
|
45 |
+
|
46 |
+
def forward(self, x, skip):
|
47 |
+
x = self.up(x)
|
48 |
+
skip = self.ag(x, skip)
|
49 |
+
x = torch.cat([x, skip], dim=1)
|
50 |
+
x = self.conv_block(x)
|
51 |
+
return x
|
52 |
+
|
53 |
+
class AttentionGate(nn.Module):
|
54 |
+
def __init__(self, F_g, F_l, F_int):
|
55 |
+
super(AttentionGate, self).__init__()
|
56 |
+
|
57 |
+
self.W_g = nn.Sequential(
|
58 |
+
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
|
59 |
+
nn.GroupNorm(8, F_int),
|
60 |
+
)
|
61 |
+
self.W_x = nn.Sequential(
|
62 |
+
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
|
63 |
+
nn.GroupNorm(8, F_int),
|
64 |
+
)
|
65 |
+
self.psi = nn.Sequential(
|
66 |
+
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
|
67 |
+
nn.GroupNorm(1, 1),
|
68 |
+
nn.Sigmoid(),
|
69 |
+
)
|
70 |
+
self.relu = nn.LeakyReLU(0.01, inplace=True)
|
71 |
+
|
72 |
+
def forward(self, g, x):
|
73 |
+
g1 = self.W_g(g)
|
74 |
+
x1 = self.W_x(x)
|
75 |
+
psi = self.relu(g1 + x1)
|
76 |
+
psi = self.psi(psi)
|
77 |
+
return x * psi
|
78 |
+
|
79 |
+
|
80 |
+
class ViTUNetColorizer(nn.Module):
|
81 |
+
def __init__(self, vit_model_name="vit_tiny_patch16_224", freeze_vit_epochs=10):
|
82 |
+
super(ViTUNetColorizer, self).__init__()
|
83 |
+
|
84 |
+
self.vit = timm.create_model(vit_model_name, pretrained=True, num_classes=0)
|
85 |
+
self.vit_embed_dim = self.vit.embed_dim
|
86 |
+
self.vit.head = nn.Identity()
|
87 |
+
|
88 |
+
self.enc1 = EncoderBlock(1, 16)
|
89 |
+
self.enc2 = EncoderBlock(16, 32)
|
90 |
+
self.enc3 = EncoderBlock(32, 64)
|
91 |
+
self.enc4 = EncoderBlock(64, 128)
|
92 |
+
|
93 |
+
self.bottleneck_processor = nn.Sequential(
|
94 |
+
nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
95 |
+
nn.GroupNorm(8, 128),
|
96 |
+
nn.LeakyReLU(0.01, inplace=True),
|
97 |
+
nn.AdaptiveAvgPool2d((14, 14)),
|
98 |
+
)
|
99 |
+
|
100 |
+
self.fusion_layer = nn.Sequential(
|
101 |
+
nn.Conv2d(128 + self.vit_embed_dim, 128, kernel_size=1), # type: ignore
|
102 |
+
nn.GroupNorm(8, 128),
|
103 |
+
nn.LeakyReLU(0.01, inplace=True),
|
104 |
+
nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
105 |
+
nn.GroupNorm(8, 128),
|
106 |
+
nn.LeakyReLU(0.01, inplace=True),
|
107 |
+
)
|
108 |
+
|
109 |
+
self.dec4 = DecoderBlock(128, 64, 64)
|
110 |
+
self.dec3 = DecoderBlock(64, 32, 32)
|
111 |
+
self.dec2 = DecoderBlock(32, 16, 16)
|
112 |
+
|
113 |
+
self.final_conv = nn.Sequential(
|
114 |
+
nn.Conv2d(16, 8, kernel_size=3, padding=1),
|
115 |
+
nn.GroupNorm(8, 8),
|
116 |
+
nn.LeakyReLU(0.01, inplace=True),
|
117 |
+
nn.Conv2d(8, 2, kernel_size=1),
|
118 |
+
nn.Tanh(),
|
119 |
+
)
|
120 |
+
|
121 |
+
self.freeze_vit_epochs = freeze_vit_epochs
|
122 |
+
self.current_epoch = 0
|
123 |
+
|
124 |
+
def extract_vit_features(self, x):
|
125 |
+
B = x.shape[0]
|
126 |
+
x_3ch = x.repeat(1, 3, 1, 1)
|
127 |
+
|
128 |
+
if x_3ch.shape[-1] != 224:
|
129 |
+
x_3ch = F.interpolate(
|
130 |
+
x_3ch, size=(224, 224), mode="bicubic", align_corners=False
|
131 |
+
)
|
132 |
+
|
133 |
+
x_vit = self.vit.patch_embed(x_3ch) # type: ignore
|
134 |
+
if hasattr(self.vit, 'pos_embed') and self.vit.pos_embed is not None:
|
135 |
+
x_vit = x_vit + self.vit.pos_embed[:, 1:, :] # type: ignore
|
136 |
+
x_vit = self.vit.pos_drop(x_vit) # type: ignore
|
137 |
+
|
138 |
+
for block in self.vit.blocks: # type: ignore
|
139 |
+
x_vit = block(x_vit)
|
140 |
+
|
141 |
+
x_vit = self.vit.norm(x_vit) # type: ignore
|
142 |
+
x_vit = x_vit.transpose(1, 2).reshape(B, self.vit_embed_dim, 14, 14)
|
143 |
+
|
144 |
+
return x_vit
|
145 |
+
|
146 |
+
def forward(self, x):
|
147 |
+
|
148 |
+
x1, skip1 = self.enc1(x)
|
149 |
+
x2, skip2 = self.enc2(x1)
|
150 |
+
x3, skip3 = self.enc3(x2)
|
151 |
+
x4, skip4 = self.enc4(x3)
|
152 |
+
|
153 |
+
bottleneck = self.bottleneck_processor(x4)
|
154 |
+
vit_features = self.extract_vit_features(x)
|
155 |
+
fused = torch.cat([bottleneck, vit_features], dim=1)
|
156 |
+
fused = self.fusion_layer(fused)
|
157 |
+
|
158 |
+
fused = F.interpolate(fused, size=x3.shape[2:], mode="bilinear", align_corners=False)
|
159 |
+
|
160 |
+
d4 = self.dec4(fused, skip3)
|
161 |
+
d3 = self.dec3(d4, skip2)
|
162 |
+
d2 = self.dec2(d3, skip1)
|
163 |
+
|
164 |
+
out = self.final_conv(d2)
|
165 |
+
|
166 |
+
return out
|
167 |
+
|
168 |
+
def set_epoch(self, epoch):
|
169 |
+
self.current_epoch = epoch
|
170 |
+
requires_grad = epoch >= self.freeze_vit_epochs
|
171 |
+
for param in self.vit.parameters():
|
172 |
+
param.requires_grad = requires_grad
|
173 |
+
|
174 |
+
def get_param_groups(self, lr_decoder=1e-4, lr_vit=1e-5):
|
175 |
+
vit_params = []
|
176 |
+
decoder_params = []
|
177 |
+
for name, param in self.named_parameters():
|
178 |
+
if "vit" in name:
|
179 |
+
vit_params.append(param)
|
180 |
+
else:
|
181 |
+
decoder_params.append(param)
|
182 |
+
return [
|
183 |
+
{"params": decoder_params, "lr": lr_decoder},
|
184 |
+
{"params": vit_params, "lr": lr_vit},
|
185 |
+
]
|
186 |
+
|
187 |
+
|
188 |
+
class PatchDiscriminator(nn.Module):
|
189 |
+
def __init__(self, in_channels=3, n_filters=64):
|
190 |
+
super(PatchDiscriminator, self).__init__()
|
191 |
+
|
192 |
+
def discriminator_block(in_filters, out_filters, stride=2):
|
193 |
+
return [
|
194 |
+
spectral_norm(
|
195 |
+
nn.Conv2d(
|
196 |
+
in_filters, out_filters, kernel_size=4, stride=stride, padding=1
|
197 |
+
)
|
198 |
+
),
|
199 |
+
nn.LeakyReLU(0.01, inplace=True)
|
200 |
+
]
|
201 |
+
|
202 |
+
self.model = nn.Sequential(
|
203 |
+
*discriminator_block(in_channels, n_filters),
|
204 |
+
*discriminator_block(n_filters, n_filters * 2),
|
205 |
+
*discriminator_block(n_filters * 2, n_filters * 4),
|
206 |
+
spectral_norm(nn.Conv2d(n_filters * 4, 1, kernel_size=4, padding=1))
|
207 |
+
)
|
208 |
+
self.apply(self._init_weights)
|
209 |
+
|
210 |
+
def _init_weights(self, m):
|
211 |
+
if isinstance(m, nn.Conv2d):
|
212 |
+
nn.init.normal_(m.weight, 0.0, 0.02)
|
213 |
+
if m.bias is not None:
|
214 |
+
nn.init.constant_(m.bias, 0)
|
215 |
+
|
216 |
+
def forward(self, L, ab):
|
217 |
+
img_input = torch.cat((L, ab), dim=1)
|
218 |
+
return self.model(img_input)
|
219 |
+
|
220 |
+
|
221 |
+
if __name__ == "__main__":
|
222 |
+
try:
|
223 |
+
with open("hyperparameters.json", "r") as f:
|
224 |
+
hparams = json.load(f)
|
225 |
+
resolution = hparams.get("resolution", 256)
|
226 |
+
except FileNotFoundError:
|
227 |
+
resolution = 256
|
228 |
+
print("Using default resolution: 256x256")
|
229 |
+
|
230 |
+
generator = ViTUNetColorizer()
|
231 |
+
generator_input_size = (1, 1, resolution, resolution)
|
232 |
+
summary(generator, input_size=generator_input_size)
|
233 |
+
|
234 |
+
discriminator = PatchDiscriminator()
|
235 |
+
discriminator_input_size = [(1, 1, resolution, resolution), (1, 2, resolution, resolution)]
|
236 |
+
summary(discriminator, input_size=discriminator_input_size)
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
torch
|
3 |
+
torchvision
|
4 |
+
torchinfo
|
5 |
+
numpy
|
6 |
+
opencv-python-headless
|
7 |
+
Pillow
|
8 |
+
scikit-image
|
9 |
+
kornia
|
10 |
+
matplotlib
|
11 |
+
timm
|