import gradio as gr from transformers import CLIPProcessor, CLIPModel from PIL import Image import torch import pickle from pathlib import Path import os import spaces # Load model/processor model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") model.eval() DATASET_DIR = Path("dataset") CACHE_FILE = "cache.pkl" def preprocess_image(image: Image.Image) -> Image.Image: return image.resize((224, 224)).convert("RGB") def get_embedding(image: Image.Image, device="cpu"): image = preprocess_image(image) inputs = processor(images=image, return_tensors="pt").to(device) model_device = model.to(device) with torch.no_grad(): emb = model_device.get_image_features(**inputs) emb = emb / emb.norm(p=2, dim=-1, keepdim=True) return emb def get_reference_embeddings(): if os.path.exists(CACHE_FILE): with open(CACHE_FILE, "rb") as f: return pickle.load(f) embeddings = {} for img_path in DATASET_DIR.glob("*.jpg"): img = Image.open(img_path).convert("RGB") emb = get_embedding(img) embeddings[img_path.name] = emb with open(CACHE_FILE, "wb") as f: pickle.dump(embeddings, f) return embeddings reference_embeddings = get_reference_embeddings() @spaces.GPU def search_similar(query_img): query_emb = get_embedding(query_img, device="cuda") results = [] for name, ref_emb in reference_embeddings.items(): sim = torch.nn.functional.cosine_similarity(query_emb, ref_emb.to("cuda")).item() results.append((name, sim)) results.sort(key=lambda x: x[1], reverse=True) return [(f"dataset/{name}", f"Score: {score:.4f}") for name, score in results[:5]] def add_image(name: str, image): path = DATASET_DIR / f"{name}.jpg" image.save(path) emb = get_embedding(image) reference_embeddings[f"{name}.jpg"] = emb with open(CACHE_FILE, "wb") as f: pickle.dump(reference_embeddings, f) return f"Image {name} added to dataset." search_interface = gr.Interface(fn=search_similar, inputs=gr.Image(type="pil", label="Query Image"), outputs=gr.Gallery(label="Top Matches").style(grid=5), allow_flagging="never") add_interface = gr.Interface(fn=add_image, inputs=[gr.Text(label="Image Name"), gr.Image(type="pil", label="Product Image")], outputs="text", allow_flagging="never") demo = gr.TabbedInterface([search_interface, add_interface], tab_names=["Search", "Add Product"]) demo.launch()