Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import spaces | |
import gradio as gr | |
import numpy as np | |
import gc | |
import torch | |
import torch.nn as nn | |
from PIL import Image, ImageDraw | |
from matplotlib import cm | |
from model_utils.extractor_dino import ViTExtractor | |
from model_utils.projection_network import AggregationNetwork, DummyAggregationNetwork | |
def resize(img, target_res=224, resize=True, to_pil=True, edge=False, sampling_filter='lanczos'): | |
filt = Image.Resampling.LANCZOS if sampling_filter == 'lanczos' else Image.Resampling.NEAREST | |
original_width, original_height = img.size | |
original_channels = len(img.getbands()) | |
if not edge: | |
canvas = np.zeros([target_res, target_res, 3], dtype=np.uint8) | |
if original_channels == 1: | |
canvas = np.zeros([target_res, target_res], dtype=np.uint8) | |
if original_height <= original_width: | |
if resize: | |
img = img.resize((target_res, int(np.around(target_res * original_height / original_width))), filt) | |
width, height = img.size | |
img = np.asarray(img) | |
canvas[(width - height) // 2: (width + height) // 2] = img | |
else: | |
if resize: | |
img = img.resize((int(np.around(target_res * original_width / original_height)), target_res), filt) | |
width, height = img.size | |
img = np.asarray(img) | |
canvas[:, (height - width) // 2: (height + width) // 2] = img | |
else: | |
if original_height <= original_width: | |
if resize: | |
img = img.resize((target_res, int(np.around(target_res * original_height / original_width))), filt) | |
width, height = img.size | |
img = np.asarray(img) | |
top_pad = (target_res - height) // 2 | |
bottom_pad = target_res - height - top_pad | |
img = np.pad(img, pad_width=[(top_pad, bottom_pad), (0, 0), (0, 0)], mode='edge') | |
else: | |
if resize: | |
img = img.resize((int(np.around(target_res * original_width / original_height)), target_res), filt) | |
width, height = img.size | |
img = np.asarray(img) | |
left_pad = (target_res - width) // 2 | |
right_pad = target_res - width - left_pad | |
img = np.pad(img, pad_width=[(0, 0), (left_pad, right_pad), (0, 0)], mode='edge') | |
canvas = img | |
if to_pil: | |
canvas = Image.fromarray(canvas) | |
return canvas | |
# βββ Feature extraction ββββββββββββββββββββββββββββββββββββββββββ | |
def get_processed_features_dino(num_patches, img,use_dummy): | |
with torch.no_grad(): | |
batch = extractor_vit.preprocess_pil(img) | |
features_dino = extractor_vit.extract_descriptors(batch.to(extractor_vit.device), layer=11, facet='token') \ | |
.permute(0,1,3,2) \ | |
.reshape(1, -1, num_patches, num_patches) | |
if use_dummy == "DINOv2": | |
desc = aggre_net_dummy(features_dino) | |
else: | |
desc = aggre_net(features_dino) | |
norms = torch.linalg.norm(desc, dim=1, keepdim=True) | |
desc = desc / (norms + 1e-8) | |
desc = desc.cpu() | |
# del batch, features_dino | |
gc.collect() | |
torch.cuda.empty_cache() | |
return desc # shape [1, C, num_patches, num_patches] | |
# βββ Similarity computation βββββββββββββββββββββββββββββββββββββββ | |
def get_sim( | |
coord: tuple[int,int], | |
feat1: torch.Tensor, | |
feat2: torch.Tensor, | |
img_size: int = 420 | |
) -> np.ndarray: | |
""" | |
Upsamples the DINO features to `img_size`, then computes cosineβsimilarity | |
between the feature at `coord` in source and every spatial location in target. | |
""" | |
y, x = coord # row, col | |
# Upsample both feature maps to [1, C, img_size, img_size] | |
src_ft = upsampler(feat1) # [1, C, img_size, img_size] | |
trg_ft = upsampler(feat2) | |
# Extract the Cβdim vector at the clicked location | |
C = src_ft.size(1) | |
src_vec = src_ft[0, :, y, x].view(1, C, 1, 1) # [1, C, 1, 1] | |
# Cosine similarity along channelβdim | |
cos = nn.CosineSimilarity(dim=1) | |
cos_map = cos(src_vec, trg_ft)[0] # [img_size, img_size] | |
return cos_map.cpu().numpy() | |
# βββ Drawing helper βββββββββββββββββββββββββββββββββββββββββββββββ | |
def draw_point(img_arr: np.ndarray, x: int, y: int, size: int, color=(255,0,0)) -> np.ndarray: | |
pil = Image.fromarray(img_arr) | |
draw = ImageDraw.Draw(pil) | |
r = size // 2 | |
draw.ellipse((x-r, y-r, x+r, y+r), fill=color, outline=color) | |
return np.array(pil) | |
# βββ Featureβupdating callback βββββββββββββββββββββββββββββββββββ | |
def update_features( | |
img: Image, | |
num_patches, | |
use_dummy | |
): | |
gc.collect() | |
torch.cuda.empty_cache() | |
""" | |
Given a PIL image, returns: | |
1) the same PIL image (so it can be displayed) | |
2) its DINO descriptor tensor, stored in a gr.State | |
""" | |
if img is None: | |
return None, None, None | |
img = resize(img, target_res=target_res, resize=True, to_pil=True) | |
feat = get_processed_features_dino(num_patches, img=img,use_dummy=use_dummy) | |
return img, feat.cpu(), Image.fromarray(np.array(img)) | |
# βββ Click handler βββββββββββββββββββββββββββββββββββββββββββββββ | |
def on_select( | |
source_pil: Image, | |
target_pil: Image, | |
feat1: torch.Tensor, | |
feat2: torch.Tensor, | |
alpha: float, | |
scatter_size: int, | |
or_tgt_img: Image, | |
or_src_img: Image, | |
sel: gr.SelectData | |
): | |
gc.collect() | |
torch.cuda.empty_cache() | |
# Convert to numpy arrays | |
src_arr = np.array(or_src_img) | |
tgt_arr = np.array(or_tgt_img) | |
# Get click coords (row, col) | |
y, x = sel.index | |
src_marked = draw_point(src_arr, y, x, scatter_size) | |
# Compute similarity map | |
sim_map = get_sim((x, y), feat1, feat2, img_size=target_res) | |
mn, mx = sim_map.min(), sim_map.max() | |
sim_norm = (sim_map - mn) / ((mx - mn) + 1e-12) | |
# Build RGBA heatmap | |
heat = cm.viridis(sim_norm) # HΓWΓ4 | |
heat[..., 3] = sim_norm * alpha # alpha channel | |
# Composite over fresh target | |
tgt_f = tgt_arr.astype(np.float32) / 255.0 | |
comp = heat[..., :3] * heat[..., 3:4] + tgt_f * (1 - heat[..., 3:4]) | |
overlay = (comp * 255).astype(np.uint8) | |
# Draw a red dot at the best match | |
my, mx_ = np.unravel_index(sim_map.argmax(), sim_map.shape) | |
overlay_marked = draw_point(overlay, mx_, my, scatter_size) | |
return src_marked,overlay_marked | |
def reload_img( | |
or_src_img: Image, | |
or_tgt_img: Image, | |
): | |
return or_src_img,or_tgt_img | |
# βββ Configuration βββββββββββββββββββββββββββββββββββββββββββββββ | |
num_patches = 45 | |
target_res = num_patches * 14 | |
ckpt_file = "./ckpts/dino_spair_0300.pth" | |
# βββ Model setup βββββββββββββββββββββββββββββββββββββββββββββββββ | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
print(f"Using device: {device}") | |
aggre_net = AggregationNetwork(feature_dims=[768], projection_dim=768, device=device) | |
aggre_net.load_pretrained_weights(torch.load(ckpt_file, map_location=device)) | |
aggre_net_dummy = DummyAggregationNetwork() | |
extractor_vit = ViTExtractor('dinov2_vitb14', stride=14, device=device) | |
aggre_net = aggre_net.eval() | |
extractor_vit.model.eval() | |
upsampler = nn.Upsample(size=(target_res, target_res), mode='bilinear', align_corners=False) | |
gc.collect() | |
torch.cuda.empty_cache() | |
# βββ Build Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββ | |
with gr.Blocks() as demo: | |
# Hidden states to hold features | |
feat1_state = gr.State() | |
feat2_state = gr.State() | |
or_tgt_img = gr.State() | |
or_src_img = gr.State() | |
# Introduction text box | |
intro_text = gr.Markdown(""" | |
## Do It Yourself: Learning Semantic Correspondence from Pseudo-Labels | |
[Project Page](https://genintel.github.io/DIY-SC) | [GitHub Repository](https://github.com/odunkel/DIY-SC) | |
Welcome to the DIY-SC demo! | |
Upload two images and select a keypoint in the source image. This demo will compute and visualize the feature similarity map and a corresponding point in the target image. | |
You can choose between the DIY-SC (DINOv2) or the DINOv2 feature extractor. | |
""") | |
# Image upload / display components | |
with gr.Row(): | |
src = gr.Image(interactive=True, type="pil", label="Source Image") | |
tgt = gr.Image(interactive=True, type="pil", label="Target Image") | |
# Controls | |
alpha = gr.State(0.7) | |
scatter = gr.State(10) | |
use_dummy = gr.Radio(["DIY-SC", "DINOv2"], value="DIY-SC", label="Feature Extractor") | |
src.input( | |
fn=update_features, | |
inputs=[src, gr.State(num_patches),use_dummy], | |
outputs=[src, feat1_state,or_src_img,] | |
) | |
tgt.input( | |
fn=update_features, | |
inputs=[tgt, gr.State(num_patches),use_dummy], | |
outputs=[tgt, feat2_state,or_tgt_img] | |
) | |
use_dummy.change( | |
fn=update_features, | |
inputs=[or_src_img, gr.State(num_patches), use_dummy], | |
outputs=[src, feat1_state, or_src_img] | |
) | |
use_dummy.change( | |
fn=update_features, | |
inputs=[or_tgt_img, gr.State(num_patches), use_dummy], | |
outputs=[tgt, feat2_state, or_tgt_img] | |
) | |
src.select( | |
fn=on_select, | |
inputs=[src, tgt, feat1_state, feat2_state, alpha, scatter,or_tgt_img,or_src_img], | |
outputs=[src,tgt] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |