AkinyemiAra commited on
Commit
cf7175c
·
verified ·
1 Parent(s): 336cab5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -36
app.py CHANGED
@@ -16,17 +16,17 @@ import spaces
16
  from typing import List, Dict, Tuple, Optional, Union
17
 
18
  # Load model/processor
19
- model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
20
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
21
  model.eval()
22
 
23
- DATASET_DIR = Path("dataset")
24
- CACHE_FILE = "cache.pkl"
25
 
26
  # Define supported image formats
27
- IMAGE_EXTENSIONS = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif", "*.webp", "*.tiff", "*.tif"]
28
 
29
- def get_all_image_files() -> List[Path]:
30
  """
31
  Get all image files from the dataset directory.
32
 
@@ -35,7 +35,7 @@ def get_all_image_files() -> List[Path]:
35
  Returns:
36
  List[Path]: List of Path objects for all found image files
37
  """
38
- image_files = []
39
  for ext in IMAGE_EXTENSIONS:
40
  image_files.extend(DATASET_DIR.glob(ext))
41
  image_files.extend(DATASET_DIR.glob(ext.upper())) # Also check uppercase
@@ -59,7 +59,7 @@ def get_embedding(image: Image.Image, device: str = "cpu") -> torch.Tensor:
59
  inputs = processor(images=image, return_tensors="pt").to(device)
60
  model_device = model.to(device)
61
  with torch.no_grad():
62
- emb = model_device.get_image_features(**inputs)
63
  # L2 normalize the embeddings
64
  emb = emb / emb.norm(p=2, dim=-1, keepdim=True)
65
  return emb
@@ -80,29 +80,29 @@ def get_reference_embeddings() -> Dict[str, torch.Tensor]:
80
  PermissionError: If unable to write cache file
81
  """
82
  # Get all current image files
83
- current_image_files = get_all_image_files()
84
- current_images = set(img_path.name for img_path in current_image_files)
85
 
86
  # Load existing cache if it exists
87
- cached_embeddings = {}
88
  if os.path.exists(CACHE_FILE):
89
  with open(CACHE_FILE, "rb") as f:
90
  cached_embeddings = pickle.load(f)
91
 
92
  # Check if cache is up to date
93
- cached_images = set(cached_embeddings.keys())
94
 
95
  # If cache is missing images or has extra images, rebuild
96
  if current_images != cached_images:
97
  print(f"Cache outdated. Current: {len(current_images)}, Cached: {len(cached_images)}")
98
- embeddings = {}
99
- device = "cuda" if torch.cuda.is_available() else "cpu"
100
 
101
  for img_path in current_image_files:
102
  print(f"Processing {img_path.name}...")
103
  try:
104
- img = Image.open(img_path).convert("RGB")
105
- emb = get_embedding(img, device=device)
106
  embeddings[img_path.name] = emb.cpu()
107
  except Exception as e:
108
  print(f"Error processing {img_path.name}: {e}")
@@ -117,7 +117,8 @@ def get_reference_embeddings() -> Dict[str, torch.Tensor]:
117
  print(f"Using cached embeddings for {len(cached_embeddings)} images")
118
  return cached_embeddings
119
 
120
- reference_embeddings = get_reference_embeddings()
 
121
 
122
  @spaces.GPU
123
  def search_similar(query_img: Image.Image) -> List[Tuple[str, str]]:
@@ -138,21 +139,21 @@ def search_similar(query_img: Image.Image) -> List[Tuple[str, str]]:
138
  global reference_embeddings
139
  reference_embeddings = get_reference_embeddings()
140
 
141
- query_emb = get_embedding(query_img, device="cuda")
142
- results = []
143
 
144
  for name, ref_emb in reference_embeddings.items():
145
  # Move reference embedding to same device as query
146
- ref_emb_gpu = ref_emb.to("cuda")
147
  # Compute cosine similarity
148
- sim = torch.nn.functional.cosine_similarity(query_emb, ref_emb_gpu, dim=1).item()
149
  results.append((name, sim))
150
 
151
  results.sort(key=lambda x: x[1], reverse=True)
152
 
153
  # Filter out low similarity results (adjust threshold as needed)
154
- SIMILARITY_THRESHOLD = 0.2 # Only show results above 20% similarity
155
- filtered_results = [(name, score) for name, score in results if score > SIMILARITY_THRESHOLD]
156
 
157
  if not filtered_results:
158
  return [("No similar images found", "No matches above similarity threshold")]
@@ -181,12 +182,12 @@ def add_image(name: str, image: Image.Image) -> str:
181
  return "Please provide a valid image name."
182
 
183
  # Save as PNG to preserve quality for all input formats
184
- path = DATASET_DIR / f"{name}.png"
185
  image.save(path, "PNG")
186
 
187
  # Use GPU for consistency if available
188
- device = "cuda" if torch.cuda.is_available() else "cpu"
189
- emb = get_embedding(image, device=device)
190
 
191
  # Add to current embeddings and save cache
192
  reference_embeddings[f"{name}.png"] = emb.cpu()
@@ -196,15 +197,37 @@ def add_image(name: str, image: Image.Image) -> str:
196
 
197
  return f"Image '{name}' added to dataset. Total images: {len(reference_embeddings)}"
198
 
199
- search_interface = gr.Interface(fn=search_similar,
200
- inputs=gr.Image(type="pil", label="Query Image"),
201
- outputs=gr.Gallery(label="Top Matches", columns=5),
202
- allow_flagging="never")
 
 
 
 
 
203
 
204
- add_interface = gr.Interface(fn=add_image,
205
- inputs=[gr.Text(label="Image Name"), gr.Image(type="pil", label="Product Image")],
206
- outputs="text",
207
- allow_flagging="never")
 
 
 
 
 
 
 
208
 
209
- demo = gr.TabbedInterface([search_interface, add_interface], tab_names=["Search", "Add Product"])
210
- demo.launch(mcp_server=True)
 
 
 
 
 
 
 
 
 
 
 
16
  from typing import List, Dict, Tuple, Optional, Union
17
 
18
  # Load model/processor
19
+ model: CLIPModel = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
20
+ processor: CLIPProcessor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
21
  model.eval()
22
 
23
+ DATASET_DIR: Path = Path("dataset")
24
+ CACHE_FILE: str = "cache.pkl"
25
 
26
  # Define supported image formats
27
+ IMAGE_EXTENSIONS: List[str] = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif", "*.webp", "*.tiff", "*.tif"]
28
 
29
+ def get_all_image_files() -> List[Path]:
30
  """
31
  Get all image files from the dataset directory.
32
 
 
35
  Returns:
36
  List[Path]: List of Path objects for all found image files
37
  """
38
+ image_files: List[Path] = []
39
  for ext in IMAGE_EXTENSIONS:
40
  image_files.extend(DATASET_DIR.glob(ext))
41
  image_files.extend(DATASET_DIR.glob(ext.upper())) # Also check uppercase
 
59
  inputs = processor(images=image, return_tensors="pt").to(device)
60
  model_device = model.to(device)
61
  with torch.no_grad():
62
+ emb: torch.Tensor = model_device.get_image_features(**inputs)
63
  # L2 normalize the embeddings
64
  emb = emb / emb.norm(p=2, dim=-1, keepdim=True)
65
  return emb
 
80
  PermissionError: If unable to write cache file
81
  """
82
  # Get all current image files
83
+ current_image_files: List[Path] = get_all_image_files()
84
+ current_images: set = set(img_path.name for img_path in current_image_files)
85
 
86
  # Load existing cache if it exists
87
+ cached_embeddings: Dict[str, torch.Tensor] = {}
88
  if os.path.exists(CACHE_FILE):
89
  with open(CACHE_FILE, "rb") as f:
90
  cached_embeddings = pickle.load(f)
91
 
92
  # Check if cache is up to date
93
+ cached_images: set = set(cached_embeddings.keys())
94
 
95
  # If cache is missing images or has extra images, rebuild
96
  if current_images != cached_images:
