getitem / src /model.py
dheena
initial commit
e8ee9c0
raw
history blame
2.12 kB
import faiss
import torch
import clip
from openai import OpenAI
import numpy as np
from PIL import Image
from fastapi import FastAPI
from typing import List
import segmentation
client = OpenAI()
device = "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
def get_image_features(image: Image.Image) -> np.ndarray:
"""Extract CLIP features from an image."""
image_input = preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
image_features = model.encode_image(image_input).float()
return image_features.cpu().numpy()
# FAISS setup
index = faiss.IndexFlatIP(512)
meta_data_store = []
def save_image_in_index(image_features: np.ndarray, metadata: dict):
"""Normalize features and add to index."""
faiss.normalize_L2(image_features)
index.add(image_features)
meta_data_store.append(metadata)
def process_image_embedding(image_url: str, labels=['clothes']) -> np.ndarray:
"""Get feature embedding for a query image."""
search_image, search_detections = segmentation.grounded_segmentation(image=image_url, labels=labels)
cropped_image = segmentation.cut_image(search_image, search_detections[0].mask, search_detections[0].box)
# Convert to valid RGB
if cropped_image.dtype != np.uint8:
cropped_image = (cropped_image * 255).astype(np.uint8)
if cropped_image.ndim == 2:
cropped_image = np.stack([cropped_image] * 3, axis=-1)
pil_image = Image.fromarray(cropped_image)
return pil_image
def get_top_k_results(image_url: str, k: int = 10) -> List[dict]:
"""Find top-k similar images from the index."""
processed_image = process_image_embedding(image_url)
image_search_embedding = get_image_features(processed_image)
faiss.normalize_L2(image_search_embedding)
distances, indices = index.search(image_search_embedding.reshape(1, -1), k)
results = []
for i, dist in zip(indices[0], distances[0]):
if i < len(meta_data_store):
results.append({
'metadata': meta_data_store[i],
'score': float(dist)
})
return results