Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import numpy as np | |
from PIL import Image | |
import pandas as pd | |
from sklearn.metrics.pairwise import cosine_similarity | |
from tqdm import tqdm | |
from token_classifier import load_token_classifier, predict | |
from model import Model | |
from dataset import RetrievalDataset | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
batch_size = 512 | |
import zipfile | |
import os | |
def unzip_file(zip_path, extract_path): | |
# Create the target directory if it doesn't exist | |
os.makedirs(extract_path, exist_ok=True) | |
# Open the zip file | |
with zipfile.ZipFile(zip_path, 'r') as zip_ref: | |
# Extract all contents to the specified directory | |
zip_ref.extractall(extract_path) | |
# Example usage | |
zip_path = "sample_evaluation.zip" | |
extract_path = "sample_evaluation" | |
unzip_file(zip_path, extract_path) | |
from huggingface_hub import hf_hub_download | |
hf_hub_download(repo_id="safinal/compositional-retrieval", filename="weights.pth", local_dir='.') | |
def encode_database(model, df: pd.DataFrame) -> np.ndarray : | |
""" | |
Process database images and generate embeddings. | |
Args: | |
df (pd. DataFrame ): DataFrame with column: | |
- target_image: str, paths to database images | |
Returns: | |
np.ndarray: Embeddings array (num_images, embedding_dim) | |
""" | |
model.eval() | |
all_embeddings = [] | |
for i in tqdm(range(0, len(df), batch_size)): | |
target_imgs = torch.stack([model.processor(Image.open(target_image_path)) for target_image_path in df['target_image'][i:i+batch_size]]).to(device) | |
with torch.no_grad(): | |
# target_imgs_embedding = model.encode_database_image(target_imgs) | |
target_imgs_embedding = model.feature_extractor.encode_image(target_imgs) | |
target_imgs_embedding = torch.nn.functional.normalize(target_imgs_embedding, dim=1, p=2) | |
all_embeddings.append(target_imgs_embedding.detach().cpu().numpy()) | |
return np.concatenate(all_embeddings) | |
# Load model and configurations | |
def load_model(): | |
model = Model(model_name="ViTamin-L-384", pretrained=None) | |
model.load("weights.pth") | |
model.eval() | |
return model | |
def process_single_query(model, query_image_path, query_text, database_embeddings, database_df): | |
# Process query image | |
query_img = model.processor(Image.open(query_image_path)).unsqueeze(0).to(device) | |
# Get token classifier | |
token_classifier, token_classifier_tokenizer = load_token_classifier( | |
"trained_distil_bert_base", | |
device | |
) | |
with torch.no_grad(): | |
query_img_embd = model.feature_extractor.encode_image(query_img) | |
# Process text query | |
predictions = predict( | |
tokens=query_text, | |
model=token_classifier, | |
tokenizer=token_classifier_tokenizer, | |
device=device, | |
max_length=128 | |
) | |
# Process positive and negative objects | |
pos = [] | |
neg = [] | |
last_tag = '' | |
for token, label in predictions: | |
if label == '<positive_object>': | |
if last_tag != '<positive_object>': | |
pos.append(f"a photo of a {token}.") | |
else: | |
pos[-1] = pos[-1][:-1] + f" {token}." | |
elif label == '<negative_object>': | |
if last_tag != '<negative_object>': | |
neg.append(f"a photo of a {token}.") | |
else: | |
neg[-1] = neg[-1][:-1] + f" {token}." | |
last_tag = label | |
# Combine embeddings | |
for obj in pos: | |
query_img_embd += model.feature_extractor.encode_text( | |
model.tokenizer(obj).to(device) | |
)[0] | |
for obj in neg: | |
query_img_embd -= model.feature_extractor.encode_text( | |
model.tokenizer(obj).to(device) | |
)[0] | |
query_img_embd = torch.nn.functional.normalize(query_img_embd, dim=1, p=2) | |
# Calculate similarities | |
query_embedding = query_img_embd.cpu().numpy() | |
similarities = cosine_similarity(query_embedding, database_embeddings)[0] | |
# Get most similar image | |
most_similar_idx = np.argmax(similarities) | |
most_similar_image_path = database_df.iloc[most_similar_idx]['target_image'] | |
return most_similar_image_path | |
# Initialize model and database | |
model = load_model() | |
test_dataset = RetrievalDataset( | |
img_dir_path="sample_evaluation/images", | |
annotations_file_path="sample_evaluation/data.csv", | |
split='test', | |
transform=model.processor, | |
tokenizer=model.tokenizer | |
) | |
database_embeddings = encode_database(model, test_dataset.load_database()) # Using your existing function | |
def interface_fn(selected_image, query_text): | |
result_image_path = process_single_query( | |
model, | |
selected_image, | |
query_text, | |
database_embeddings, | |
test_dataset.load_database() | |
) | |
return Image.open(result_image_path) | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=interface_fn, | |
inputs=[ | |
gr.Image(type="filepath", label="Select Query Image"), | |
gr.Textbox(label="Enter Query Text") | |
], | |
outputs=gr.Image(label="Retrieved Image"), | |
title="Compositional Image Retrieval", | |
description="Select an image and enter a text query to find the most similar image.", | |
examples=[ | |
["sample_evaluation/images/261684.png", "Bring cow into the picture, and then follow up with removing bench."], | |
["sample_evaluation/images/283700.png", "add bowl and bench and remove shoe and elephant"], | |
["sample_evaluation/images/455007.png", "Discard chair in the beginning, then proceed to bring car into play."], | |
["sample_evaluation/images/612311.png", "Get rid of train initially, and then follow up by including snowboard."] | |
] | |
) | |
if __name__ == "__main__": | |
demo.launch() |