AUMREDKA commited on
Commit
0bf782c
·
verified ·
1 Parent(s): cceded8

UI Updates

Browse files

Simplified UI

Files changed (1) hide show
  1. app.py +79 -75
app.py CHANGED
@@ -1,10 +1,13 @@
1
- # app.py — Gradio-native metrics, clean UI, CUDA/CPU only
2
-
3
- import os, math, cv2, base64
4
- import torch, numpy as np, gradio as gr
 
 
 
5
  from PIL import Image
 
6
 
7
- # Optional (fine if missing)
8
  try:
9
  import kornia.color as kc
10
  except Exception:
@@ -12,31 +15,24 @@ except Exception:
12
 
13
  from skimage.metrics import peak_signal_noise_ratio as psnr_metric
14
  from skimage.metrics import structural_similarity as ssim_metric
 
 
15
 
16
- # ---------------- Device & Model (no MPS) ----------------
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
 
19
- from model import ViTUNetColorizer
20
  CKPT = "checkpoints/checkpoint_epoch_015_20250808_154437.pt"
21
  model = None
22
  if os.path.exists(CKPT):
 
23
  model = ViTUNetColorizer(vit_model_name="vit_tiny_patch16_224").to(device)
24
  state = torch.load(CKPT, map_location=device)
25
  sd = state.get("generator_state_dict", state)
26
  model.load_state_dict(sd)
27
  model.eval()
28
-
29
- # ---------------- Utils ----------------
30
- def is_grayscale(img: Image.Image) -> bool:
31
- a = np.array(img)
32
- if a.ndim == 2: return True
33
- if a.ndim == 3 and a.shape[2] == 1: return True
34
- if a.ndim == 3 and a.shape[2] == 3:
35
- return np.allclose(a[...,0], a[...,1]) and np.allclose(a[...,1], a[...,2])
36
- return False
37
 
38
  def to_L(rgb_np: np.ndarray):
39
- # ViTUNetColorizer expects L in [0,1]
40
  if kc is None:
41
  gray = cv2.cvtColor(rgb_np, cv2.COLOR_RGB2GRAY).astype(np.float32)
42
  L = gray / 100.0
@@ -71,18 +67,19 @@ def compute_metrics(pred, gt):
71
  ssim = float(ssim_metric(g, p, multichannel=True, data_range=1.0, win_size=7))
72
  return round(mae,4), round(psnr,2), round(ssim,4)
73
 
74
- # ---------------- Inference ----------------
75
- def infer(image: Image.Image, want_metrics: bool, show_L: bool):
76
  if image is None:
77
- return None, None, None, None, None, "", ""
 
 
 
 
 
78
  if model is None:
79
- return None, None, None, None, None, "", "<div>Checkpoint not found in /checkpoints.</div>"
80
 
81
  pil = image.convert("RGB")
82
  rgb = np.array(pil)
83
- w,h = pil.size
84
- was_color = not is_grayscale(pil)
85
-
86
 
87
  proc, (oh, ow) = pad_to_multiple(rgb, 16); back = (ow, oh)
88
 
@@ -93,46 +90,36 @@ def infer(image: Image.Image, want_metrics: bool, show_L: bool):
93
 
94
  out = out[:back[1], :back[0]]
95
 
96
- # Metrics (Gradio-native numbers)
97
  mae = psnr = ssim = None
98
  if want_metrics:
99
  mae, psnr, ssim = compute_metrics(out, np.array(pil))
100
 
