File size: 3,202 Bytes
fb6458d e6e631c fb6458d e6e631c 8c86939 e6e631c fb6458d e6e631c a4d053b e6e631c a4d053b e6e631c a4d053b e6e631c a4d053b e6e631c a4d053b e6e631c a4d053b e6e631c a4d053b e6e631c a4d053b e6e631c a0a00d7 e6e631c a4d053b e6e631c a4d053b e6e631c fb6458d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import gradio as gr
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
import pickle
from pathlib import Path
import os
import spaces
# Load model/processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model.eval()
DATASET_DIR = Path("dataset")
CACHE_FILE = "cache.pkl"
def get_embedding(image: Image.Image, device="cpu"):
# Use CLIP's built-in preprocessing instead of custom resize
inputs = processor(images=image, return_tensors="pt").to(device)
model_device = model.to(device)
with torch.no_grad():
emb = model_device.get_image_features(**inputs)
# L2 normalize the embeddings
emb = emb / emb.norm(p=2, dim=-1, keepdim=True)
return emb
def get_reference_embeddings():
if os.path.exists(CACHE_FILE):
with open(CACHE_FILE, "rb") as f:
return pickle.load(f)
embeddings = {}
# Use GPU for preprocessing reference images too for consistency
device = "cuda" if torch.cuda.is_available() else "cpu"
for img_path in DATASET_DIR.glob("*.jpg"):
img = Image.open(img_path).convert("RGB")
emb = get_embedding(img, device=device)
# Store on CPU to save GPU memory
embeddings[img_path.name] = emb.cpu()
with open(CACHE_FILE, "wb") as f:
pickle.dump(embeddings, f)
return embeddings
reference_embeddings = get_reference_embeddings()
@spaces.GPU
def search_similar(query_img):
query_emb = get_embedding(query_img, device="cuda")
results = []
for name, ref_emb in reference_embeddings.items():
# Move reference embedding to same device as query
ref_emb_gpu = ref_emb.to("cuda")
# Compute cosine similarity
sim = torch.nn.functional.cosine_similarity(query_emb, ref_emb_gpu, dim=1).item()
results.append((name, sim))
results.sort(key=lambda x: x[1], reverse=True)
return [(f"dataset/{name}", f"Score: {score:.4f}") for name, score in results[:5]]
@spaces.GPU
def add_image(name: str, image):
path = DATASET_DIR / f"{name}.jpg"
image.save(path)
# Use GPU for consistency if available
device = "cuda" if torch.cuda.is_available() else "cpu"
emb = get_embedding(image, device=device)
# Store on CPU to save memory
reference_embeddings[f"{name}.jpg"] = emb.cpu()
with open(CACHE_FILE, "wb") as f:
pickle.dump(reference_embeddings, f)
return f"Image {name} added to dataset."
search_interface = gr.Interface(fn=search_similar,
inputs=gr.Image(type="pil", label="Query Image"),
outputs=gr.Gallery(label="Top Matches", columns=5),
allow_flagging="never")
add_interface = gr.Interface(fn=add_image,
inputs=[gr.Text(label="Image Name"), gr.Image(type="pil", label="Product Image")],
outputs="text",
allow_flagging="never")
demo = gr.TabbedInterface([search_interface, add_interface], tab_names=["Search", "Add Product"])
demo.launch()
|