ImgSearch / app.py
AkinyemiAra's picture
initial commit
e6e631c verified
raw
history blame
2.71 kB
import gradio as gr
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
import pickle
from pathlib import Path
import os
import spaces
# Load model/processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model.eval()
DATASET_DIR = Path("dataset")
CACHE_FILE = "cache.pkl"
def preprocess_image(image: Image.Image) -> Image.Image:
return image.resize((224, 224)).convert("RGB")
def get_embedding(image: Image.Image, device="cpu"):
image = preprocess_image(image)
inputs = processor(images=image, return_tensors="pt").to(device)
model_device = model.to(device)
with torch.no_grad():
emb = model_device.get_image_features(**inputs)
emb = emb / emb.norm(p=2, dim=-1, keepdim=True)
return emb
def get_reference_embeddings():
if os.path.exists(CACHE_FILE):
with open(CACHE_FILE, "rb") as f:
return pickle.load(f)
embeddings = {}
for img_path in DATASET_DIR.glob("*.jpg"):
img = Image.open(img_path).convert("RGB")
emb = get_embedding(img)
embeddings[img_path.name] = emb
with open(CACHE_FILE, "wb") as f:
pickle.dump(embeddings, f)
return embeddings
reference_embeddings = get_reference_embeddings()
@spaces.GPU
def search_similar(query_img):
query_emb = get_embedding(query_img, device="cuda")
results = []
for name, ref_emb in reference_embeddings.items():
sim = torch.nn.functional.cosine_similarity(query_emb, ref_emb.to("cuda")).item()
results.append((name, sim))
results.sort(key=lambda x: x[1], reverse=True)
return [(f"dataset/{name}", f"Score: {score:.4f}") for name, score in results[:5]]
def add_image(name: str, image):
path = DATASET_DIR / f"{name}.jpg"
image.save(path)
emb = get_embedding(image)
reference_embeddings[f"{name}.jpg"] = emb
with open(CACHE_FILE, "wb") as f:
pickle.dump(reference_embeddings, f)
return f"Image {name} added to dataset."
search_interface = gr.Interface(fn=search_similar,
inputs=gr.Image(type="pil", label="Query Image"),
outputs=gr.Gallery(label="Top Matches").style(grid=5),
allow_flagging="never")
add_interface = gr.Interface(fn=add_image,
inputs=[gr.Text(label="Image Name"), gr.Image(type="pil", label="Product Image")],
outputs="text",
allow_flagging="never")
demo = gr.TabbedInterface([search_interface, add_interface], tab_names=["Search", "Add Product"])
demo.launch()