AkinyemiAra commited on
Commit
55bb1f4
·
verified ·
1 Parent(s): c0e2011

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -6
app.py CHANGED
@@ -1,3 +1,10 @@
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from transformers import CLIPProcessor, CLIPModel
3
  from PIL import Image
@@ -18,15 +25,35 @@ CACHE_FILE = "cache.pkl"
18
  # Define supported image formats
19
  IMAGE_EXTENSIONS = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif", "*.webp", "*.tiff", "*.tif"]
20
 
21
- def get_all_image_files():
22
- """Get all image files from dataset directory"""
 
 
 
 
 
 
 
23
  image_files = []
24
  for ext in IMAGE_EXTENSIONS:
25
  image_files.extend(DATASET_DIR.glob(ext))
26
  image_files.extend(DATASET_DIR.glob(ext.upper())) # Also check uppercase
27
  return image_files
28
 
29
- def get_embedding(image: Image.Image, device="cpu"):
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  # Use CLIP's built-in preprocessing
31
  inputs = processor(images=image, return_tensors="pt").to(device)
32
  model_device = model.to(device)
@@ -37,7 +64,20 @@ def get_embedding(image: Image.Image, device="cpu"):
37
  return emb
38
 
39
  @spaces.GPU
40
- def get_reference_embeddings():
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  # Get all current image files
42
  current_image_files = get_all_image_files()
43
  current_images = set(img_path.name for img_path in current_image_files)
@@ -79,7 +119,20 @@ def get_reference_embeddings():
79
  reference_embeddings = get_reference_embeddings()
80
 
81
  @spaces.GPU
82
- def search_similar(query_img):
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  # Refresh embeddings to catch any new images
84
  global reference_embeddings
85
  reference_embeddings = get_reference_embeddings()
@@ -107,7 +160,22 @@ def search_similar(query_img):
107
  return [(f"dataset/{name}", f"Score: {score:.4f}") for name, score in filtered_results[:5]]
108
 
109
  @spaces.GPU
110
- def add_image(name: str, image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  if not name.strip():
112
  return "Please provide a valid image name."
113
 
 
1
+ """
2
+ CLIP Image Search Application
3
+
4
+ A Gradio-based application for searching similar images using OpenAI's CLIP model.
5
+ Supports multiple image formats and provides a web interface for uploading and searching images.
6
+ """
7
+
8
  import gradio as gr
9
  from transformers import CLIPProcessor, CLIPModel
10
  from PIL import Image
 
25
  # Define supported image formats
26
  IMAGE_EXTENSIONS = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif", "*.webp", "*.tiff", "*.tif"]
27
 
28
+ def get_all_image_files() -> List[Path]:
29
+ """
30
+ Get all image files from the dataset directory.
31
+
32
+ Searches for images with supported extensions in both lowercase and uppercase.
33
+
34
+ Returns:
35
+ List[Path]: List of Path objects for all found image files
36
+ """
37
  image_files = []
38
  for ext in IMAGE_EXTENSIONS:
39
  image_files.extend(DATASET_DIR.glob(ext))
40
  image_files.extend(DATASET_DIR.glob(ext.upper())) # Also check uppercase
41
  return image_files
42
 
43
+ def get_embedding(image: Image.Image, device: str = "cpu") -> torch.Tensor:
44
+ """
45
+ Generate CLIP embedding for an image.
46
+
47
+ Args:
48
+ image (Image.Image): PIL Image object to process
49
+ device (str, optional): Device to run computation on. Defaults to "cpu".
50
+
51
+ Returns:
52
+ torch.Tensor: L2-normalized image embedding tensor
53
+
54
+ Raises:
55
+ RuntimeError: If CUDA is requested but not available
56
+ """
57
  # Use CLIP's built-in preprocessing
58
  inputs = processor(images=image, return_tensors="pt").to(device)
59
  model_device = model.to(device)
 
64
  return emb
65
 
66
  @spaces.GPU
67
+ def get_reference_embeddings() -> Dict[str, torch.Tensor]:
68
+ """
69
+ Load or compute embeddings for all reference images in the dataset.
70
+
71
+ Checks if cached embeddings are up to date with the current dataset.
72
+ If not, recomputes embeddings for all images and updates the cache.
73
+
74
+ Returns:
75
+ Dict[str, torch.Tensor]: Dictionary mapping image filenames to their embeddings
76
+
77
+ Raises:
78
+ FileNotFoundError: If dataset directory doesn't exist
79
+ PermissionError: If unable to write cache file
80
+ """
81
  # Get all current image files
82
  current_image_files = get_all_image_files()
83
  current_images = set(img_path.name for img_path in current_image_files)
 
119
  reference_embeddings = get_reference_embeddings()
120
 
121
  @spaces.GPU
122
+ def search_similar(query_img: Image.Image) -> List[Tuple[str, str]]:
123
+ """
124
+ Find similar images to the query image using CLIP embeddings.
125
+
126
+ Args:
127
+ query_img (Image.Image): Query image to find similar images for
128
+
129
+ Returns:
130
+ List[Tuple[str, str]]: List of tuples containing (image_path, similarity_score)
131
+ Limited to top 5 results above similarity threshold
132
+
133
+ Raises:
134
+ RuntimeError: If CUDA operations fail
135
+ """
136
  # Refresh embeddings to catch any new images
137
  global reference_embeddings
138
  reference_embeddings = get_reference_embeddings()
 
160
  return [(f"dataset/{name}", f"Score: {score:.4f}") for name, score in filtered_results[:5]]
161
 
162
  @spaces.GPU
163
+ def add_image(name: str, image: Image.Image) -> str:
164
+ """
165
+ Add a new image to the dataset and update embeddings.
166
+
167
+ Args:
168
+ name (str): Name for the new image (without extension)
169
+ image (Image.Image): PIL Image object to add to dataset
170
+
171
+ Returns:
172
+ str: Success message with total image count
173
+
174
+ Raises:
175
+ ValueError: If name is empty or invalid
176
+ PermissionError: If unable to save image or update cache
177
+ RuntimeError: If embedding computation fails
178
+ """
179
  if not name.strip():
180
  return "Please provide a valid image name."
181