odunkel commited on
Commit
fd625c4
Β·
verified Β·
1 Parent(s): 98aae95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -9
app.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  import spaces
4
  import gradio as gr
5
  import numpy as np
 
6
  import torch
7
  import torch.nn as nn
8
  from PIL import Image, ImageDraw
@@ -54,21 +55,22 @@ def resize(img, target_res=224, resize=True, to_pil=True, edge=False, sampling_f
54
  return canvas
55
 
56
  # ─── Feature extraction ──────────────────────────────────────────
57
- @spaces.GPU(duration=20)
58
  def get_processed_features_dino(num_patches, img,use_dummy):
59
- batch = extractor_vit.preprocess_pil(img)
60
- features_dino = extractor_vit.extract_descriptors(batch.to(extractor_vit.device), layer=11, facet='token') \
61
- .permute(0,1,3,2) \
62
- .reshape(1, -1, num_patches, num_patches)
63
- # Project + normalize
64
  with torch.no_grad():
 
 
 
 
65
  if use_dummy == "DINOv2":
66
  desc = aggre_net_dummy(features_dino)
67
  else:
68
  desc = aggre_net(features_dino)
69
  norms = torch.linalg.norm(desc, dim=1, keepdim=True)
70
  desc = desc / (norms + 1e-8)
71
- desc = desc.cpu()
 
 
72
  torch.cuda.empty_cache()
73
  return desc # shape [1, C, num_patches, num_patches]
74
 
@@ -86,7 +88,6 @@ def get_sim(
86
  y, x = coord # row, col
87
 
88
  # Upsample both feature maps to [1, C, img_size, img_size]
89
- upsampler = nn.Upsample(size=(img_size, img_size), mode='bilinear', align_corners=False)
90
  src_ft = upsampler(feat1) # [1, C, img_size, img_size]
91
  trg_ft = upsampler(feat2)
92
 
@@ -176,7 +177,7 @@ def reload_img(
176
 
177
 
178
  # ─── Configuration ───────────────────────────────────────────────
179
- num_patches = 45
180
  target_res = num_patches * 14
181
  ckpt_file = "./ckpts/dino_spair_0300.pth"
182
 
@@ -188,6 +189,11 @@ aggre_net.load_pretrained_weights(torch.load(ckpt_file, map_location=device))
188
  aggre_net_dummy = DummyAggregationNetwork()
189
  extractor_vit = ViTExtractor('dinov2_vitb14', stride=14, device=device)
190
 
 
 
 
 
 
191
  # ─── Build Gradio UI ──────────────────────────────────────────────
192
  with gr.Blocks() as demo:
193
  # Hidden states to hold features
 
3
  import spaces
4
  import gradio as gr
5
  import numpy as np
6
+ import gc
7
  import torch
8
  import torch.nn as nn
9
  from PIL import Image, ImageDraw
 
55
  return canvas
56
 
57
  # ─── Feature extraction ──────────────────────────────────────────
58
+ @spaces.GPU(duration=0)
59
  def get_processed_features_dino(num_patches, img,use_dummy):
 
 
 
 
 
60
  with torch.no_grad():
61
+ batch = extractor_vit.preprocess_pil(img)
62
+ features_dino = extractor_vit.extract_descriptors(batch.to(extractor_vit.device), layer=11, facet='token') \
63
+ .permute(0,1,3,2) \
64
+ .reshape(1, -1, num_patches, num_patches)
65
  if use_dummy == "DINOv2":
66
  desc = aggre_net_dummy(features_dino)
67
  else:
68
  desc = aggre_net(features_dino)
69
  norms = torch.linalg.norm(desc, dim=1, keepdim=True)
70
  desc = desc / (norms + 1e-8)
71
+ desc = desc.cpu().detach()
72
+ del batch, features_dino
73
+ gc.collect()
74
  torch.cuda.empty_cache()
75
  return desc # shape [1, C, num_patches, num_patches]
76
 
 
88
  y, x = coord # row, col
89
 
90
  # Upsample both feature maps to [1, C, img_size, img_size]
 
91
  src_ft = upsampler(feat1) # [1, C, img_size, img_size]
92
  trg_ft = upsampler(feat2)
93
 
 
177
 
178
 
179
  # ─── Configuration ───────────────────────────────────────────────
180
+ num_patches = 30
181
  target_res = num_patches * 14
182
  ckpt_file = "./ckpts/dino_spair_0300.pth"
183
 
 
189
  aggre_net_dummy = DummyAggregationNetwork()
190
  extractor_vit = ViTExtractor('dinov2_vitb14', stride=14, device=device)
191
 
192
+ aggre_net = aggre_net.eval()
193
+ extractor_vit.model.eval()
194
+
195
+ upsampler = nn.Upsample(size=(target_res, target_res), mode='bilinear', align_corners=False)
196
+
197
  # ─── Build Gradio UI ──────────────────────────────────────────────
198
  with gr.Blocks() as demo:
199
  # Hidden states to hold features