import gradio as gr from transformers import CLIPProcessor, CLIPModel from PIL import Image import torch import pickle from pathlib import Path import os import spaces # Load model/processor model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") model.eval() DATASET_DIR = Path("dataset") CACHE_FILE = "cache.pkl" def get_embedding(image: Image.Image, device="cpu"): # Use CLIP's built-in preprocessing instead of custom resize inputs = processor(images=image, return_tensors="pt").to(device) model_device = model.to(device) with torch.no_grad(): emb = model_device.get_image_features(**inputs) # L2 normalize the embeddings emb = emb / emb.norm(p=2, dim=-1, keepdim=True) return emb def get_reference_embeddings(): if os.path.exists(CACHE_FILE): with open(CACHE_FILE, "rb") as f: return pickle.load(f) embeddings = {} # Use GPU for preprocessing reference images too for consistency device = "cuda" if torch.cuda.is_available() else "cpu" for img_path in DATASET_DIR.glob("*.jpg"): img = Image.open(img_path).convert("RGB") emb = get_embedding(img, device=device) # Store on CPU to save GPU memory embeddings[img_path.name] = emb.cpu() with open(CACHE_FILE, "wb") as f: pickle.dump(embeddings, f) return embeddings reference_embeddings = get_reference_embeddings() @spaces.GPU def search_similar(query_img): query_emb = get_embedding(query_img, device="cuda") results = [] for name, ref_emb in reference_embeddings.items(): # Move reference embedding to same device as query ref_emb_gpu = ref_emb.to("cuda") # Compute cosine similarity sim = torch.nn.functional.cosine_similarity(query_emb, ref_emb_gpu, dim=1).item() results.append((name, sim)) results.sort(key=lambda x: x[1], reverse=True) return [(f"dataset/{name}", f"Score: {score:.4f}") for name, score in results[:5]] @spaces.GPU def add_image(name: str, image): path = DATASET_DIR / f"{name}.jpg" image.save(path) # Use GPU for consistency if available device = "cuda" if torch.cuda.is_available() else "cpu" emb = get_embedding(image, device=device) # Store on CPU to save memory reference_embeddings[f"{name}.jpg"] = emb.cpu() with open(CACHE_FILE, "wb") as f: pickle.dump(reference_embeddings, f) return f"Image {name} added to dataset." search_interface = gr.Interface(fn=search_similar, inputs=gr.Image(type="pil", label="Query Image"), outputs=gr.Gallery(label="Top Matches", columns=5), allow_flagging="never") add_interface = gr.Interface(fn=add_image, inputs=[gr.Text(label="Image Name"), gr.Image(type="pil", label="Product Image")], outputs="text", allow_flagging="never") demo = gr.TabbedInterface([search_interface, add_interface], tab_names=["Search", "Add Product"]) demo.launch()