File size: 10,112 Bytes
79cc514
 
162c36b
79cc514
 
fd625c4
79cc514
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f5372c
79cc514
 
fd625c4
 
 
 
79cc514
 
 
 
 
 
31c4dba
 
fd625c4
ea88e6c
d3e5665
79cc514
 
 
 
 
 
170b2d9
79cc514
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea88e6c
79cc514
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31c4dba
98aae95
79cc514
 
 
 
 
 
 
 
 
ea88e6c
79cc514
 
 
 
 
 
 
 
 
 
 
 
 
31c4dba
 
79cc514
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170b2d9
2f5372c
170b2d9
162c36b
170b2d9
 
162c36b
 
170b2d9
 
 
 
 
fd625c4
 
 
 
 
31c4dba
 
 
79cc514
 
 
 
 
 
 
 
 
 
 
1510936
170b2d9
79cc514
 
d2e532c
79cc514
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170b2d9
162c36b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263

import os
import spaces
import gradio as gr
import numpy as np
import gc
import torch
import torch.nn as nn
from PIL import Image, ImageDraw
from matplotlib import cm

from model_utils.extractor_dino import ViTExtractor
from model_utils.projection_network import AggregationNetwork, DummyAggregationNetwork

def resize(img, target_res=224, resize=True, to_pil=True, edge=False, sampling_filter='lanczos'):
    filt = Image.Resampling.LANCZOS if sampling_filter == 'lanczos' else Image.Resampling.NEAREST
    original_width, original_height = img.size
    original_channels = len(img.getbands())
    if not edge:
        canvas = np.zeros([target_res, target_res, 3], dtype=np.uint8)
        if original_channels == 1:
            canvas = np.zeros([target_res, target_res], dtype=np.uint8)
        if original_height <= original_width:
            if resize:
                img = img.resize((target_res, int(np.around(target_res * original_height / original_width))), filt)
            width, height = img.size
            img = np.asarray(img)
            canvas[(width - height) // 2: (width + height) // 2] = img
        else:
            if resize:
                img = img.resize((int(np.around(target_res * original_width / original_height)), target_res), filt)
            width, height = img.size
            img = np.asarray(img)
            canvas[:, (height - width) // 2: (height + width) // 2] = img
    else:
        if original_height <= original_width:
            if resize:
                img = img.resize((target_res, int(np.around(target_res * original_height / original_width))), filt)
            width, height = img.size
            img = np.asarray(img)
            top_pad = (target_res - height) // 2
            bottom_pad = target_res - height - top_pad
            img = np.pad(img, pad_width=[(top_pad, bottom_pad), (0, 0), (0, 0)], mode='edge')
        else:
            if resize:
                img = img.resize((int(np.around(target_res * original_width / original_height)), target_res), filt)
            width, height = img.size
            img = np.asarray(img)
            left_pad = (target_res - width) // 2
            right_pad = target_res - width - left_pad
            img = np.pad(img, pad_width=[(0, 0), (left_pad, right_pad), (0, 0)], mode='edge')
        canvas = img
    if to_pil:
        canvas = Image.fromarray(canvas)
    return canvas

# ─── Feature extraction ──────────────────────────────────────────
@spaces.GPU(duration=20)
def get_processed_features_dino(num_patches, img,use_dummy):
    with torch.no_grad():
        batch = extractor_vit.preprocess_pil(img)
        features_dino = extractor_vit.extract_descriptors(batch.to(extractor_vit.device), layer=11, facet='token') \
                                        .permute(0,1,3,2) \
                                        .reshape(1, -1, num_patches, num_patches)
        if use_dummy == "DINOv2":
            desc = aggre_net_dummy(features_dino)
        else:
            desc = aggre_net(features_dino)
        norms = torch.linalg.norm(desc, dim=1, keepdim=True)
        desc = desc / (norms + 1e-8)
    desc = desc.cpu()
    # del batch, features_dino
    gc.collect()
    torch.cuda.empty_cache()
    return desc  # shape [1, C, num_patches, num_patches]

# ─── Similarity computation ───────────────────────────────────────
def get_sim(
    coord: tuple[int,int],
    feat1: torch.Tensor,
    feat2: torch.Tensor,
    img_size: int = 420
) -> np.ndarray:
    """
    Upsamples the DINO features to `img_size`, then computes cosine‐similarity
    between the feature at `coord` in source and every spatial location in target.
    """
    y, x = coord  # row, col

    # Upsample both feature maps to [1, C, img_size, img_size]
    src_ft = upsampler(feat1)  # [1, C, img_size, img_size]
    trg_ft = upsampler(feat2)

    # Extract the C‐dim vector at the clicked location
    C = src_ft.size(1)
    src_vec = src_ft[0, :, y, x].view(1, C, 1, 1)  # [1, C, 1, 1]

    # Cosine similarity along channel‐dim
    cos = nn.CosineSimilarity(dim=1)
    cos_map = cos(src_vec, trg_ft)[0]           # [img_size, img_size]
    return cos_map.cpu().numpy()         

# ─── Drawing helper ───────────────────────────────────────────────
def draw_point(img_arr: np.ndarray, x: int, y: int, size: int, color=(255,0,0)) -> np.ndarray:
    pil = Image.fromarray(img_arr)
    draw = ImageDraw.Draw(pil)
    r = size // 2
    draw.ellipse((x-r, y-r, x+r, y+r), fill=color, outline=color)
    return np.array(pil)

# ─── Feature‐updating callback ───────────────────────────────────
def update_features(
    img: Image,
    num_patches,
    use_dummy
):
    gc.collect()
    torch.cuda.empty_cache()
    """
    Given a PIL image, returns:
      1) the same PIL image (so it can be displayed)
      2) its DINO descriptor tensor, stored in a gr.State
    """
    if img is None:
        return None, None, None
    img = resize(img, target_res=target_res, resize=True, to_pil=True)
    feat = get_processed_features_dino(num_patches, img=img,use_dummy=use_dummy)
    return img, feat.cpu(), Image.fromarray(np.array(img))

# ─── Click handler ───────────────────────────────────────────────
def on_select(
    source_pil: Image,
    target_pil: Image,
    feat1: torch.Tensor,
    feat2: torch.Tensor,
    alpha: float,
    scatter_size: int,
    or_tgt_img: Image,
    or_src_img: Image,
    sel: gr.SelectData
):
    gc.collect()
    torch.cuda.empty_cache()
    # Convert to numpy arrays
    src_arr = np.array(or_src_img)
    tgt_arr = np.array(or_tgt_img)

    # Get click coords (row, col)
    y, x = sel.index

    src_marked = draw_point(src_arr, y, x, scatter_size)

    # Compute similarity map
    sim_map = get_sim((x, y), feat1, feat2, img_size=target_res)

    mn, mx = sim_map.min(), sim_map.max()
    sim_norm = (sim_map - mn) / ((mx - mn) + 1e-12)

    # Build RGBA heatmap
    heat = cm.viridis(sim_norm)           # HΓ—WΓ—4
    heat[..., 3] = sim_norm * alpha       # alpha channel

    # Composite over fresh target
    tgt_f = tgt_arr.astype(np.float32) / 255.0
    comp = heat[..., :3] * heat[..., 3:4] + tgt_f * (1 - heat[..., 3:4])
    overlay = (comp * 255).astype(np.uint8)

    # Draw a red dot at the best match
    my, mx_ = np.unravel_index(sim_map.argmax(), sim_map.shape)
    overlay_marked = draw_point(overlay, mx_, my, scatter_size)

    return src_marked,overlay_marked

def reload_img(
        or_src_img: Image,
        or_tgt_img: Image,
):
    return or_src_img,or_tgt_img
    


# ─── Configuration ───────────────────────────────────────────────
num_patches = 45
target_res = num_patches * 14
ckpt_file = "./ckpts/dino_spair_0300.pth"

# ─── Model setup ─────────────────────────────────────────────────
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
aggre_net = AggregationNetwork(feature_dims=[768], projection_dim=768, device=device)
aggre_net.load_pretrained_weights(torch.load(ckpt_file, map_location=device))
aggre_net_dummy  = DummyAggregationNetwork()
extractor_vit = ViTExtractor('dinov2_vitb14', stride=14, device=device)

aggre_net = aggre_net.eval()
extractor_vit.model.eval()

upsampler = nn.Upsample(size=(target_res, target_res), mode='bilinear', align_corners=False)

gc.collect()
torch.cuda.empty_cache()

# ─── Build Gradio UI ──────────────────────────────────────────────
with gr.Blocks() as demo:
    # Hidden states to hold features
    feat1_state = gr.State()
    feat2_state = gr.State()
    or_tgt_img = gr.State()
    or_src_img = gr.State()

    # Introduction text box
    intro_text = gr.Markdown("""
    ## Do It Yourself: Learning Semantic Correspondence from Pseudo-Labels
    [Project Page](https://genintel.github.io/DIY-SC) | [GitHub Repository](https://github.com/odunkel/DIY-SC)
                            
    Welcome to the DIY-SC demo!
    Upload two images and select a keypoint in the source image. This demo will compute and visualize the feature similarity map and a corresponding point in the target image.
    You can choose between the DIY-SC (DINOv2) or the DINOv2 feature extractor.
    """)

    # Image upload / display components
    with gr.Row():
        src = gr.Image(interactive=True, type="pil", label="Source Image")
        tgt = gr.Image(interactive=True, type="pil", label="Target Image")

    # Controls
    alpha   = gr.State(0.7)
    scatter = gr.State(10)
    use_dummy = gr.Radio(["DIY-SC", "DINOv2"], value="DIY-SC", label="Feature Extractor")

    src.input(
        fn=update_features,
        inputs=[src, gr.State(num_patches),use_dummy],
        outputs=[src, feat1_state,or_src_img,]
    )

    tgt.input(
        fn=update_features,
        inputs=[tgt, gr.State(num_patches),use_dummy],
        outputs=[tgt, feat2_state,or_tgt_img]
    )

    use_dummy.change(
        fn=update_features,
        inputs=[or_src_img, gr.State(num_patches), use_dummy],
        outputs=[src, feat1_state, or_src_img]
    )

    use_dummy.change(
        fn=update_features,
        inputs=[or_tgt_img, gr.State(num_patches), use_dummy],
        outputs=[tgt, feat2_state, or_tgt_img]
    )

    src.select(
        fn=on_select,
        inputs=[src, tgt, feat1_state, feat2_state, alpha, scatter,or_tgt_img,or_src_img],
        outputs=[src,tgt]
    )

if __name__ == "__main__":
    demo.launch()