101
- # Optional L preview
102
- extra_html = ""
103
- if show_L:
104
- L01 = np.clip(L[0,0].detach().cpu().numpy(),0,1)
105
- L_vis = (L01*255).astype(np.uint8)
106
- L_vis = cv2.cvtColor(L_vis, cv2.COLOR_GRAY2RGB)
107
- _, buf = cv2.imencode(".png", cv2.cvtColor(L_vis, cv2.COLOR_RGB2BGR))
108
- L_b64 = "data:image/png;base64," + base64.b64encode(buf).decode()
109
- extra_html += f"<div><b>L-channel</b><br/><img style='max-height:140px;border-radius:12px' src='{L_b64}'/></div>"
110
-
111
- # Subtle notice only if needed
112
- if was_color:
113
- extra_html += "<div style='opacity:.8;margin-top:8px'>We used a grayscale version of your image for colorization.</div>"
114
-
115
- # Compare slider (HTML only; easy to remove if you want 100% Gradio)
116
- _, bo = cv2.imencode(".jpg", cv2.cvtColor(np.array(pil), cv2.COLOR_RGB2BGR))
117
- _, bc = cv2.imencode(".jpg", cv2.cvtColor(out, cv2.COLOR_RGB2BGR))
118
  so = "data:image/jpeg;base64," + base64.b64encode(bo).decode()
119
  sc = "data:image/jpeg;base64," + base64.b64encode(bc).decode()
120
- compare = f"""
121
- <div style="position:relative;max-width:500px;margin:auto;border-radius:14px;overflow:hidden;box-shadow:0 8px 20px rgba(0,0,0,.2)">
122
- <img src="{so}" style="width:100%;display:block"/>
123
- <div id="cmpTop" style="position:absolute;top:0;left:0;height:100%;width:50%;overflow:hidden">
124
- <img src="{sc}" style="width:100%;display:block"/>
125
- </div>
126
- <input id="cmpRange" type="range" min="0" max="100" value="50"
127
- oninput="document.getElementById('cmpTop').style.width=this.value+'%';"
128
- style="position:absolute;left:0;right:0;bottom:8px;width:60%;margin:auto"/>
129
  </div>
130
  """
131
 
 
132
 
133
- return Image.fromarray(np.array(pil)), Image.fromarray(out), mae, psnr, ssim, compare, extra_html
 
 
 
 
 
 
 
134
 
135
- # ---------------- Theme (fallback-safe) ----------------
136
  def make_theme():
137
  try:
138
  from gradio.themes.utils import colors, fonts, sizes
@@ -146,9 +133,21 @@ def make_theme():
146
 
147
  THEME = make_theme()
148
 
149
- # ---------------- UI ----------------
150
- with gr.Blocks(theme=THEME, title="Image Colorizer") as demo:
 
 
 
 
 
 
 
 
 
 
151
  gr.Markdown("# 🎨 Image Colorizer")
 
 
152
 
153
  with gr.Row():
154
  with gr.Column(scale=5):
@@ -159,12 +158,12 @@ with gr.Blocks(theme=THEME, title="Image Colorizer") as demo:
159
  height=320,
160
  sources=["upload", "webcam", "clipboard"]
161
  )
162
- with gr.Row():
163
- show_L = gr.Checkbox(label="Show L-channel", value=False)
164
  show_m = gr.Checkbox(label="Show metrics", value=True)
165
  with gr.Row():
166
  run = gr.Button("Colorize")
167
  clr = gr.Button("Clear")
 
168
 
169
  examples = gr.Examples(
170
  examples=[os.path.join("examples", f) for f in os.listdir("examples")] if os.path.exists("examples") else [],
@@ -174,39 +173,44 @@ with gr.Blocks(theme=THEME, title="Image Colorizer") as demo:
174
  )
175
 
176
  with gr.Column(scale=7):
177
- with gr.Row():
178
- orig = gr.Image(label="Original", interactive=False, height=300, show_download_button=True)
179
- out = gr.Image(label="Result", interactive=False, height=300, show_download_button=True)
180
-
181
- # Pure Gradio metric fields
182
  with gr.Row():
183
  mae_box = gr.Number(label="MAE", interactive=False, precision=4)
184
  psnr_box = gr.Number(label="PSNR (dB)", interactive=False, precision=2)
185
  ssim_box = gr.Number(label="SSIM", interactive=False, precision=4)
186
 
187
- gr.Markdown("**Compare**")
188
- compare = gr.HTML()
189
- extras = gr.HTML()
190
-
191
- def _go(image, want_metrics, sizing_mode, show_L):
192
- o, c, mae, psnr, ssim, cmp_html, extra = infer(image, want_metrics, show_L)
193
  if not want_metrics:
