|
import gradio as gr |
|
from transformers import CLIPProcessor, CLIPModel |
|
from PIL import Image |
|
import torch |
|
import pickle |
|
from pathlib import Path |
|
import os |
|
import spaces |
|
|
|
|
|
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
|
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
model.eval() |
|
|
|
DATASET_DIR = Path("dataset") |
|
CACHE_FILE = "cache.pkl" |
|
|
|
def preprocess_image(image: Image.Image) -> Image.Image: |
|
return image.resize((224, 224)).convert("RGB") |
|
|
|
def get_embedding(image: Image.Image, device="cpu"): |
|
image = preprocess_image(image) |
|
inputs = processor(images=image, return_tensors="pt").to(device) |
|
model_device = model.to(device) |
|
with torch.no_grad(): |
|
emb = model_device.get_image_features(**inputs) |
|
emb = emb / emb.norm(p=2, dim=-1, keepdim=True) |
|
return emb |
|
|
|
def get_reference_embeddings(): |
|
if os.path.exists(CACHE_FILE): |
|
with open(CACHE_FILE, "rb") as f: |
|
return pickle.load(f) |
|
|
|
embeddings = {} |
|
for img_path in DATASET_DIR.glob("*.jpg"): |
|
img = Image.open(img_path).convert("RGB") |
|
emb = get_embedding(img) |
|
embeddings[img_path.name] = emb |
|
with open(CACHE_FILE, "wb") as f: |
|
pickle.dump(embeddings, f) |
|
return embeddings |
|
|
|
reference_embeddings = get_reference_embeddings() |
|
|
|
@spaces.GPU |
|
def search_similar(query_img): |
|
query_emb = get_embedding(query_img, device="cuda") |
|
results = [] |
|
for name, ref_emb in reference_embeddings.items(): |
|
sim = torch.nn.functional.cosine_similarity(query_emb, ref_emb.to("cuda")).item() |
|
results.append((name, sim)) |
|
results.sort(key=lambda x: x[1], reverse=True) |
|
return [(f"dataset/{name}", f"Score: {score:.4f}") for name, score in results[:5]] |
|
|
|
def add_image(name: str, image): |
|
path = DATASET_DIR / f"{name}.jpg" |
|
image.save(path) |
|
emb = get_embedding(image) |
|
reference_embeddings[f"{name}.jpg"] = emb |
|
with open(CACHE_FILE, "wb") as f: |
|
pickle.dump(reference_embeddings, f) |
|
return f"Image {name} added to dataset." |
|
|
|
search_interface = gr.Interface(fn=search_similar, |
|
inputs=gr.Image(type="pil", label="Query Image"), |
|
outputs=gr.Gallery(label="Top Matches").style(grid=5), |
|
allow_flagging="never") |
|
|
|
add_interface = gr.Interface(fn=add_image, |
|
inputs=[gr.Text(label="Image Name"), gr.Image(type="pil", label="Product Image")], |
|
outputs="text", |
|
allow_flagging="never") |
|
|
|
demo = gr.TabbedInterface([search_interface, add_interface], tab_names=["Search", "Add Product"]) |
|
demo.launch() |
|
|