|
import faiss |
|
import torch |
|
import clip |
|
import numpy as np |
|
from PIL import Image |
|
from fastapi import FastAPI |
|
from typing import List |
|
import segmentation |
|
import os |
|
|
|
|
|
device = "cpu" |
|
model, preprocess = clip.load("ViT-B/32", device=device, download_root="/tmp") |
|
|
|
def get_image_features(image: Image.Image) -> np.ndarray: |
|
"""Extract CLIP features from an image.""" |
|
image_input = preprocess(image).unsqueeze(0).to(device) |
|
with torch.no_grad(): |
|
image_features = model.encode_image(image_input).float() |
|
return image_features.cpu().numpy() |
|
|
|
|
|
index = faiss.IndexFlatIP(512) |
|
meta_data_store = [] |
|
|
|
def save_image_in_index(image_features: np.ndarray, metadata: dict): |
|
"""Normalize features and add to index.""" |
|
faiss.normalize_L2(image_features) |
|
index.add(image_features) |
|
meta_data_store.append(metadata) |
|
|
|
def process_image_embedding(image_or_url, labels=['clothes']) -> np.ndarray: |
|
"""Get feature embedding for a query image.""" |
|
search_image, search_detections = segmentation.grounded_segmentation(image=image_or_url, labels=labels) |
|
cropped_image = segmentation.cut_image(search_image, search_detections[0].mask, search_detections[0].box) |
|
|
|
|
|
if cropped_image.dtype != np.uint8: |
|
cropped_image = (cropped_image * 255).astype(np.uint8) |
|
if cropped_image.ndim == 2: |
|
cropped_image = np.stack([cropped_image] * 3, axis=-1) |
|
|
|
return Image.fromarray(cropped_image) |
|
|
|
def get_top_k_results(image_url: str, k: int = 10) -> List[dict]: |
|
"""Find top-k similar images from the index.""" |
|
processed_image = process_image_embedding(image_url) |
|
image_search_embedding = get_image_features(processed_image) |
|
faiss.normalize_L2(image_search_embedding) |
|
distances, indices = index.search(image_search_embedding.reshape(1, -1), k) |
|
|
|
results = [] |
|
for i, dist in zip(indices[0], distances[0]): |
|
if i < len(meta_data_store): |
|
results.append({ |
|
'metadata': meta_data_store[i], |
|
'score': float(dist) |
|
}) |
|
return results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|