194
  mae = psnr = ssim = None
195
- return o, c, mae, psnr, ssim, cmp_html, extra
 
 
 
196
 
197
  run.click(
198
  _go,
199
- inputs=[img_in, show_m, show_L],
200
- outputs=[orig, out, mae_box, psnr_box, ssim_box, compare, extras]
201
  )
202
 
203
  def _clear():
204
- return None, None, None, None, None, "", ""
205
- clr.click(_clear, inputs=None, outputs=[orig, out, mae_box, psnr_box, ssim_box, compare, extras])
 
 
 
 
 
 
 
 
 
 
 
206
 
207
  if __name__ == "__main__":
208
- # No queue, no API panel
209
  try:
210
  demo.launch(show_api=False)
211
  except TypeError:
212
- demo.launch()
 
1
+ import os
2
+ import math
3
+ import cv2
4
+ import base64
5
+ import torch
6
+ import numpy as np
7
+ import gradio as gr
8
  from PIL import Image
9
+ import tempfile
10
 
 
11
  try:
12
  import kornia.color as kc
13
  except Exception:
 
15
 
16
  from skimage.metrics import peak_signal_noise_ratio as psnr_metric
17
  from skimage.metrics import structural_similarity as ssim_metric
18
+ from model import ViTUNetColorizer
19
+
20
 
 
21
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
 
 
23
  CKPT = "checkpoints/checkpoint_epoch_015_20250808_154437.pt"
24
  model = None
25
  if os.path.exists(CKPT):
26
+ print(f"Loading model from: {CKPT}")
27
  model = ViTUNetColorizer(vit_model_name="vit_tiny_patch16_224").to(device)
28
  state = torch.load(CKPT, map_location=device)
29
  sd = state.get("generator_state_dict", state)
30
  model.load_state_dict(sd)
31
  model.eval()
32
+ else:
33
+ print(f"Warning: Checkpoint not found at {CKPT}. The app will not be able to colorize images.")
 
 
 
 
 
 
 
34
 
35
  def to_L(rgb_np: np.ndarray):
 
36
  if kc is None:
37
  gray = cv2.cvtColor(rgb_np, cv2.COLOR_RGB2GRAY).astype(np.float32)
38
  L = gray / 100.0
 
67
  ssim = float(ssim_metric(g, p, multichannel=True, data_range=1.0, win_size=7))
68
  return round(mae,4), round(psnr,2), round(ssim,4)
69
 
70
+ def to_grayscale(image):
 
71
  if image is None:
72
+ return None
73
+ return image.convert("L").convert("RGB")
74
+
75
+ def infer(image: Image.Image, want_metrics: bool):
76
+ if image is None:
77
+ return None, None, None, None, None
78
  if model is None:
79
+ return None, None, None, None, "<div>Checkpoint not found.</div>"
80
 
81
  pil = image.convert("RGB")
82
  rgb = np.array(pil)
 
 
 
83
 
84
  proc, (oh, ow) = pad_to_multiple(rgb, 16); back = (ow, oh)
85
 
 
90
 
91
  out = out[:back[1], :back[0]]
92
 
 
93
  mae = psnr = ssim = None
94
  if want_metrics:
95
  mae, psnr, ssim = compute_metrics(out, np.array(pil))
96
 
97
+ gray_pil = pil.convert("L").convert("RGB")
98
+ _, bo = cv2.imencode(".jpg", cv2.cvtColor(np.array(gray_pil), cv2.COLOR_RGB2BGR))
99
+ _, bc = cv2.imencode(".jpg", cv2.cvtColor(out, cv2.COLOR_RGB2BGR))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  so = "data:image/jpeg;base64," + base64.b64encode(bo).decode()
101
  sc = "data:image/jpeg;base64," + base64.b64encode(bc).decode()
