Spaces:
Running
Running
# app.py — Gradio-native metrics, clean UI, CUDA/CPU only | |
import os, math, cv2, base64 | |
import torch, numpy as np, gradio as gr | |
from PIL import Image | |
# Optional (fine if missing) | |
try: | |
import kornia.color as kc | |
except Exception: | |
kc = None | |
from skimage.metrics import peak_signal_noise_ratio as psnr_metric | |
from skimage.metrics import structural_similarity as ssim_metric | |
# ---------------- Device & Model (no MPS) ---------------- | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
from model import ViTUNetColorizer | |
CKPT = "checkpoints/checkpoint_epoch_015_20250808_154437.pt" | |
model = None | |
if os.path.exists(CKPT): | |
model = ViTUNetColorizer(vit_model_name="vit_tiny_patch16_224").to(device) | |
state = torch.load(CKPT, map_location=device) | |
sd = state.get("generator_state_dict", state) | |
model.load_state_dict(sd) | |
model.eval() | |
# ---------------- Utils ---------------- | |
def is_grayscale(img: Image.Image) -> bool: | |
a = np.array(img) | |
if a.ndim == 2: return True | |
if a.ndim == 3 and a.shape[2] == 1: return True | |
if a.ndim == 3 and a.shape[2] == 3: | |
return np.allclose(a[...,0], a[...,1]) and np.allclose(a[...,1], a[...,2]) | |
return False | |
def to_L(rgb_np: np.ndarray): | |
# ViTUNetColorizer expects L in [0,1] | |
if kc is None: | |
gray = cv2.cvtColor(rgb_np, cv2.COLOR_RGB2GRAY).astype(np.float32) | |
L = gray / 100.0 | |
return torch.from_numpy(L).unsqueeze(0).unsqueeze(0).float().to(device) | |
t = torch.from_numpy(rgb_np.astype(np.float32)/255.).permute(2,0,1).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
return kc.rgb_to_lab(t)[:,0:1]/100.0 | |
def lab_to_rgb(L, ab): | |
if kc is None: | |
lab = torch.cat([L*100.0, torch.clamp(ab, -1, 1)*110.0], dim=1)[0].permute(1,2,0).cpu().numpy() | |
lab = np.clip(lab, [0,-128,-128], [100,127,127]).astype(np.float32) | |
rgb = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB) | |
return (np.clip(rgb,0,1)*255).astype(np.uint8) | |
lab = torch.cat([L*100.0, torch.clamp(ab, -1, 1)*110.0], dim=1) | |
with torch.no_grad(): | |
rgb = kc.lab_to_rgb(lab) | |
return (torch.clamp(rgb,0,1)[0].permute(1,2,0).cpu().numpy()*255).astype(np.uint8) | |
def pad_to_multiple(img_np, m=16): | |
h,w = img_np.shape[:2] | |
ph, pw = math.ceil(h/m)*m, math.ceil(w/m)*m | |
return cv2.copyMakeBorder(img_np,0,ph-h,0,pw-w,cv2.BORDER_CONSTANT,value=(0,0,0)), (h,w) | |
def compute_metrics(pred, gt): | |
p = pred.astype(np.float32)/255.; g = gt.astype(np.float32)/255. | |
mae = float(np.mean(np.abs(p-g))) | |
psnr = float(psnr_metric(g, p, data_range=1.0)) | |
try: | |
ssim = float(ssim_metric(g, p, channel_axis=2, data_range=1.0, win_size=7)) | |
except TypeError: | |
ssim = float(ssim_metric(g, p, multichannel=True, data_range=1.0, win_size=7)) | |
return round(mae,4), round(psnr,2), round(ssim,4) | |
# ---------------- Inference ---------------- | |
def infer(image: Image.Image, want_metrics: bool, show_L: bool): | |
if image is None: | |
return None, None, None, None, None, "", "" | |
if model is None: | |
return None, None, None, None, None, "", "<div>Checkpoint not found in /checkpoints.</div>" | |
pil = image.convert("RGB") | |
rgb = np.array(pil) | |
w,h = pil.size | |
was_color = not is_grayscale(pil) | |
proc, (oh, ow) = pad_to_multiple(rgb, 16); back = (ow, oh) | |
L = to_L(proc) | |
with torch.no_grad(): | |
ab = model(L) | |
out = lab_to_rgb(L, ab) | |
out = out[:back[1], :back[0]] | |
# Metrics (Gradio-native numbers) | |
mae = psnr = ssim = None | |
if want_metrics: | |
mae, psnr, ssim = compute_metrics(out, np.array(pil)) | |
# Optional L preview | |
extra_html = "" | |
if show_L: | |
L01 = np.clip(L[0,0].detach().cpu().numpy(),0,1) | |
L_vis = (L01*255).astype(np.uint8) | |
L_vis = cv2.cvtColor(L_vis, cv2.COLOR_GRAY2RGB) | |
_, buf = cv2.imencode(".png", cv2.cvtColor(L_vis, cv2.COLOR_RGB2BGR)) | |
L_b64 = "data:image/png;base64," + base64.b64encode(buf).decode() | |
extra_html += f"<div><b>L-channel</b><br/><img style='max-height:140px;border-radius:12px' src='{L_b64}'/></div>" | |
# Subtle notice only if needed | |
if was_color: | |
extra_html += "<div style='opacity:.8;margin-top:8px'>We used a grayscale version of your image for colorization.</div>" | |
# Compare slider (HTML only; easy to remove if you want 100% Gradio) | |
_, bo = cv2.imencode(".jpg", cv2.cvtColor(np.array(pil), cv2.COLOR_RGB2BGR)) | |
_, bc = cv2.imencode(".jpg", cv2.cvtColor(out, cv2.COLOR_RGB2BGR)) | |
so = "data:image/jpeg;base64," + base64.b64encode(bo).decode() | |
sc = "data:image/jpeg;base64," + base64.b64encode(bc).decode() | |
compare = f""" | |
<div style="position:relative;max-width:500px;margin:auto;border-radius:14px;overflow:hidden;box-shadow:0 8px 20px rgba(0,0,0,.2)"> | |
<img src="{so}" style="width:100%;display:block"/> | |
<div id="cmpTop" style="position:absolute;top:0;left:0;height:100%;width:50%;overflow:hidden"> | |
<img src="{sc}" style="width:100%;display:block"/> | |
</div> | |
<input id="cmpRange" type="range" min="0" max="100" value="50" | |
oninput="document.getElementById('cmpTop').style.width=this.value+'%';" | |
style="position:absolute;left:0;right:0;bottom:8px;width:60%;margin:auto"/> | |
</div> | |
""" | |
return Image.fromarray(np.array(pil)), Image.fromarray(out), mae, psnr, ssim, compare, extra_html | |
# ---------------- Theme (fallback-safe) ---------------- | |
def make_theme(): | |
try: | |
from gradio.themes.utils import colors, fonts, sizes | |
return gr.themes.Soft( | |
primary_hue=colors.indigo, | |
neutral_hue=colors.gray, | |
font=fonts.GoogleFont("Inter"), | |
).set(radius_size=sizes.radius_lg, spacing_size=sizes.spacing_md) | |
except Exception: | |
return gr.themes.Soft() | |
THEME = make_theme() | |
# ---------------- UI ---------------- | |
with gr.Blocks(theme=THEME, title="Image Colorizer") as demo: | |
gr.Markdown("# 🎨 Image Colorizer") | |
with gr.Row(): | |
with gr.Column(scale=5): | |
img_in = gr.Image( | |
label="Upload image", | |
type="pil", | |
image_mode="RGB", | |
height=320, | |
sources=["upload", "webcam", "clipboard"] | |
) | |
with gr.Row(): | |
show_L = gr.Checkbox(label="Show L-channel", value=False) | |
show_m = gr.Checkbox(label="Show metrics", value=True) | |
with gr.Row(): | |
run = gr.Button("Colorize") | |
clr = gr.Button("Clear") | |
examples = gr.Examples( | |
examples=[os.path.join("examples", f) for f in os.listdir("examples")] if os.path.exists("examples") else [], | |
inputs=img_in, | |
examples_per_page=8, | |
label=None | |
) | |
with gr.Column(scale=7): | |
with gr.Row(): | |
orig = gr.Image(label="Original", interactive=False, height=300, show_download_button=True) | |
out = gr.Image(label="Result", interactive=False, height=300, show_download_button=True) | |
# Pure Gradio metric fields | |
with gr.Row(): | |
mae_box = gr.Number(label="MAE", interactive=False, precision=4) | |
psnr_box = gr.Number(label="PSNR (dB)", interactive=False, precision=2) | |
ssim_box = gr.Number(label="SSIM", interactive=False, precision=4) | |
gr.Markdown("**Compare**") | |
compare = gr.HTML() | |
extras = gr.HTML() | |
def _go(image, want_metrics, sizing_mode, show_L): | |
o, c, mae, psnr, ssim, cmp_html, extra = infer(image, want_metrics, show_L) | |
if not want_metrics: | |
mae = psnr = ssim = None | |
return o, c, mae, psnr, ssim, cmp_html, extra | |
run.click( | |
_go, | |
inputs=[img_in, show_m, show_L], | |
outputs=[orig, out, mae_box, psnr_box, ssim_box, compare, extras] | |
) | |
def _clear(): | |
return None, None, None, None, None, "", "" | |
clr.click(_clear, inputs=None, outputs=[orig, out, mae_box, psnr_box, ssim_box, compare, extras]) | |
if __name__ == "__main__": | |
# No queue, no API panel | |
try: | |
demo.launch(show_api=False) | |
except TypeError: | |
demo.launch() | |