""" CLIP Image Search Application A Gradio-based application for searching similar images using OpenAI's CLIP model. Supports multiple image formats and provides a web interface for uploading and searching images. """ import gradio as gr from transformers import CLIPProcessor, CLIPModel from PIL import Image import torch import pickle from pathlib import Path import os import spaces from typing import List, Dict, Tuple, Optional, Union # Load model/processor model: CLIPModel = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") processor: CLIPProcessor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") model.eval() DATASET_DIR: Path = Path("dataset") CACHE_FILE: str = "cache.pkl" # Define supported image formats IMAGE_EXTENSIONS: List[str] = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif", "*.webp", "*.tiff", "*.tif"] def get_all_image_files() -> List[Path]: """ Get all image files from the dataset directory. Searches for images with supported extensions in both lowercase and uppercase. Returns: List[Path]: List of Path objects for all found image files """ image_files: List[Path] = [] for ext in IMAGE_EXTENSIONS: image_files.extend(DATASET_DIR.glob(ext)) image_files.extend(DATASET_DIR.glob(ext.upper())) # Also check uppercase return image_files def get_embedding(image: Image.Image, device: str = "cpu") -> torch.Tensor: """ Generate CLIP embedding for an image. Args: image (Image.Image): PIL Image object to process device (str, optional): Device to run computation on. Defaults to "cpu". Returns: torch.Tensor: L2-normalized image embedding tensor Raises: RuntimeError: If CUDA is requested but not available """ # Use CLIP's built-in preprocessing inputs = processor(images=image, return_tensors="pt").to(device) model_device = model.to(device) with torch.no_grad(): emb: torch.Tensor = model_device.get_image_features(**inputs) # L2 normalize the embeddings emb = emb / emb.norm(p=2, dim=-1, keepdim=True) return emb @spaces.GPU def get_reference_embeddings() -> Dict[str, torch.Tensor]: """ Load or compute embeddings for all reference images in the dataset. Checks if cached embeddings are up to date with the current dataset. If not, recomputes embeddings for all images and updates the cache. Returns: Dict[str, torch.Tensor]: Dictionary mapping image filenames to their embeddings Raises: FileNotFoundError: If dataset directory doesn't exist PermissionError: If unable to write cache file """ # Get all current image files current_image_files: List[Path] = get_all_image_files() current_images: set = set(img_path.name for img_path in current_image_files) # Load existing cache if it exists cached_embeddings: Dict[str, torch.Tensor] = {} if os.path.exists(CACHE_FILE): with open(CACHE_FILE, "rb") as f: cached_embeddings = pickle.load(f) # Check if cache is up to date cached_images: set = set(cached_embeddings.keys()) # If cache is missing images or has extra images, rebuild if current_images != cached_images: print(f"Cache outdated. Current: {len(current_images)}, Cached: {len(cached_images)}") embeddings: Dict[str, torch.Tensor] = {} device: str = "cuda" if torch.cuda.is_available() else "cpu" for img_path in current_image_files: print(f"Processing {img_path.name}...") try: img: Image.Image = Image.open(img_path).convert("RGB") emb: torch.Tensor = get_embedding(img, device=device) embeddings[img_path.name] = emb.cpu() except Exception as e: print(f"Error processing {img_path.name}: {e}") continue # Save updated cache with open(CACHE_FILE, "wb") as f: pickle.dump(embeddings, f) print(f"Cache updated with {len(embeddings)} images") return embeddings else: print(f"Using cached embeddings for {len(cached_embeddings)} images") return cached_embeddings # Initialize reference embeddings reference_embeddings: Dict[str, torch.Tensor] = get_reference_embeddings() @spaces.GPU def search_similar(query_img: Image.Image) -> List[Tuple[str, str]]: """ Find similar images to the query image using CLIP embeddings. Args: query_img (Image.Image): Query image to find similar images for Returns: List[Tuple[str, str]]: List of tuples containing (image_path, similarity_score) Limited to top 5 results above similarity threshold Raises: RuntimeError: If CUDA operations fail """ # Refresh embeddings to catch any new images global reference_embeddings reference_embeddings = get_reference_embeddings() query_emb: torch.Tensor = get_embedding(query_img, device="cuda") results: List[Tuple[str, float]] = [] for name, ref_emb in reference_embeddings.items(): # Move reference embedding to same device as query ref_emb_gpu: torch.Tensor = ref_emb.to("cuda") # Compute cosine similarity sim: float = torch.nn.functional.cosine_similarity(query_emb, ref_emb_gpu, dim=1).item() results.append((name, sim)) results.sort(key=lambda x: x[1], reverse=True) # Filter out low similarity results (adjust threshold as needed) SIMILARITY_THRESHOLD: float = 0.2 # Only show results above 20% similarity filtered_results: List[Tuple[str, float]] = [(name, score) for name, score in results if score > SIMILARITY_THRESHOLD] if not filtered_results: return [("No similar images found", "No matches above similarity threshold")] # Return top 5 results return [(f"dataset/{name}", f"Score: {score:.4f}") for name, score in filtered_results[:5]] @spaces.GPU def add_image(name: str, image: Image.Image) -> str: """ Add a new image to the dataset and update embeddings. Args: name (str): Name for the new image (without extension) image (Image.Image): PIL Image object to add to dataset Returns: str: Success message with total image count Raises: ValueError: If name is empty or invalid PermissionError: If unable to save image or update cache RuntimeError: If embedding computation fails """ if not name.strip(): return "Please provide a valid image name." # Save as PNG to preserve quality for all input formats path: Path = DATASET_DIR / f"{name}.png" image.save(path, "PNG") # Use GPU for consistency if available device: str = "cuda" if torch.cuda.is_available() else "cpu" emb: torch.Tensor = get_embedding(image, device=device) # Add to current embeddings and save cache reference_embeddings[f"{name}.png"] = emb.cpu() with open(CACHE_FILE, "wb") as f: pickle.dump(reference_embeddings, f) return f"Image '{name}' added to dataset. Total images: {len(reference_embeddings)}" # Create Gradio interfaces search_interface: gr.Interface = gr.Interface( fn=search_similar, inputs=gr.Image(type="pil", label="Query Image"), outputs=gr.Gallery(label="Top Matches", columns=5), allow_flagging="never", title="Image Similarity Search", description="Upload an image to find similar images in the dataset" ) add_interface: gr.Interface = gr.Interface( fn=add_image, inputs=[ gr.Text(label="Image Name", placeholder="Enter a unique name for your image"), gr.Image(type="pil", label="Product Image") ], outputs="text", allow_flagging="never", title="Add Image to Dataset", description="Add a new image to the searchable dataset" ) # Create main application demo: gr.TabbedInterface = gr.TabbedInterface( [search_interface, add_interface], tab_names=["Search", "Add Product"], title="CLIP Image Search System", theme=gr.themes.Soft() ) if __name__ == "__main__": # Ensure dataset directory exists DATASET_DIR.mkdir(exist_ok=True) demo.launch(share=True, mcp_server=True)