odunkel commited on
Commit
a106d62
Β·
verified Β·
1 Parent(s): d3e5665

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -54,7 +54,7 @@ def resize(img, target_res=224, resize=True, to_pil=True, edge=False, sampling_f
54
  return canvas
55
 
56
  # ─── Feature extraction ──────────────────────────────────────────
57
- @spaces.GPU
58
  def get_processed_features_dino(num_patches, img,use_dummy):
59
  batch = extractor_vit.preprocess_pil(img)
60
  features_dino = extractor_vit.extract_descriptors(batch.to(extractor_vit.device), layer=11, facet='token') \
@@ -68,11 +68,10 @@ def get_processed_features_dino(num_patches, img,use_dummy):
68
  desc = aggre_net(features_dino)
69
  norms = torch.linalg.norm(desc, dim=1, keepdim=True)
70
  desc = desc / (norms + 1e-8)
71
- desc = desc.cpu()
72
- torch.cuda.empty_cache()
73
  return desc # shape [1, C, num_patches, num_patches]
74
 
75
  # ─── Similarity computation ───────────────────────────────────────
 
76
  def get_sim(
77
  coord: tuple[int,int],
78
  feat1: torch.Tensor,
@@ -97,7 +96,9 @@ def get_sim(
97
  # Cosine similarity along channel‐dim
98
  cos = nn.CosineSimilarity(dim=1)
99
  cos_map = cos(src_vec, trg_ft)[0] # [img_size, img_size]
100
- return cos_map.cpu().numpy()
 
 
101
 
102
  # ─── Drawing helper ───────────────────────────────────────────────
103
  def draw_point(img_arr: np.ndarray, x: int, y: int, size: int, color=(255,0,0)) -> np.ndarray:
@@ -108,6 +109,7 @@ def draw_point(img_arr: np.ndarray, x: int, y: int, size: int, color=(255,0,0))
108
  return np.array(pil)
109
 
110
  # ─── Feature‐updating callback ───────────────────────────────────
 
111
  def update_features(
112
  img: Image,
113
  num_patches,
@@ -122,7 +124,7 @@ def update_features(
122
  return None, None, None
123
  img = resize(img, target_res=target_res, resize=True, to_pil=True)
124
  feat = get_processed_features_dino(num_patches, img=img,use_dummy=use_dummy)
125
- return img, feat.cpu(), Image.fromarray(np.array(img))
126
 
127
  # ─── Click handler ───────────────────────────────────────────────
128
  def on_select(
 
54
  return canvas
55
 
56
  # ─── Feature extraction ──────────────────────────────────────────
57
+ @spaces.GPU(duration=20)
58
  def get_processed_features_dino(num_patches, img,use_dummy):
59
  batch = extractor_vit.preprocess_pil(img)
60
  features_dino = extractor_vit.extract_descriptors(batch.to(extractor_vit.device), layer=11, facet='token') \
 
68
  desc = aggre_net(features_dino)
69
  norms = torch.linalg.norm(desc, dim=1, keepdim=True)
70
  desc = desc / (norms + 1e-8)
 
 
71
  return desc # shape [1, C, num_patches, num_patches]
72
 
73
  # ─── Similarity computation ───────────────────────────────────────
74
+ spaces.GPU(duration=20)
75
  def get_sim(
76
  coord: tuple[int,int],
77
  feat1: torch.Tensor,
 
96
  # Cosine similarity along channel‐dim
97
  cos = nn.CosineSimilarity(dim=1)
98
  cos_map = cos(src_vec, trg_ft)[0] # [img_size, img_size]
99
+ cos_map = cos_map.cpu()
100
+ torch.cuda.empty_cache()
101
+ return cos_map.numpy()
102
 
103
  # ─── Drawing helper ───────────────────────────────────────────────
104
  def draw_point(img_arr: np.ndarray, x: int, y: int, size: int, color=(255,0,0)) -> np.ndarray:
 
109
  return np.array(pil)
110
 
111
  # ─── Feature‐updating callback ───────────────────────────────────
112
+ spaces.GPU(duration=20)
113
  def update_features(
114
  img: Image,
115
  num_patches,
 
124
  return None, None, None
125
  img = resize(img, target_res=target_res, resize=True, to_pil=True)
126
  feat = get_processed_features_dino(num_patches, img=img,use_dummy=use_dummy)
127
+ return img, feat, Image.fromarray(np.array(img))
128
 
129
  # ─── Click handler ───────────────────────────────────────────────
130
  def on_select(