import faiss import torch import clip import numpy as np from PIL import Image from fastapi import FastAPI from typing import List import segmentation import os device = "cpu" model, preprocess = clip.load("ViT-B/32", device=device, download_root="/tmp") def get_image_features(image: Image.Image) -> np.ndarray: """Extract CLIP features from an image.""" image_input = preprocess(image).unsqueeze(0).to(device) with torch.no_grad(): image_features = model.encode_image(image_input).float() return image_features.cpu().numpy() # FAISS setup index = faiss.IndexFlatIP(512) meta_data_store = [] def save_image_in_index(image_features: np.ndarray, metadata: dict): """Normalize features and add to index.""" faiss.normalize_L2(image_features) index.add(image_features) meta_data_store.append(metadata) def process_image_embedding(image_or_url, labels=['clothes']) -> np.ndarray: """Get feature embedding for a query image.""" search_image, search_detections = segmentation.grounded_segmentation(image=image_or_url, labels=labels) cropped_image = segmentation.cut_image(search_image, search_detections[0].mask, search_detections[0].box) # Convert to valid RGB if cropped_image.dtype != np.uint8: cropped_image = (cropped_image * 255).astype(np.uint8) if cropped_image.ndim == 2: cropped_image = np.stack([cropped_image] * 3, axis=-1) return Image.fromarray(cropped_image) def get_top_k_results(image_url: str, k: int = 10) -> List[dict]: """Find top-k similar images from the index.""" processed_image = process_image_embedding(image_url) image_search_embedding = get_image_features(processed_image) faiss.normalize_L2(image_search_embedding) distances, indices = index.search(image_search_embedding.reshape(1, -1), k) results = [] for i, dist in zip(indices[0], distances[0]): if i < len(meta_data_store): results.append({ 'metadata': meta_data_store[i], 'score': float(dist) }) return results