AkinyemiAra commited on
Commit
e1286f2
·
verified ·
1 Parent(s): 0856eb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -19
app.py CHANGED
@@ -24,30 +24,49 @@ def get_embedding(image: Image.Image, device="cpu"):
24
  # L2 normalize the embeddings
25
  emb = emb / emb.norm(p=2, dim=-1, keepdim=True)
26
  return emb
27
-
28
  def get_reference_embeddings():
 
 
 
 
 
29
  if os.path.exists(CACHE_FILE):
30
  with open(CACHE_FILE, "rb") as f:
31
- return pickle.load(f)
32
-
33
- embeddings = {}
34
- # Use GPU for preprocessing reference images too for consistency
35
- device = "cuda" if torch.cuda.is_available() else "cpu"
36
 
37
- for img_path in DATASET_DIR.glob("*.jpg"):
38
- img = Image.open(img_path).convert("RGB")
39
- emb = get_embedding(img, device=device)
40
- # Store on CPU to save GPU memory
41
- embeddings[img_path.name] = emb.cpu()
42
 
43
- with open(CACHE_FILE, "wb") as f:
44
- pickle.dump(embeddings, f)
45
- return embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  reference_embeddings = get_reference_embeddings()
48
 
49
  @spaces.GPU
50
  def search_similar(query_img):
 
 
 
 
51
  query_emb = get_embedding(query_img, device="cuda")
52
  results = []
53
 
@@ -59,10 +78,21 @@ def search_similar(query_img):
59
  results.append((name, sim))
60
 
61
  results.sort(key=lambda x: x[1], reverse=True)
62
- return [(f"dataset/{name}", f"Score: {score:.4f}") for name, score in results[:5]]
 
 
 
 
 
 
 
 
 
63
 
64
- @spaces.GPU
65
  def add_image(name: str, image):
 
 
 
66
  path = DATASET_DIR / f"{name}.jpg"
67
  image.save(path)
68
 
@@ -70,12 +100,13 @@ def add_image(name: str, image):
70
  device = "cuda" if torch.cuda.is_available() else "cpu"
71
  emb = get_embedding(image, device=device)
72
 
73
- # Store on CPU to save memory
74
  reference_embeddings[f"{name}.jpg"] = emb.cpu()
75
 
76
  with open(CACHE_FILE, "wb") as f:
77
  pickle.dump(reference_embeddings, f)
78
- return f"Image {name} added to dataset."
 
79
 
80
  search_interface = gr.Interface(fn=search_similar,
81
  inputs=gr.Image(type="pil", label="Query Image"),
@@ -88,4 +119,4 @@ add_interface = gr.Interface(fn=add_image,
88
  allow_flagging="never")
89
 
90
  demo = gr.TabbedInterface([search_interface, add_interface], tab_names=["Search", "Add Product"])
91
- demo.launch()
 
24
  # L2 normalize the embeddings
25
  emb = emb / emb.norm(p=2, dim=-1, keepdim=True)
26
  return emb
27
+
28
  def get_reference_embeddings():
29
+ # Get all current image files
30
+ current_images = set(img_path.name for img_path in DATASET_DIR.glob("*.jpg"))
31
+
32
+ # Load existing cache if it exists
33
+ cached_embeddings = {}
34
  if os.path.exists(CACHE_FILE):
35
  with open(CACHE_FILE, "rb") as f:
36
+ cached_embeddings = pickle.load(f)
 
 
 
 
37
 
38
+ # Check if cache is up to date
39
+ cached_images = set(cached_embeddings.keys())
 
 
 
40
 
41
+ # If cache is missing images or has extra images, rebuild
42
+ if current_images != cached_images:
43
+ print(f"Cache outdated. Current: {len(current_images)}, Cached: {len(cached_images)}")
44
+ embeddings = {}
45
+ device = "cuda" if torch.cuda.is_available() else "cpu"
46
+
47
+ for img_path in DATASET_DIR.glob("*.jpg"):
48
+ print(f"Processing {img_path.name}...")
49
+ img = Image.open(img_path).convert("RGB")
50
+ emb = get_embedding(img, device=device)
51
+ embeddings[img_path.name] = emb.cpu()
52
+
53
+ # Save updated cache
54
+ with open(CACHE_FILE, "wb") as f:
55
+ pickle.dump(embeddings, f)
56
+ print(f"Cache updated with {len(embeddings)} images")
57
+ return embeddings
58
+ else:
59
+ print(f"Using cached embeddings for {len(cached_embeddings)} images")
60
+ return cached_embeddings
61
 
62
  reference_embeddings = get_reference_embeddings()
63
 
64
  @spaces.GPU
65
  def search_similar(query_img):
66
+ # Refresh embeddings to catch any new images
67
+ global reference_embeddings
68
+ reference_embeddings = get_reference_embeddings()
69
+
70
  query_emb = get_embedding(query_img, device="cuda")
71
  results = []
72
 
 
78
  results.append((name, sim))
79
 
80
  results.sort(key=lambda x: x[1], reverse=True)
81
+
82
+ # Filter out low similarity results (adjust threshold as needed)
83
+ SIMILARITY_THRESHOLD = 0.2 # Only show results above 20% similarity
84
+ filtered_results = [(name, score) for name, score in results if score > SIMILARITY_THRESHOLD]
85
+
86
+ if not filtered_results:
87
+ return [("No similar images found", "No matches above similarity threshold")]
88
+
89
+ # Return top 5 results
90
+ return [(f"dataset/{name}", f"Score: {score:.4f}") for name, score in filtered_results[:5]]
91
 
 
92
  def add_image(name: str, image):
93
+ if not name.strip():
94
+ return "Please provide a valid image name."
95
+
96
  path = DATASET_DIR / f"{name}.jpg"
97
  image.save(path)
98
 
 
100
  device = "cuda" if torch.cuda.is_available() else "cpu"
101
  emb = get_embedding(image, device=device)
102
 
103
+ # Add to current embeddings and save cache
104
  reference_embeddings[f"{name}.jpg"] = emb.cpu()
105
 
106
  with open(CACHE_FILE, "wb") as f:
107
  pickle.dump(reference_embeddings, f)
108
+
109
+ return f"Image '{name}' added to dataset. Total images: {len(reference_embeddings)}"
110
 
111
  search_interface = gr.Interface(fn=search_similar,
112
  inputs=gr.Image(type="pil", label="Query Image"),
 
119
  allow_flagging="never")
120
 
121
  demo = gr.TabbedInterface([search_interface, add_interface], tab_names=["Search", "Add Product"])
122
+ demo.launch()