AkinyemiAra commited on
Commit
a4d053b
·
verified ·
1 Parent(s): f77723a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -13
app.py CHANGED
@@ -8,35 +8,39 @@ import os
8
  import spaces
9
 
10
  # Load model/processor
11
- model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
12
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
13
  model.eval()
14
 
15
  DATASET_DIR = Path("dataset")
16
  CACHE_FILE = "cache.pkl"
17
 
18
- def preprocess_image(image: Image.Image) -> Image.Image:
19
- return image.resize((224, 224)).convert("RGB")
20
-
21
  def get_embedding(image: Image.Image, device="cpu"):
22
- image = preprocess_image(image)
23
  inputs = processor(images=image, return_tensors="pt").to(device)
24
  model_device = model.to(device)
25
  with torch.no_grad():
26
  emb = model_device.get_image_features(**inputs)
 
27
  emb = emb / emb.norm(p=2, dim=-1, keepdim=True)
28
  return emb
29
-
 
30
  def get_reference_embeddings():
31
  if os.path.exists(CACHE_FILE):
32
  with open(CACHE_FILE, "rb") as f:
33
  return pickle.load(f)
34
 
35
  embeddings = {}
 
 
 
36
  for img_path in DATASET_DIR.glob("*.jpg"):
37
  img = Image.open(img_path).convert("RGB")
38
- emb = get_embedding(img)
39
- embeddings[img_path.name] = emb
 
 
40
  with open(CACHE_FILE, "wb") as f:
41
  pickle.dump(embeddings, f)
42
  return embeddings
@@ -47,24 +51,35 @@ reference_embeddings = get_reference_embeddings()
47
  def search_similar(query_img):
48
  query_emb = get_embedding(query_img, device="cuda")
49
  results = []
 
50
  for name, ref_emb in reference_embeddings.items():
51
- sim = torch.nn.functional.cosine_similarity(query_emb, ref_emb.to("cuda")).item()
 
 
 
52
  results.append((name, sim))
 
53
  results.sort(key=lambda x: x[1], reverse=True)
54
  return [(f"dataset/{name}", f"Score: {score:.4f}") for name, score in results[:5]]
55
 
56
  def add_image(name: str, image):
57
  path = DATASET_DIR / f"{name}.jpg"
58
  image.save(path)
59
- emb = get_embedding(image)
60
- reference_embeddings[f"{name}.jpg"] = emb
 
 
 
 
 
 
61
  with open(CACHE_FILE, "wb") as f:
62
  pickle.dump(reference_embeddings, f)
63
  return f"Image {name} added to dataset."
64
 
65
  search_interface = gr.Interface(fn=search_similar,
66
  inputs=gr.Image(type="pil", label="Query Image"),
67
- outputs=gr.Gallery(label="Top Matches"),
68
  allow_flagging="never")
69
 
70
  add_interface = gr.Interface(fn=add_image,
 
8
  import spaces
9
 
10
  # Load model/processor
11
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch14")
12
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch14")
13
  model.eval()
14
 
15
  DATASET_DIR = Path("dataset")
16
  CACHE_FILE = "cache.pkl"
17
 
 
 
 
18
  def get_embedding(image: Image.Image, device="cpu"):
19
+ # Use CLIP's built-in preprocessing instead of custom resize
20
  inputs = processor(images=image, return_tensors="pt").to(device)
21
  model_device = model.to(device)
22
  with torch.no_grad():
23
  emb = model_device.get_image_features(**inputs)
24
+ # L2 normalize the embeddings
25
  emb = emb / emb.norm(p=2, dim=-1, keepdim=True)
26
  return emb
27
+
28
+ @space.GPU
29
  def get_reference_embeddings():
30
  if os.path.exists(CACHE_FILE):
31
  with open(CACHE_FILE, "rb") as f:
32
  return pickle.load(f)
33
 
34
  embeddings = {}
35
+ # Use GPU for preprocessing reference images too for consistency
36
+ device = "cuda" if torch.cuda.is_available() else "cpu"
37
+
38
  for img_path in DATASET_DIR.glob("*.jpg"):
39
  img = Image.open(img_path).convert("RGB")
40
+ emb = get_embedding(img, device=device)
41
+ # Store on CPU to save GPU memory
42
+ embeddings[img_path.name] = emb.cpu()
43
+
44
  with open(CACHE_FILE, "wb") as f:
45
  pickle.dump(embeddings, f)
46
  return embeddings
 
51
  def search_similar(query_img):
52
  query_emb = get_embedding(query_img, device="cuda")
53
  results = []
54
+
55
  for name, ref_emb in reference_embeddings.items():
56
+ # Move reference embedding to same device as query
57
+ ref_emb_gpu = ref_emb.to("cuda")
58
+ # Compute cosine similarity
59
+ sim = torch.nn.functional.cosine_similarity(query_emb, ref_emb_gpu, dim=1).item()
60
  results.append((name, sim))
61
+
62
  results.sort(key=lambda x: x[1], reverse=True)
63
  return [(f"dataset/{name}", f"Score: {score:.4f}") for name, score in results[:5]]
64
 
65
  def add_image(name: str, image):
66
  path = DATASET_DIR / f"{name}.jpg"
67
  image.save(path)
68
+
69
+ # Use GPU for consistency if available
70
+ device = "cuda" if torch.cuda.is_available() else "cpu"
71
+ emb = get_embedding(image, device=device)
72
+
73
+ # Store on CPU to save memory
74
+ reference_embeddings[f"{name}.jpg"] = emb.cpu()
75
+
76
  with open(CACHE_FILE, "wb") as f:
77
  pickle.dump(reference_embeddings, f)
78
  return f"Image {name} added to dataset."
79
 
80
  search_interface = gr.Interface(fn=search_similar,
81
  inputs=gr.Image(type="pil", label="Query Image"),
82
+ outputs=gr.Gallery(label="Top Matches", columns=5),
83
  allow_flagging="never")
84
 
85
  add_interface = gr.Interface(fn=add_image,