odunkel commited on
Commit
170b2d9
Β·
verified Β·
1 Parent(s): d2e532c

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -15
app.py CHANGED
@@ -52,18 +52,6 @@ def resize(img, target_res=224, resize=True, to_pil=True, edge=False, sampling_f
52
  canvas = Image.fromarray(canvas)
53
  return canvas
54
 
55
- # ─── Configuration ───────────────────────────────────────────────
56
- num_patches = 30
57
- target_res = num_patches * 14
58
- ckpt_file = "ckpts/dino_spair_0300.pth"
59
-
60
- # ─── Model setup ─────────────────────────────────────────────────
61
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
62
- aggre_net = AggregationNetwork(feature_dims=[768], projection_dim=768, device=device)
63
- aggre_net.load_pretrained_weights(torch.load(ckpt_file, map_location=device))
64
- aggre_net_dummy = DummyAggregationNetwork()
65
- extractor_vit = ViTExtractor('dinov2_vitb14', stride=14, device=device)
66
-
67
  # ─── Feature extraction ──────────────────────────────────────────
68
  def get_processed_features_dino(num_patches, img,use_dummy):
69
  batch = extractor_vit.preprocess_pil(img)
@@ -85,7 +73,7 @@ def get_sim(
85
  coord: tuple[int,int],
86
  feat1: torch.Tensor,
87
  feat2: torch.Tensor,
88
- img_size: int = target_res
89
  ) -> np.ndarray:
90
  """
91
  Upsamples the DINO features to `img_size`, then computes cosine‐similarity
@@ -182,6 +170,18 @@ def reload_img(
182
 
183
 
184
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  # ─── Build Gradio UI ──────────────────────────────────────────────
186
  with gr.Blocks() as demo:
187
  # Hidden states to hold features
@@ -194,7 +194,7 @@ with gr.Blocks() as demo:
194
  intro_text = gr.Markdown("""
195
  ## Do It Yourself: Learning Semantic Correspondence from Pseudo-Labels
196
  [Project Page](https://example.com) | [GitHub Repository](https://github.com/example/repo)
197
-
198
  Welcome to the DIY-SC demo!
199
  Upload two images and select a keypoint in the source image. This demo will compute and visualize the feature similarity map and a corresponding point in the target image.
200
  You can choose between the DIY-SC (DINOv2) or the DINOv2 feature extractor.
@@ -240,4 +240,5 @@ with gr.Blocks() as demo:
240
  outputs=[src,tgt]
241
  )
242
 
243
- demo.launch(share=True)
 
 
52
  canvas = Image.fromarray(canvas)
53
  return canvas
54
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  # ─── Feature extraction ──────────────────────────────────────────
56
  def get_processed_features_dino(num_patches, img,use_dummy):
57
  batch = extractor_vit.preprocess_pil(img)
 
73
  coord: tuple[int,int],
74
  feat1: torch.Tensor,
75
  feat2: torch.Tensor,
76
+ img_size: int = 420
77
  ) -> np.ndarray:
78
  """
79
  Upsamples the DINO features to `img_size`, then computes cosine‐similarity
 
170
 
171
 
172
 
173
+ # ─── Configuration ───────────────────────────────────────────────
174
+ num_patches = 30
175
+ target_res = num_patches * 14
176
+ ckpt_file = "ckpts/dino_spair_0300.pth"
177
+
178
+ # ─── Model setup ─────────────────────────────────────────────────
179
+ device = 'cpu' #'cuda' if torch.cuda.is_available() else 'cpu'
180
+ aggre_net = AggregationNetwork(feature_dims=[768], projection_dim=768, device=device)
181
+ aggre_net.load_pretrained_weights(torch.load(ckpt_file, map_location=device))
182
+ aggre_net_dummy = DummyAggregationNetwork()
183
+ extractor_vit = ViTExtractor('dinov2_vitb14', stride=14, device=device)
184
+
185
  # ─── Build Gradio UI ──────────────────────────────────────────────
186
  with gr.Blocks() as demo:
187
  # Hidden states to hold features
 
194
  intro_text = gr.Markdown("""
195
  ## Do It Yourself: Learning Semantic Correspondence from Pseudo-Labels
196
  [Project Page](https://example.com) | [GitHub Repository](https://github.com/example/repo)
197
+
198
  Welcome to the DIY-SC demo!
199
  Upload two images and select a keypoint in the source image. This demo will compute and visualize the feature similarity map and a corresponding point in the target image.
200
  You can choose between the DIY-SC (DINOv2) or the DINOv2 feature extractor.
 
240
  outputs=[src,tgt]
241
  )
242
 
243
+ if __name__ == "__main__":
244
+ demo.launch(share=True)