Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
from transformers import CLIPProcessor, CLIPModel
|
3 |
from PIL import Image
|
@@ -18,15 +25,35 @@ CACHE_FILE = "cache.pkl"
|
|
18 |
# Define supported image formats
|
19 |
IMAGE_EXTENSIONS = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif", "*.webp", "*.tiff", "*.tif"]
|
20 |
|
21 |
-
def get_all_image_files():
|
22 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
image_files = []
|
24 |
for ext in IMAGE_EXTENSIONS:
|
25 |
image_files.extend(DATASET_DIR.glob(ext))
|
26 |
image_files.extend(DATASET_DIR.glob(ext.upper())) # Also check uppercase
|
27 |
return image_files
|
28 |
|
29 |
-
def get_embedding(image: Image.Image, device="cpu"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
# Use CLIP's built-in preprocessing
|
31 |
inputs = processor(images=image, return_tensors="pt").to(device)
|
32 |
model_device = model.to(device)
|
@@ -37,7 +64,20 @@ def get_embedding(image: Image.Image, device="cpu"):
|
|
37 |
return emb
|
38 |
|
39 |
@spaces.GPU
|
40 |
-
def get_reference_embeddings():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
# Get all current image files
|
42 |
current_image_files = get_all_image_files()
|
43 |
current_images = set(img_path.name for img_path in current_image_files)
|
@@ -79,7 +119,20 @@ def get_reference_embeddings():
|
|
79 |
reference_embeddings = get_reference_embeddings()
|
80 |
|
81 |
@spaces.GPU
|
82 |
-
def search_similar(query_img):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
# Refresh embeddings to catch any new images
|
84 |
global reference_embeddings
|
85 |
reference_embeddings = get_reference_embeddings()
|
@@ -107,7 +160,22 @@ def search_similar(query_img):
|
|
107 |
return [(f"dataset/{name}", f"Score: {score:.4f}") for name, score in filtered_results[:5]]
|
108 |
|
109 |
@spaces.GPU
|
110 |
-
def add_image(name: str, image):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
if not name.strip():
|
112 |
return "Please provide a valid image name."
|
113 |
|
|
|
1 |
+
"""
|
2 |
+
CLIP Image Search Application
|
3 |
+
|
4 |
+
A Gradio-based application for searching similar images using OpenAI's CLIP model.
|
5 |
+
Supports multiple image formats and provides a web interface for uploading and searching images.
|
6 |
+
"""
|
7 |
+
|
8 |
import gradio as gr
|
9 |
from transformers import CLIPProcessor, CLIPModel
|
10 |
from PIL import Image
|
|
|
25 |
# Define supported image formats
|
26 |
IMAGE_EXTENSIONS = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif", "*.webp", "*.tiff", "*.tif"]
|
27 |
|
28 |
+
def get_all_image_files() -> List[Path]:
|
29 |
+
"""
|
30 |
+
Get all image files from the dataset directory.
|
31 |
+
|
32 |
+
Searches for images with supported extensions in both lowercase and uppercase.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
List[Path]: List of Path objects for all found image files
|
36 |
+
"""
|
37 |
image_files = []
|
38 |
for ext in IMAGE_EXTENSIONS:
|
39 |
image_files.extend(DATASET_DIR.glob(ext))
|
40 |
image_files.extend(DATASET_DIR.glob(ext.upper())) # Also check uppercase
|
41 |
return image_files
|
42 |
|
43 |
+
def get_embedding(image: Image.Image, device: str = "cpu") -> torch.Tensor:
|
44 |
+
"""
|
45 |
+
Generate CLIP embedding for an image.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
image (Image.Image): PIL Image object to process
|
49 |
+
device (str, optional): Device to run computation on. Defaults to "cpu".
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
torch.Tensor: L2-normalized image embedding tensor
|
53 |
+
|
54 |
+
Raises:
|
55 |
+
RuntimeError: If CUDA is requested but not available
|
56 |
+
"""
|
57 |
# Use CLIP's built-in preprocessing
|
58 |
inputs = processor(images=image, return_tensors="pt").to(device)
|
59 |
model_device = model.to(device)
|
|
|
64 |
return emb
|
65 |
|
66 |
@spaces.GPU
|
67 |
+
def get_reference_embeddings() -> Dict[str, torch.Tensor]:
|
68 |
+
"""
|
69 |
+
Load or compute embeddings for all reference images in the dataset.
|
70 |
+
|
71 |
+
Checks if cached embeddings are up to date with the current dataset.
|
72 |
+
If not, recomputes embeddings for all images and updates the cache.
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
Dict[str, torch.Tensor]: Dictionary mapping image filenames to their embeddings
|
76 |
+
|
77 |
+
Raises:
|
78 |
+
FileNotFoundError: If dataset directory doesn't exist
|
79 |
+
PermissionError: If unable to write cache file
|
80 |
+
"""
|
81 |
# Get all current image files
|
82 |
current_image_files = get_all_image_files()
|
83 |
current_images = set(img_path.name for img_path in current_image_files)
|
|
|
119 |
reference_embeddings = get_reference_embeddings()
|
120 |
|
121 |
@spaces.GPU
|
122 |
+
def search_similar(query_img: Image.Image) -> List[Tuple[str, str]]:
|
123 |
+
"""
|
124 |
+
Find similar images to the query image using CLIP embeddings.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
query_img (Image.Image): Query image to find similar images for
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
List[Tuple[str, str]]: List of tuples containing (image_path, similarity_score)
|
131 |
+
Limited to top 5 results above similarity threshold
|
132 |
+
|
133 |
+
Raises:
|
134 |
+
RuntimeError: If CUDA operations fail
|
135 |
+
"""
|
136 |
# Refresh embeddings to catch any new images
|
137 |
global reference_embeddings
|
138 |
reference_embeddings = get_reference_embeddings()
|
|
|
160 |
return [(f"dataset/{name}", f"Score: {score:.4f}") for name, score in filtered_results[:5]]
|
161 |
|
162 |
@spaces.GPU
|
163 |
+
def add_image(name: str, image: Image.Image) -> str:
|
164 |
+
"""
|
165 |
+
Add a new image to the dataset and update embeddings.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
name (str): Name for the new image (without extension)
|
169 |
+
image (Image.Image): PIL Image object to add to dataset
|
170 |
+
|
171 |
+
Returns:
|
172 |
+
str: Success message with total image count
|
173 |
+
|
174 |
+
Raises:
|
175 |
+
ValueError: If name is empty or invalid
|
176 |
+
PermissionError: If unable to save image or update cache
|
177 |
+
RuntimeError: If embedding computation fails
|
178 |
+
"""
|
179 |
if not name.strip():
|
180 |
return "Please provide a valid image name."
|
181 |
|