getitem / src /model.py
dheena
SYS-0000
58597bf
import faiss
import torch
import clip
import numpy as np
from PIL import Image
from fastapi import FastAPI
from typing import List
import segmentation
import os
device = "cpu"
model, preprocess = clip.load("ViT-B/32", device=device, download_root="/tmp")
def get_image_features(image: Image.Image) -> np.ndarray:
"""Extract CLIP features from an image."""
image_input = preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
image_features = model.encode_image(image_input).float()
return image_features.cpu().numpy()
# FAISS setup
index = faiss.IndexFlatIP(512)
meta_data_store = []
def save_image_in_index(image_features: np.ndarray, metadata: dict):
"""Normalize features and add to index."""
faiss.normalize_L2(image_features)
index.add(image_features)
meta_data_store.append(metadata)
def process_image_embedding(image_or_url, labels=['clothes']) -> np.ndarray:
"""Get feature embedding for a query image."""
search_image, search_detections = segmentation.grounded_segmentation(image=image_or_url, labels=labels)
cropped_image = segmentation.cut_image(search_image, search_detections[0].mask, search_detections[0].box)
# Convert to valid RGB
if cropped_image.dtype != np.uint8:
cropped_image = (cropped_image * 255).astype(np.uint8)
if cropped_image.ndim == 2:
cropped_image = np.stack([cropped_image] * 3, axis=-1)
return Image.fromarray(cropped_image)
def get_top_k_results(image_url: str, k: int = 10) -> List[dict]:
"""Find top-k similar images from the index."""
processed_image = process_image_embedding(image_url)
image_search_embedding = get_image_features(processed_image)
faiss.normalize_L2(image_search_embedding)
distances, indices = index.search(image_search_embedding.reshape(1, -1), k)
results = []
for i, dist in zip(indices[0], distances[0]):
if i < len(meta_data_store):
results.append({
'metadata': meta_data_store[i],
'score': float(dist)
})
return results