Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -3,6 +3,7 @@ import os
|
|
3 |
import spaces
|
4 |
import gradio as gr
|
5 |
import numpy as np
|
|
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
from PIL import Image, ImageDraw
|
@@ -54,21 +55,22 @@ def resize(img, target_res=224, resize=True, to_pil=True, edge=False, sampling_f
|
|
54 |
return canvas
|
55 |
|
56 |
# βββ Feature extraction ββββββββββββββββββββββββββββββββββββββββββ
|
57 |
-
@spaces.GPU(duration=
|
58 |
def get_processed_features_dino(num_patches, img,use_dummy):
|
59 |
-
batch = extractor_vit.preprocess_pil(img)
|
60 |
-
features_dino = extractor_vit.extract_descriptors(batch.to(extractor_vit.device), layer=11, facet='token') \
|
61 |
-
.permute(0,1,3,2) \
|
62 |
-
.reshape(1, -1, num_patches, num_patches)
|
63 |
-
# Project + normalize
|
64 |
with torch.no_grad():
|
|
|
|
|
|
|
|
|
65 |
if use_dummy == "DINOv2":
|
66 |
desc = aggre_net_dummy(features_dino)
|
67 |
else:
|
68 |
desc = aggre_net(features_dino)
|
69 |
norms = torch.linalg.norm(desc, dim=1, keepdim=True)
|
70 |
desc = desc / (norms + 1e-8)
|
71 |
-
desc = desc.cpu()
|
|
|
|
|
72 |
torch.cuda.empty_cache()
|
73 |
return desc # shape [1, C, num_patches, num_patches]
|
74 |
|
@@ -86,7 +88,6 @@ def get_sim(
|
|
86 |
y, x = coord # row, col
|
87 |
|
88 |
# Upsample both feature maps to [1, C, img_size, img_size]
|
89 |
-
upsampler = nn.Upsample(size=(img_size, img_size), mode='bilinear', align_corners=False)
|
90 |
src_ft = upsampler(feat1) # [1, C, img_size, img_size]
|
91 |
trg_ft = upsampler(feat2)
|
92 |
|
@@ -176,7 +177,7 @@ def reload_img(
|
|
176 |
|
177 |
|
178 |
# βββ Configuration βββββββββββββββββββββββββββββββββββββββββββββββ
|
179 |
-
num_patches =
|
180 |
target_res = num_patches * 14
|
181 |
ckpt_file = "./ckpts/dino_spair_0300.pth"
|
182 |
|
@@ -188,6 +189,11 @@ aggre_net.load_pretrained_weights(torch.load(ckpt_file, map_location=device))
|
|
188 |
aggre_net_dummy = DummyAggregationNetwork()
|
189 |
extractor_vit = ViTExtractor('dinov2_vitb14', stride=14, device=device)
|
190 |
|
|
|
|
|
|
|
|
|
|
|
191 |
# βββ Build Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββ
|
192 |
with gr.Blocks() as demo:
|
193 |
# Hidden states to hold features
|
|
|
3 |
import spaces
|
4 |
import gradio as gr
|
5 |
import numpy as np
|
6 |
+
import gc
|
7 |
import torch
|
8 |
import torch.nn as nn
|
9 |
from PIL import Image, ImageDraw
|
|
|
55 |
return canvas
|
56 |
|
57 |
# βββ Feature extraction ββββββββββββββββββββββββββββββββββββββββββ
|
58 |
+
@spaces.GPU(duration=0)
|
59 |
def get_processed_features_dino(num_patches, img,use_dummy):
|
|
|
|
|
|
|
|
|
|
|
60 |
with torch.no_grad():
|
61 |
+
batch = extractor_vit.preprocess_pil(img)
|
62 |
+
features_dino = extractor_vit.extract_descriptors(batch.to(extractor_vit.device), layer=11, facet='token') \
|
63 |
+
.permute(0,1,3,2) \
|
64 |
+
.reshape(1, -1, num_patches, num_patches)
|
65 |
if use_dummy == "DINOv2":
|
66 |
desc = aggre_net_dummy(features_dino)
|
67 |
else:
|
68 |
desc = aggre_net(features_dino)
|
69 |
norms = torch.linalg.norm(desc, dim=1, keepdim=True)
|
70 |
desc = desc / (norms + 1e-8)
|
71 |
+
desc = desc.cpu().detach()
|
72 |
+
del batch, features_dino
|
73 |
+
gc.collect()
|
74 |
torch.cuda.empty_cache()
|
75 |
return desc # shape [1, C, num_patches, num_patches]
|
76 |
|
|
|
88 |
y, x = coord # row, col
|
89 |
|
90 |
# Upsample both feature maps to [1, C, img_size, img_size]
|
|
|
91 |
src_ft = upsampler(feat1) # [1, C, img_size, img_size]
|
92 |
trg_ft = upsampler(feat2)
|
93 |
|
|
|
177 |
|
178 |
|
179 |
# βββ Configuration βββββββββββββββββββββββββββββββββββββββββββββββ
|
180 |
+
num_patches = 30
|
181 |
target_res = num_patches * 14
|
182 |
ckpt_file = "./ckpts/dino_spair_0300.pth"
|
183 |
|
|
|
189 |
aggre_net_dummy = DummyAggregationNetwork()
|
190 |
extractor_vit = ViTExtractor('dinov2_vitb14', stride=14, device=device)
|
191 |
|
192 |
+
aggre_net = aggre_net.eval()
|
193 |
+
extractor_vit.model.eval()
|
194 |
+
|
195 |
+
upsampler = nn.Upsample(size=(target_res, target_res), mode='bilinear', align_corners=False)
|
196 |
+
|
197 |
# βββ Build Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββ
|
198 |
with gr.Blocks() as demo:
|
199 |
# Hidden states to hold features
|