getitem / src /model.py
dheena
SYS-0000
ea99a27
raw
history blame
2.09 kB
import faiss
import torch
import clip
import numpy as np
from PIL import Image
from fastapi import FastAPI
from typing import List
import segmentation
import os
device = "cpu"
model, preprocess = clip.load("ViT-B/32", device=device, download_root=os.getcwd())
def get_image_features(image: Image.Image) -> np.ndarray:
"""Extract CLIP features from an image."""
image_input = preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
image_features = model.encode_image(image_input).float()
return image_features.cpu().numpy()
# FAISS setup
index = faiss.IndexFlatIP(512)
meta_data_store = []
def save_image_in_index(image_features: np.ndarray, metadata: dict):
"""Normalize features and add to index."""
faiss.normalize_L2(image_features)
index.add(image_features)
meta_data_store.append(metadata)
def process_image_embedding(image_or_url, labels=['clothes']) -> np.ndarray:
"""Get feature embedding for a query image."""
search_image, search_detections = segmentation.grounded_segmentation(image=image_or_url, labels=labels)
cropped_image = segmentation.cut_image(search_image, search_detections[0].mask, search_detections[0].box)
# Convert to valid RGB
if cropped_image.dtype != np.uint8:
cropped_image = (cropped_image * 255).astype(np.uint8)
if cropped_image.ndim == 2:
cropped_image = np.stack([cropped_image] * 3, axis=-1)
return Image.fromarray(cropped_image)
def get_top_k_results(image_url: str, k: int = 10) -> List[dict]:
"""Find top-k similar images from the index."""
processed_image = process_image_embedding(image_url)
image_search_embedding = get_image_features(processed_image)
faiss.normalize_L2(image_search_embedding)
distances, indices = index.search(image_search_embedding.reshape(1, -1), k)
results = []
for i, dist in zip(indices[0], distances[0]):
if i < len(meta_data_store):
results.append({
'metadata': meta_data_store[i],
'score': float(dist)
})
return results