# app.py — Gradio-native metrics, clean UI, CUDA/CPU only import os, math, cv2, base64 import torch, numpy as np, gradio as gr from PIL import Image # Optional (fine if missing) try: import kornia.color as kc except Exception: kc = None from skimage.metrics import peak_signal_noise_ratio as psnr_metric from skimage.metrics import structural_similarity as ssim_metric # ---------------- Device & Model (no MPS) ---------------- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") from model import ViTUNetColorizer # CKPT = "checkpoints/checkpoint_epoch_015_20250808_154437.pt" CKPT = "checkpoints/checkpoint_epoch_022_20250808_190318.pt" model = None if os.path.exists(CKPT): model = ViTUNetColorizer(vit_model_name="vit_tiny_patch16_224").to(device) state = torch.load(CKPT, map_location=device) sd = state.get("generator_state_dict", state) model.load_state_dict(sd) model.eval() # ---------------- Utils ---------------- def is_grayscale(img: Image.Image) -> bool: a = np.array(img) if a.ndim == 2: return True if a.ndim == 3 and a.shape[2] == 1: return True if a.ndim == 3 and a.shape[2] == 3: return np.allclose(a[...,0], a[...,1]) and np.allclose(a[...,1], a[...,2]) return False def to_L(rgb_np: np.ndarray): # ViTUNetColorizer expects L in [0,1] if kc is None: gray = cv2.cvtColor(rgb_np, cv2.COLOR_RGB2GRAY).astype(np.float32) L = gray / 100.0 return torch.from_numpy(L).unsqueeze(0).unsqueeze(0).float().to(device) t = torch.from_numpy(rgb_np.astype(np.float32)/255.).permute(2,0,1).unsqueeze(0).to(device) with torch.no_grad(): return kc.rgb_to_lab(t)[:,0:1]/100.0 def lab_to_rgb(L, ab): if kc is None: lab = torch.cat([L*100.0, torch.clamp(ab, -1, 1)*110.0], dim=1)[0].permute(1,2,0).cpu().numpy() lab = np.clip(lab, [0,-128,-128], [100,127,127]).astype(np.float32) rgb = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB) return (np.clip(rgb,0,1)*255).astype(np.uint8) lab = torch.cat([L*100.0, torch.clamp(ab, -1, 1)*110.0], dim=1) with torch.no_grad(): rgb = kc.lab_to_rgb(lab) return (torch.clamp(rgb,0,1)[0].permute(1,2,0).cpu().numpy()*255).astype(np.uint8) def pad_to_multiple(img_np, m=16): h,w = img_np.shape[:2] ph, pw = math.ceil(h/m)*m, math.ceil(w/m)*m return cv2.copyMakeBorder(img_np,0,ph-h,0,pw-w,cv2.BORDER_CONSTANT,value=(0,0,0)), (h,w) def compute_metrics(pred, gt): p = pred.astype(np.float32)/255.; g = gt.astype(np.float32)/255. mae = float(np.mean(np.abs(p-g))) psnr = float(psnr_metric(g, p, data_range=1.0)) try: ssim = float(ssim_metric(g, p, channel_axis=2, data_range=1.0, win_size=7)) except TypeError: ssim = float(ssim_metric(g, p, multichannel=True, data_range=1.0, win_size=7)) return round(mae,4), round(psnr,2), round(ssim,4) # ---------------- Inference ---------------- def infer(image: Image.Image, want_metrics: bool, sizing_mode: str, show_L: bool): if image is None: return None, None, None, None, None, "", "" if model is None: return None, None, None, None, None, "", "
Checkpoint not found in /checkpoints.
" pil = image.convert("RGB") rgb = np.array(pil) w,h = pil.size was_color = not is_grayscale(pil) if sizing_mode == "Pad to keep size": proc, (oh, ow) = pad_to_multiple(rgb, 16); back = (ow, oh) else: proc = cv2.resize(rgb, (256,256), interpolation=cv2.INTER_CUBIC); back = (w,h) L = to_L(proc) with torch.no_grad(): ab = model(L) out = lab_to_rgb(L, ab) if sizing_mode == "Pad to keep size": out = out[:back[1], :back[0]] else: out = cv2.resize(out, back, interpolation=cv2.INTER_CUBIC) # Metrics (Gradio-native numbers) mae = psnr = ssim = None if want_metrics: mae, psnr, ssim = compute_metrics(out, np.array(pil)) # Optional L preview extra_html = "" if show_L: L01 = np.clip(L[0,0].detach().cpu().numpy(),0,1) L_vis = (L01*255).astype(np.uint8) L_vis = cv2.cvtColor(L_vis, cv2.COLOR_GRAY2RGB) _, buf = cv2.imencode(".png", cv2.cvtColor(L_vis, cv2.COLOR_RGB2BGR)) L_b64 = "data:image/png;base64," + base64.b64encode(buf).decode() extra_html += f"
L-channel
" # Subtle notice only if needed if was_color: extra_html += "
We used a grayscale version of your image for colorization.
" # Compare slider (HTML only; easy to remove if you want 100% Gradio) _, bo = cv2.imencode(".jpg", cv2.cvtColor(np.array(pil), cv2.COLOR_RGB2BGR)) _, bc = cv2.imencode(".jpg", cv2.cvtColor(out, cv2.COLOR_RGB2BGR)) so = "data:image/jpeg;base64," + base64.b64encode(bo).decode() sc = "data:image/jpeg;base64," + base64.b64encode(bc).decode() compare = f"""
""" return Image.fromarray(np.array(pil)), Image.fromarray(out), mae, psnr, ssim, compare, extra_html # ---------------- Theme (fallback-safe) ---------------- def make_theme(): try: from gradio.themes.utils import colors, fonts, sizes return gr.themes.Soft( primary_hue=colors.indigo, neutral_hue=colors.gray, font=fonts.GoogleFont("Inter"), ).set(radius_size=sizes.radius_lg, spacing_size=sizes.spacing_md) except Exception: return gr.themes.Soft() THEME = make_theme() # ---------------- UI ---------------- with gr.Blocks(theme=THEME, title="Neural Colorizer") as demo: gr.Markdown("# 🎨 Neural Colorizer") with gr.Row(): with gr.Column(scale=5): img_in = gr.Image( label="Upload grayscale or color image", type="pil", image_mode="RGB", height=320, sources=["upload", "clipboard"] ) with gr.Row(): sizing = gr.Radio( ["Resize to 256", "Pad to keep size"], value="Resize to 256", label="Sizing" ) show_L = gr.Checkbox(label="Show L-channel", value=False) show_m = gr.Checkbox(label="Show metrics", value=True) with gr.Row(): run = gr.Button("Colorize") clr = gr.Button("Clear") examples = gr.Examples( examples=[os.path.join("examples", f) for f in os.listdir("examples")] if os.path.exists("examples") else [], inputs=img_in, examples_per_page=8, label=None ) with gr.Column(scale=7): with gr.Row(): orig = gr.Image(label="Original", interactive=False, height=300, show_download_button=True) out = gr.Image(label="Result", interactive=False, height=300, show_download_button=True) # Pure Gradio metric fields with gr.Row(): mae_box = gr.Number(label="MAE", interactive=False, precision=4) psnr_box = gr.Number(label="PSNR (dB)", interactive=False, precision=2) ssim_box = gr.Number(label="SSIM", interactive=False, precision=4) gr.Markdown("**Compare**") compare = gr.HTML() extras = gr.HTML() def _go(image, want_metrics, sizing_mode, show_L): o, c, mae, psnr, ssim, cmp_html, extra = infer(image, want_metrics, sizing_mode, show_L) if not want_metrics: mae = psnr = ssim = None return o, c, mae, psnr, ssim, cmp_html, extra run.click( _go, inputs=[img_in, show_m, sizing, show_L], outputs=[orig, out, mae_box, psnr_box, ssim_box, compare, extras] ) def _clear(): return None, None, None, None, None, "", "" clr.click(_clear, inputs=None, outputs=[orig, out, mae_box, psnr_box, ssim_box, compare, extras]) if __name__ == "__main__": # No queue, no API panel try: demo.launch(show_api=False) except TypeError: demo.launch()