102
+
103
+ compare_html = f"""
104
+ <div style="margin:auto; border-radius:14px; overflow:hidden;">
105
+ <img-comparison-slider>
106
+ <img slot="first" src="{so}" />
107
+ <img slot="second" src="{sc}" />
108
+ </img-comparison-slider>
 
 
109
  </div>
110
  """
111
 
112
+ return out, mae, psnr, ssim, compare_html
113
 
114
+ def save_for_download(image_array):
115
+ """Saves a NumPy array to a temporary file and returns the path."""
116
+ if image_array is not None:
117
+ pil_img = Image.fromarray(image_array)
118
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
119
+ pil_img.save(temp_file.name)
120
+ return temp_file.name
121
+ return None
122
 
 
123
  def make_theme():
124
  try:
125
  from gradio.themes.utils import colors, fonts, sizes
 
133
 
134
  THEME = make_theme()
135
 
136
+ PLACEHOLDER_HTML = """
137
+ <div style='display:flex; justify-content:center; align-items:center; height:480px; border: 2px dashed #4B5563; border-radius:12px; color:#4B5563; font-family:sans-serif;'>
138
+ <span>Result will be shown here</span>
139
+ </div>
140
+ """
141
+
142
+ HEAD = """
143
+ <script type="module" src="https://unpkg.com/img-comparison-slider@8/dist/index.js"></script>
144
+ <link rel="stylesheet" href="https://unpkg.com/img-comparison-slider@8/dist/themes/default.css" />
145
+ """
146
+
147
+ with gr.Blocks(theme=THEME, title="Image Colorizer", head=HEAD) as demo:
148
  gr.Markdown("# 🎨 Image Colorizer")
149
+
150
+ result_state = gr.State()
151
 
152
  with gr.Row():
153
  with gr.Column(scale=5):
 
158
  height=320,
159
  sources=["upload", "webcam", "clipboard"]
160
  )
161
+ img_in.upload(fn=to_grayscale, inputs=img_in, outputs=img_in)
 
162
  show_m = gr.Checkbox(label="Show metrics", value=True)
163
  with gr.Row():
164
  run = gr.Button("Colorize")
165
  clr = gr.Button("Clear")
166
+ download_btn = gr.DownloadButton("Download Result", visible=False)
167
 
168
  examples = gr.Examples(
169
  examples=[os.path.join("examples", f) for f in os.listdir("examples")] if os.path.exists("examples") else [],
 
173
  )
174
 
175
  with gr.Column(scale=7):
176
+ out_html = gr.HTML(label="Result", value=PLACEHOLDER_HTML)
 
 
 
 
177
  with gr.Row():
178
  mae_box = gr.Number(label="MAE", interactive=False, precision=4)
179
  psnr_box = gr.Number(label="PSNR (dB)", interactive=False, precision=2)
180
  ssim_box = gr.Number(label="SSIM", interactive=False, precision=4)
181
 
182
+ def _go(image, want_metrics):
183
+ out_image, mae, psnr, ssim, cmp_html = infer(image, want_metrics)
 
 
 
 
184
  if not want_metrics:
185
  mae = psnr = ssim = None
186
+
187
+ download_button_update = gr.update(visible=True) if out_image is not None else gr.update(visible=False)
188
+
189
+ return out_image, cmp_html, mae, psnr, ssim, download_button_update
190
 
191
  run.click(
192
  _go,
193
+ inputs=[img_in, show_m],
194
+ outputs=[result_state, out_html, mae_box, psnr_box, ssim_box, download_btn]
195
  )
196
 
197
  def _clear():
198
+ return None, PLACEHOLDER_HTML, None, None, None, gr.update(visible=False)
199
+
200
+ clr.click(
201
+ _clear,
202
+ inputs=None,
203
+ outputs=[result_state, out_html, mae_box, psnr_box, ssim_box, download_btn]
204
+ )
205
+
206
+ download_btn.click(
207
+ save_for_download,
208
+ inputs=[result_state],
209
+ outputs=[download_btn]
210
+ )
211
 
212
  if __name__ == "__main__":
 
213
  try:
214
  demo.launch(show_api=False)
215
  except TypeError:
216
+ demo.launch()