File size: 2,088 Bytes
e8ee9c0 ea99a27 e8ee9c0 58597bf e8ee9c0 f6f293b e8ee9c0 f6f293b e8ee9c0 f6f293b e8ee9c0 f6f293b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import faiss
import torch
import clip
import numpy as np
from PIL import Image
from fastapi import FastAPI
from typing import List
import segmentation
import os
device = "cpu"
model, preprocess = clip.load("ViT-B/32", device=device, download_root="/tmp")
def get_image_features(image: Image.Image) -> np.ndarray:
"""Extract CLIP features from an image."""
image_input = preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
image_features = model.encode_image(image_input).float()
return image_features.cpu().numpy()
# FAISS setup
index = faiss.IndexFlatIP(512)
meta_data_store = []
def save_image_in_index(image_features: np.ndarray, metadata: dict):
"""Normalize features and add to index."""
faiss.normalize_L2(image_features)
index.add(image_features)
meta_data_store.append(metadata)
def process_image_embedding(image_or_url, labels=['clothes']) -> np.ndarray:
"""Get feature embedding for a query image."""
search_image, search_detections = segmentation.grounded_segmentation(image=image_or_url, labels=labels)
cropped_image = segmentation.cut_image(search_image, search_detections[0].mask, search_detections[0].box)
# Convert to valid RGB
if cropped_image.dtype != np.uint8:
cropped_image = (cropped_image * 255).astype(np.uint8)
if cropped_image.ndim == 2:
cropped_image = np.stack([cropped_image] * 3, axis=-1)
return Image.fromarray(cropped_image)
def get_top_k_results(image_url: str, k: int = 10) -> List[dict]:
"""Find top-k similar images from the index."""
processed_image = process_image_embedding(image_url)
image_search_embedding = get_image_features(processed_image)
faiss.normalize_L2(image_search_embedding)
distances, indices = index.search(image_search_embedding.reshape(1, -1), k)
results = []
for i, dist in zip(indices[0], distances[0]):
if i < len(meta_data_store):
results.append({
'metadata': meta_data_store[i],
'score': float(dist)
})
return results
|