File size: 2,088 Bytes
e8ee9c0
 
 
 
 
 
 
 
ea99a27
 
e8ee9c0
 
58597bf
e8ee9c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6f293b
e8ee9c0
f6f293b
e8ee9c0
 
 
 
 
 
 
 
f6f293b
e8ee9c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6f293b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import faiss
import torch
import clip
import numpy as np
from PIL import Image
from fastapi import FastAPI
from typing import List
import segmentation
import os


device = "cpu"
model, preprocess = clip.load("ViT-B/32", device=device, download_root="/tmp")

def get_image_features(image: Image.Image) -> np.ndarray:
    """Extract CLIP features from an image."""
    image_input = preprocess(image).unsqueeze(0).to(device)
    with torch.no_grad():
        image_features = model.encode_image(image_input).float()
    return image_features.cpu().numpy()

# FAISS setup
index = faiss.IndexFlatIP(512)
meta_data_store = []

def save_image_in_index(image_features: np.ndarray, metadata: dict):
    """Normalize features and add to index."""
    faiss.normalize_L2(image_features)
    index.add(image_features)
    meta_data_store.append(metadata)

def process_image_embedding(image_or_url, labels=['clothes']) -> np.ndarray:
    """Get feature embedding for a query image."""
    search_image, search_detections = segmentation.grounded_segmentation(image=image_or_url, labels=labels)
    cropped_image = segmentation.cut_image(search_image, search_detections[0].mask, search_detections[0].box)

    # Convert to valid RGB
    if cropped_image.dtype != np.uint8:
        cropped_image = (cropped_image * 255).astype(np.uint8)
    if cropped_image.ndim == 2:
        cropped_image = np.stack([cropped_image] * 3, axis=-1)

    return Image.fromarray(cropped_image)

def get_top_k_results(image_url: str, k: int = 10) -> List[dict]:
    """Find top-k similar images from the index."""
    processed_image = process_image_embedding(image_url)
    image_search_embedding = get_image_features(processed_image)
    faiss.normalize_L2(image_search_embedding)
    distances, indices = index.search(image_search_embedding.reshape(1, -1), k)

    results = []
    for i, dist in zip(indices[0], distances[0]):
        if i < len(meta_data_store):
            results.append({
                'metadata': meta_data_store[i],
                'score': float(dist)
            })
    return results