97
  print(f"Cache outdated. Current: {len(current_images)}, Cached: {len(cached_images)}")
98
+ embeddings: Dict[str, torch.Tensor] = {}
99
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
100
 
101
  for img_path in current_image_files:
102
  print(f"Processing {img_path.name}...")
103
  try:
104
+ img: Image.Image = Image.open(img_path).convert("RGB")
105
+ emb: torch.Tensor = get_embedding(img, device=device)
106
  embeddings[img_path.name] = emb.cpu()
107
  except Exception as e:
108
  print(f"Error processing {img_path.name}: {e}")
 
117
  print(f"Using cached embeddings for {len(cached_embeddings)} images")
118
  return cached_embeddings
119
 
120
+ # Initialize reference embeddings
121
+ reference_embeddings: Dict[str, torch.Tensor] = get_reference_embeddings()
122
 
123
  @spaces.GPU
124
  def search_similar(query_img: Image.Image) -> List[Tuple[str, str]]:
 
139
  global reference_embeddings
140
  reference_embeddings = get_reference_embeddings()
141
 
142
+ query_emb: torch.Tensor = get_embedding(query_img, device="cuda")
143
+ results: List[Tuple[str, float]] = []
144
 
145
  for name, ref_emb in reference_embeddings.items():
146
  # Move reference embedding to same device as query
147
+ ref_emb_gpu: torch.Tensor = ref_emb.to("cuda")
148
  # Compute cosine similarity
149
+ sim: float = torch.nn.functional.cosine_similarity(query_emb, ref_emb_gpu, dim=1).item()
150
  results.append((name, sim))
151
 
152
  results.sort(key=lambda x: x[1], reverse=True)
153
 
154
  # Filter out low similarity results (adjust threshold as needed)
155
+ SIMILARITY_THRESHOLD: float = 0.2 # Only show results above 20% similarity
156
+ filtered_results: List[Tuple[str, float]] = [(name, score) for name, score in results if score > SIMILARITY_THRESHOLD]
157
 
158
  if not filtered_results:
159
  return [("No similar images found", "No matches above similarity threshold")]
 
182
  return "Please provide a valid image name."
183
 
184
  # Save as PNG to preserve quality for all input formats
185
+ path: Path = DATASET_DIR / f"{name}.png"
186
  image.save(path, "PNG")
187
 
188
  # Use GPU for consistency if available
189
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
190
+ emb: torch.Tensor = get_embedding(image, device=device)
191
 
192
  # Add to current embeddings and save cache
193
  reference_embeddings[f"{name}.png"] = emb.cpu()
 
197
 
198
  return f"Image '{name}' added to dataset. Total images: {len(reference_embeddings)}"
199
 
200
+ # Create Gradio interfaces
201
+ search_interface: gr.Interface = gr.Interface(
202
+ fn=search_similar,
203
+ inputs=gr.Image(type="pil", label="Query Image"),
204
+ outputs=gr.Gallery(label="Top Matches", columns=5),
205
+ allow_flagging="never",
206
+ title="Image Similarity Search",
207
+ description="Upload an image to find similar images in the dataset"
208
+ )
209
 
210
+ add_interface: gr.Interface = gr.Interface(
211
+ fn=add_image,
212
+ inputs=[
213
+ gr.Text(label="Image Name", placeholder="Enter a unique name for your image"),
214
+ gr.Image(type="pil", label="Product Image")
215
+ ],
216
+ outputs="text",
217
+ allow_flagging="never",
218
+ title="Add Image to Dataset",
219
+ description="Add a new image to the searchable dataset"
220
+ )
221
 
222
+ # Create main application
223
+ demo: gr.TabbedInterface = gr.TabbedInterface(
224
+ [search_interface, add_interface],
225
+ tab_names=["Search", "Add Product"],
226
+ title="CLIP Image Search System",
227
+ theme=gr.themes.Soft()
228
+ )
229
+
230
+ if __name__ == "__main__":
231
+ # Ensure dataset directory exists
232
+ DATASET_DIR.mkdir(exist_ok=True)
233
+ demo.launch(share=True)