DIY-SC / app.py
odunkel's picture
Update app.py
31c4dba verified
import os
import spaces
import gradio as gr
import numpy as np
import gc
import torch
import torch.nn as nn
from PIL import Image, ImageDraw
from matplotlib import cm
from model_utils.extractor_dino import ViTExtractor
from model_utils.projection_network import AggregationNetwork, DummyAggregationNetwork
def resize(img, target_res=224, resize=True, to_pil=True, edge=False, sampling_filter='lanczos'):
filt = Image.Resampling.LANCZOS if sampling_filter == 'lanczos' else Image.Resampling.NEAREST
original_width, original_height = img.size
original_channels = len(img.getbands())
if not edge:
canvas = np.zeros([target_res, target_res, 3], dtype=np.uint8)
if original_channels == 1:
canvas = np.zeros([target_res, target_res], dtype=np.uint8)
if original_height <= original_width:
if resize:
img = img.resize((target_res, int(np.around(target_res * original_height / original_width))), filt)
width, height = img.size
img = np.asarray(img)
canvas[(width - height) // 2: (width + height) // 2] = img
else:
if resize:
img = img.resize((int(np.around(target_res * original_width / original_height)), target_res), filt)
width, height = img.size
img = np.asarray(img)
canvas[:, (height - width) // 2: (height + width) // 2] = img
else:
if original_height <= original_width:
if resize:
img = img.resize((target_res, int(np.around(target_res * original_height / original_width))), filt)
width, height = img.size
img = np.asarray(img)
top_pad = (target_res - height) // 2
bottom_pad = target_res - height - top_pad
img = np.pad(img, pad_width=[(top_pad, bottom_pad), (0, 0), (0, 0)], mode='edge')
else:
if resize:
img = img.resize((int(np.around(target_res * original_width / original_height)), target_res), filt)
width, height = img.size
img = np.asarray(img)
left_pad = (target_res - width) // 2
right_pad = target_res - width - left_pad
img = np.pad(img, pad_width=[(0, 0), (left_pad, right_pad), (0, 0)], mode='edge')
canvas = img
if to_pil:
canvas = Image.fromarray(canvas)
return canvas
# ─── Feature extraction ──────────────────────────────────────────
@spaces.GPU(duration=20)
def get_processed_features_dino(num_patches, img,use_dummy):
with torch.no_grad():
batch = extractor_vit.preprocess_pil(img)
features_dino = extractor_vit.extract_descriptors(batch.to(extractor_vit.device), layer=11, facet='token') \
.permute(0,1,3,2) \
.reshape(1, -1, num_patches, num_patches)
if use_dummy == "DINOv2":
desc = aggre_net_dummy(features_dino)
else:
desc = aggre_net(features_dino)
norms = torch.linalg.norm(desc, dim=1, keepdim=True)
desc = desc / (norms + 1e-8)
desc = desc.cpu()
# del batch, features_dino
gc.collect()
torch.cuda.empty_cache()
return desc # shape [1, C, num_patches, num_patches]
# ─── Similarity computation ───────────────────────────────────────
def get_sim(
coord: tuple[int,int],
feat1: torch.Tensor,
feat2: torch.Tensor,
img_size: int = 420
) -> np.ndarray:
"""
Upsamples the DINO features to `img_size`, then computes cosine‐similarity
between the feature at `coord` in source and every spatial location in target.
"""
y, x = coord # row, col
# Upsample both feature maps to [1, C, img_size, img_size]
src_ft = upsampler(feat1) # [1, C, img_size, img_size]
trg_ft = upsampler(feat2)
# Extract the C‐dim vector at the clicked location
C = src_ft.size(1)
src_vec = src_ft[0, :, y, x].view(1, C, 1, 1) # [1, C, 1, 1]
# Cosine similarity along channel‐dim
cos = nn.CosineSimilarity(dim=1)
cos_map = cos(src_vec, trg_ft)[0] # [img_size, img_size]
return cos_map.cpu().numpy()
# ─── Drawing helper ───────────────────────────────────────────────
def draw_point(img_arr: np.ndarray, x: int, y: int, size: int, color=(255,0,0)) -> np.ndarray:
pil = Image.fromarray(img_arr)
draw = ImageDraw.Draw(pil)
r = size // 2
draw.ellipse((x-r, y-r, x+r, y+r), fill=color, outline=color)
return np.array(pil)
# ─── Feature‐updating callback ───────────────────────────────────
def update_features(
img: Image,
num_patches,
use_dummy
):
gc.collect()
torch.cuda.empty_cache()
"""
Given a PIL image, returns:
1) the same PIL image (so it can be displayed)
2) its DINO descriptor tensor, stored in a gr.State
"""
if img is None:
return None, None, None
img = resize(img, target_res=target_res, resize=True, to_pil=True)
feat = get_processed_features_dino(num_patches, img=img,use_dummy=use_dummy)
return img, feat.cpu(), Image.fromarray(np.array(img))
# ─── Click handler ───────────────────────────────────────────────
def on_select(
source_pil: Image,
target_pil: Image,
feat1: torch.Tensor,
feat2: torch.Tensor,
alpha: float,
scatter_size: int,
or_tgt_img: Image,
or_src_img: Image,
sel: gr.SelectData
):
gc.collect()
torch.cuda.empty_cache()
# Convert to numpy arrays
src_arr = np.array(or_src_img)
tgt_arr = np.array(or_tgt_img)
# Get click coords (row, col)
y, x = sel.index
src_marked = draw_point(src_arr, y, x, scatter_size)
# Compute similarity map
sim_map = get_sim((x, y), feat1, feat2, img_size=target_res)
mn, mx = sim_map.min(), sim_map.max()
sim_norm = (sim_map - mn) / ((mx - mn) + 1e-12)
# Build RGBA heatmap
heat = cm.viridis(sim_norm) # HΓ—WΓ—4
heat[..., 3] = sim_norm * alpha # alpha channel
# Composite over fresh target
tgt_f = tgt_arr.astype(np.float32) / 255.0
comp = heat[..., :3] * heat[..., 3:4] + tgt_f * (1 - heat[..., 3:4])
overlay = (comp * 255).astype(np.uint8)
# Draw a red dot at the best match
my, mx_ = np.unravel_index(sim_map.argmax(), sim_map.shape)
overlay_marked = draw_point(overlay, mx_, my, scatter_size)
return src_marked,overlay_marked
def reload_img(
or_src_img: Image,
or_tgt_img: Image,
):
return or_src_img,or_tgt_img
# ─── Configuration ───────────────────────────────────────────────
num_patches = 45
target_res = num_patches * 14
ckpt_file = "./ckpts/dino_spair_0300.pth"
# ─── Model setup ─────────────────────────────────────────────────
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
aggre_net = AggregationNetwork(feature_dims=[768], projection_dim=768, device=device)
aggre_net.load_pretrained_weights(torch.load(ckpt_file, map_location=device))
aggre_net_dummy = DummyAggregationNetwork()
extractor_vit = ViTExtractor('dinov2_vitb14', stride=14, device=device)
aggre_net = aggre_net.eval()
extractor_vit.model.eval()
upsampler = nn.Upsample(size=(target_res, target_res), mode='bilinear', align_corners=False)
gc.collect()
torch.cuda.empty_cache()
# ─── Build Gradio UI ──────────────────────────────────────────────
with gr.Blocks() as demo:
# Hidden states to hold features
feat1_state = gr.State()
feat2_state = gr.State()
or_tgt_img = gr.State()
or_src_img = gr.State()
# Introduction text box
intro_text = gr.Markdown("""
## Do It Yourself: Learning Semantic Correspondence from Pseudo-Labels
[Project Page](https://genintel.github.io/DIY-SC) | [GitHub Repository](https://github.com/odunkel/DIY-SC)
Welcome to the DIY-SC demo!
Upload two images and select a keypoint in the source image. This demo will compute and visualize the feature similarity map and a corresponding point in the target image.
You can choose between the DIY-SC (DINOv2) or the DINOv2 feature extractor.
""")
# Image upload / display components
with gr.Row():
src = gr.Image(interactive=True, type="pil", label="Source Image")
tgt = gr.Image(interactive=True, type="pil", label="Target Image")
# Controls
alpha = gr.State(0.7)
scatter = gr.State(10)
use_dummy = gr.Radio(["DIY-SC", "DINOv2"], value="DIY-SC", label="Feature Extractor")
src.input(
fn=update_features,
inputs=[src, gr.State(num_patches),use_dummy],
outputs=[src, feat1_state,or_src_img,]
)
tgt.input(
fn=update_features,
inputs=[tgt, gr.State(num_patches),use_dummy],
outputs=[tgt, feat2_state,or_tgt_img]
)
use_dummy.change(
fn=update_features,
inputs=[or_src_img, gr.State(num_patches), use_dummy],
outputs=[src, feat1_state, or_src_img]
)
use_dummy.change(
fn=update_features,
inputs=[or_tgt_img, gr.State(num_patches), use_dummy],
outputs=[tgt, feat2_state, or_tgt_img]
)
src.select(
fn=on_select,
inputs=[src, tgt, feat1_state, feat2_state, alpha, scatter,or_tgt_img,or_src_img],
outputs=[src,tgt]
)
if __name__ == "__main__":
demo.launch()