ImageColorizer / app.py
AUMREDKA's picture
Use latest model
1dc3825 verified
raw
history blame
7.66 kB
import os
import math
import cv2
import base64
import torch
import numpy as np
import gradio as gr
from PIL import Image
import tempfile
try:
import kornia.color as kc
except Exception:
kc = None
from skimage.metrics import peak_signal_noise_ratio as psnr_metric
from skimage.metrics import structural_similarity as ssim_metric
from model import ViTUNetColorizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CKPT = "checkpoints/checkpoint_epoch_017_20250810_193435.pt"
model = None
if os.path.exists(CKPT):
print(f"Loading model from: {CKPT}")
model = ViTUNetColorizer(vit_model_name="vit_tiny_patch16_224").to(device)
state = torch.load(CKPT, map_location=device)
sd = state.get("generator_state_dict", state)
model.load_state_dict(sd)
model.eval()
else:
print(f"Warning: Checkpoint not found at {CKPT}. The app will not be able to colorize images.")
def to_L(rgb_np: np.ndarray):
if kc is None:
gray = cv2.cvtColor(rgb_np, cv2.COLOR_RGB2GRAY).astype(np.float32)
L = gray / 100.0
return torch.from_numpy(L).unsqueeze(0).unsqueeze(0).float().to(device)
t = torch.from_numpy(rgb_np.astype(np.float32)/255.).permute(2,0,1).unsqueeze(0).to(device)
with torch.no_grad():
return kc.rgb_to_lab(t)[:,0:1]/100.0
def lab_to_rgb(L, ab):
if kc is None:
lab = torch.cat([L*100.0, torch.clamp(ab, -1, 1)*110.0], dim=1)[0].permute(1,2,0).cpu().numpy()
lab = np.clip(lab, [0,-128,-128], [100,127,127]).astype(np.float32)
rgb = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
return (np.clip(rgb,0,1)*255).astype(np.uint8)
lab = torch.cat([L*100.0, torch.clamp(ab, -1, 1)*110.0], dim=1)
with torch.no_grad():
rgb = kc.lab_to_rgb(lab)
return (torch.clamp(rgb,0,1)[0].permute(1,2,0).cpu().numpy()*255).astype(np.uint8)
def pad_to_multiple(img_np, m=16):
h,w = img_np.shape[:2]
ph, pw = math.ceil(h/m)*m, math.ceil(w/m)*m
return cv2.copyMakeBorder(img_np,0,ph-h,0,pw-w,cv2.BORDER_CONSTANT,value=(0,0,0)), (h,w)
def compute_metrics(pred, gt):
p = pred.astype(np.float32)/255.; g = gt.astype(np.float32)/255.
mae = float(np.mean(np.abs(p-g)))
psnr = float(psnr_metric(g, p, data_range=1.0))
try:
ssim = float(ssim_metric(g, p, channel_axis=2, data_range=1.0, win_size=7))
except TypeError:
ssim = float(ssim_metric(g, p, multichannel=True, data_range=1.0, win_size=7))
return round(mae,4), round(psnr,2), round(ssim,4)
def to_grayscale(image):
if image is None:
return None
return image.convert("L").convert("RGB")
def infer(image: Image.Image, want_metrics: bool):
if image is None:
return None, None, None, None, None
if model is None:
return None, None, None, None, "<div>Checkpoint not found.</div>"
pil = image.convert("RGB")
rgb = np.array(pil)
proc, (oh, ow) = pad_to_multiple(rgb, 16); back = (ow, oh)
L = to_L(proc)
with torch.no_grad():
ab = model(L)
out = lab_to_rgb(L, ab)
out = out[:back[1], :back[0]]
mae = psnr = ssim = None
if want_metrics:
mae, psnr, ssim = compute_metrics(out, np.array(pil))
gray_pil = pil.convert("L").convert("RGB")
_, bo = cv2.imencode(".jpg", cv2.cvtColor(np.array(gray_pil), cv2.COLOR_RGB2BGR))
_, bc = cv2.imencode(".jpg", cv2.cvtColor(out, cv2.COLOR_RGB2BGR))
so = "data:image/jpeg;base64," + base64.b64encode(bo).decode()
sc = "data:image/jpeg;base64," + base64.b64encode(bc).decode()
compare_html = f"""
<div style="margin:auto; border-radius:14px; overflow:hidden;">
<img-comparison-slider>
<img slot="first" src="{so}" />
<img slot="second" src="{sc}" />
</img-comparison-slider>
</div>
"""
return out, mae, psnr, ssim, compare_html
def save_for_download(image_array):
"""Saves a NumPy array to a temporary file and returns the path."""
if image_array is not None:
pil_img = Image.fromarray(image_array)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
pil_img.save(temp_file.name)
return temp_file.name
return None
def make_theme():
try:
from gradio.themes.utils import colors, fonts, sizes
return gr.themes.Soft(
primary_hue=colors.indigo,
neutral_hue=colors.gray,
font=fonts.GoogleFont("Inter"),
).set(radius_size=sizes.radius_lg, spacing_size=sizes.spacing_md)
except Exception:
return gr.themes.Soft()
THEME = make_theme()
PLACEHOLDER_HTML = """
<div style='display:flex; justify-content:center; align-items:center; height:480px; border: 2px dashed #4B5563; border-radius:12px; color:#4B5563; font-family:sans-serif;'>
<span>Result will be shown here</span>
</div>
"""
HEAD = """
<script type="module" src="https://unpkg.com/img-comparison-slider@8/dist/index.js"></script>
<link rel="stylesheet" href="https://unpkg.com/img-comparison-slider@8/dist/themes/default.css" />
"""
with gr.Blocks(theme=THEME, title="Image Colorizer", head=HEAD) as demo:
gr.Markdown("# 🎨 Image Colorizer\nWorks best on natural scenes. Learn more about the dataset we trained on [here.](http://places.csail.mit.edu/)")
result_state = gr.State()
with gr.Row():
with gr.Column(scale=5):
img_in = gr.Image(
label="Upload image",
type="pil",
image_mode="RGB",
height=320,
sources=["upload", "webcam", "clipboard"]
)
img_in.upload(fn=to_grayscale, inputs=img_in, outputs=img_in)
show_m = gr.Checkbox(label="Show metrics", value=True)
with gr.Row():
run = gr.Button("Colorize")
clr = gr.Button("Clear")
download_btn = gr.DownloadButton("Download Result", visible=False)
examples = gr.Examples(
examples=[os.path.join("examples", f) for f in os.listdir("examples")] if os.path.exists("examples") else [],
inputs=img_in,
examples_per_page=8,
label=None
)
with gr.Column(scale=7):
out_html = gr.HTML(label="Result", value=PLACEHOLDER_HTML)
with gr.Row():
mae_box = gr.Number(label="MAE", interactive=False, precision=4)
psnr_box = gr.Number(label="PSNR (dB)", interactive=False, precision=2)
ssim_box = gr.Number(label="SSIM", interactive=False, precision=4)
def _go(image, want_metrics):
out_image, mae, psnr, ssim, cmp_html = infer(image, want_metrics)
if not want_metrics:
mae = psnr = ssim = None
download_button_update = gr.update(visible=True) if out_image is not None else gr.update(visible=False)
return out_image, cmp_html, mae, psnr, ssim, download_button_update
run.click(
_go,
inputs=[img_in, show_m],
outputs=[result_state, out_html, mae_box, psnr_box, ssim_box, download_btn]
)
def _clear():
return None, None, PLACEHOLDER_HTML, None, None, None, gr.update(visible=False)
clr.click(
_clear,
inputs=None,
outputs=[img_in, result_state, out_html, mae_box, psnr_box, ssim_box, download_btn]
)
download_btn.click(
save_for_download,
inputs=[result_state],
outputs=[download_btn]
)
if __name__ == "__main__":
try:
demo.launch(show_api=False)
except TypeError:
demo.launch()