import os import math import cv2 import base64 import torch import numpy as np import gradio as gr from PIL import Image import tempfile try: import kornia.color as kc except Exception: kc = None from model import ViTUNetColorizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") CKPT = "checkpoints/checkpoint_epoch_017_20250810_193435.pt" model = None if os.path.exists(CKPT): print(f"Loading model from: {CKPT}") model = ViTUNetColorizer(vit_model_name="vit_tiny_patch16_224").to(device) state = torch.load(CKPT, map_location=device) sd = state.get("generator_state_dict", state) model.load_state_dict(sd) model.eval() else: print(f"Warning: Checkpoint not found at {CKPT}. The app will not be able to colorize images.") def to_L(rgb_np: np.ndarray): if kc is None: gray = cv2.cvtColor(rgb_np, cv2.COLOR_RGB2GRAY).astype(np.float32) L = gray / 100.0 return torch.from_numpy(L).unsqueeze(0).unsqueeze(0).float().to(device) t = torch.from_numpy(rgb_np.astype(np.float32)/255.).permute(2,0,1).unsqueeze(0).to(device) with torch.no_grad(): return kc.rgb_to_lab(t)[:,0:1]/100.0 def lab_to_rgb(L, ab): if kc is None: lab = torch.cat([L*100.0, torch.clamp(ab, -1, 1)*110.0], dim=1)[0].permute(1,2,0).cpu().numpy() lab = np.clip(lab, [0,-128,-128], [100,127,127]).astype(np.float32) rgb = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB) return (np.clip(rgb,0,1)*255).astype(np.uint8) lab = torch.cat([L*100.0, torch.clamp(ab, -1, 1)*110.0], dim=1) with torch.no_grad(): rgb = kc.lab_to_rgb(lab) return (torch.clamp(rgb,0,1)[0].permute(1,2,0).cpu().numpy()*255).astype(np.uint8) def pad_to_multiple(img_np, m=16): h,w = img_np.shape[:2] ph, pw = math.ceil(h/m)*m, math.ceil(w/m)*m return cv2.copyMakeBorder(img_np,0,ph-h,0,pw-w,cv2.BORDER_CONSTANT,value=(0,0,0)), (h,w) def to_grayscale(image): if image is None: return None return image.convert("L").convert("RGB") def infer(image: Image.Image): if image is None: return None, None if model is None: return None, "
Checkpoint not found.
" pil = image.convert("RGB") rgb = np.array(pil) proc, (oh, ow) = pad_to_multiple(rgb, 16); back = (ow, oh) L = to_L(proc) with torch.no_grad(): ab = model(L) out = lab_to_rgb(L, ab) out = out[:back[1], :back[0]] gray_pil = pil.convert("L").convert("RGB") _, bo = cv2.imencode(".jpg", cv2.cvtColor(np.array(gray_pil), cv2.COLOR_RGB2BGR)) _, bc = cv2.imencode(".jpg", cv2.cvtColor(out, cv2.COLOR_RGB2BGR)) so = "data:image/jpeg;base64," + base64.b64encode(bo).decode() sc = "data:image/jpeg;base64," + base64.b64encode(bc).decode() compare_html = f"""
""" return out, compare_html def save_for_download(image_array): """Saves a NumPy array to a temporary file and returns the path.""" if image_array is not None: pil_img = Image.fromarray(image_array) with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: pil_img.save(temp_file.name) return temp_file.name return None def make_theme(): try: from gradio.themes.utils import colors, fonts, sizes return gr.themes.Soft( primary_hue=colors.indigo, neutral_hue=colors.gray, font=fonts.GoogleFont("Inter"), ).set(radius_size=sizes.radius_lg, spacing_size=sizes.spacing_md) except Exception: return gr.themes.Soft() THEME = make_theme() PLACEHOLDER_HTML = """
Result will be shown here
""" HEAD = """ """ with gr.Blocks(theme=THEME, title="Image Colorizer", head=HEAD) as demo: gr.Markdown("# 🎨 Image Colorizer\nWorks best on natural scenes. Learn more about the dataset we trained on [here.](http://places.csail.mit.edu/)") result_state = gr.State() with gr.Row(): with gr.Column(scale=5): img_in = gr.Image( label="Upload image", type="pil", image_mode="RGB", height=320, sources=["upload", "webcam", "clipboard"] ) img_in.upload(fn=to_grayscale, inputs=img_in, outputs=img_in) with gr.Row(): run = gr.Button("Colorize") clr = gr.Button("Clear") download_btn = gr.DownloadButton("Download Result", visible=False) examples = gr.Examples( examples=[os.path.join("examples", f) for f in os.listdir("examples")] if os.path.exists("examples") else [], inputs=img_in, examples_per_page=8, label=None ) with gr.Column(scale=7): out_html = gr.HTML(label="Result", value=PLACEHOLDER_HTML) def _go(image): out_image, cmp_html = infer(image) download_button_update = gr.update(visible=True) if out_image is not None else gr.update(visible=False) return out_image, cmp_html, download_button_update run.click( _go, inputs=[img_in], outputs=[result_state, out_html, download_btn] ) def _clear(): return None, None, PLACEHOLDER_HTML, gr.update(visible=False) clr.click( _clear, inputs=None, outputs=[img_in, result_state, out_html, download_btn] ) download_btn.click( save_for_download, inputs=[result_state], outputs=[download_btn] ) if __name__ == "__main__": try: demo.launch(show_api=False) except TypeError: demo.launch()