File size: 3,202 Bytes
fb6458d
e6e631c
 
 
 
 
 
 
fb6458d
e6e631c
8c86939
 
e6e631c
fb6458d
e6e631c
 
 
 
a4d053b
e6e631c
 
 
 
a4d053b
e6e631c
 
a4d053b
e6e631c
 
 
 
 
 
a4d053b
 
 
e6e631c
 
a4d053b
 
 
 
e6e631c
 
 
 
 
 
 
 
 
 
a4d053b
e6e631c
a4d053b
 
 
 
e6e631c
a4d053b
e6e631c
 
 
a0a00d7
e6e631c
 
 
a4d053b
 
 
 
 
 
 
 
e6e631c
 
 
 
 
 
a4d053b
e6e631c
 
 
 
 
 
 
 
fb6458d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import gradio as gr
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
import pickle
from pathlib import Path
import os
import spaces

# Load model/processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model.eval()

DATASET_DIR = Path("dataset")
CACHE_FILE = "cache.pkl"

def get_embedding(image: Image.Image, device="cpu"):
    # Use CLIP's built-in preprocessing instead of custom resize
    inputs = processor(images=image, return_tensors="pt").to(device)
    model_device = model.to(device)
    with torch.no_grad():
        emb = model_device.get_image_features(**inputs)
    # L2 normalize the embeddings
    emb = emb / emb.norm(p=2, dim=-1, keepdim=True)
    return emb
    
def get_reference_embeddings():
    if os.path.exists(CACHE_FILE):
        with open(CACHE_FILE, "rb") as f:
            return pickle.load(f)

    embeddings = {}
    # Use GPU for preprocessing reference images too for consistency
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    for img_path in DATASET_DIR.glob("*.jpg"):
        img = Image.open(img_path).convert("RGB")
        emb = get_embedding(img, device=device)
        # Store on CPU to save GPU memory
        embeddings[img_path.name] = emb.cpu()
    
    with open(CACHE_FILE, "wb") as f:
        pickle.dump(embeddings, f)
    return embeddings

reference_embeddings = get_reference_embeddings()

@spaces.GPU
def search_similar(query_img):
    query_emb = get_embedding(query_img, device="cuda")
    results = []
    
    for name, ref_emb in reference_embeddings.items():
        # Move reference embedding to same device as query
        ref_emb_gpu = ref_emb.to("cuda")
        # Compute cosine similarity
        sim = torch.nn.functional.cosine_similarity(query_emb, ref_emb_gpu, dim=1).item()
        results.append((name, sim))
    
    results.sort(key=lambda x: x[1], reverse=True)
    return [(f"dataset/{name}", f"Score: {score:.4f}") for name, score in results[:5]]

@spaces.GPU
def add_image(name: str, image):
    path = DATASET_DIR / f"{name}.jpg"
    image.save(path)
    
    # Use GPU for consistency if available
    device = "cuda" if torch.cuda.is_available() else "cpu"
    emb = get_embedding(image, device=device)
    
    # Store on CPU to save memory
    reference_embeddings[f"{name}.jpg"] = emb.cpu()
    
    with open(CACHE_FILE, "wb") as f:
        pickle.dump(reference_embeddings, f)
    return f"Image {name} added to dataset."

search_interface = gr.Interface(fn=search_similar,
                                inputs=gr.Image(type="pil", label="Query Image"),
                                outputs=gr.Gallery(label="Top Matches", columns=5),
                                allow_flagging="never")

add_interface = gr.Interface(fn=add_image,
                             inputs=[gr.Text(label="Image Name"), gr.Image(type="pil", label="Product Image")],
                             outputs="text",
                             allow_flagging="never")

demo = gr.TabbedInterface([search_interface, add_interface], tab_names=["Search", "Add Product"])
demo.launch()