Spaces:
Runtime error
Runtime error
File size: 4,946 Bytes
08adf1a a263f63 08adf1a ffd2453 a263f63 ffd2453 08adf1a b0bd1f4 08adf1a ffd2453 08adf1a ffd2453 08adf1a ffd2453 08adf1a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import gradio as gr
import torch
import numpy as np
from PIL import Image
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
from token_classifier import load_token_classifier, predict
from model import Model
from dataset import RetrievalDataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 512
def encode_database(model, df: pd.DataFrame) -> np.ndarray :
"""
Process database images and generate embeddings.
Args:
df (pd. DataFrame ): DataFrame with column:
- target_image: str, paths to database images
Returns:
np.ndarray: Embeddings array (num_images, embedding_dim)
"""
model.eval()
all_embeddings = []
for i in tqdm(range(0, len(df), batch_size)):
target_imgs = torch.stack([model.processor(Image.open(target_image_path)) for target_image_path in df['target_image'][i:i+batch_size]]).to(device)
with torch.no_grad():
# target_imgs_embedding = model.encode_database_image(target_imgs)
target_imgs_embedding = model.feature_extractor.encode_image(target_imgs)
target_imgs_embedding = torch.nn.functional.normalize(target_imgs_embedding, dim=1, p=2)
all_embeddings.append(target_imgs_embedding.detach().cpu().numpy())
return np.concatenate(all_embeddings)
# Load model and configurations
def load_model():
model = Model(model_name="ViTamin-L-384", pretrained="datacomp1b")
# model.load("weights.pth")
model.eval()
return model
def process_single_query(model, query_image_path, query_text, database_embeddings, database_df):
# Process query image
query_img = model.processor(Image.open(query_image_path)).unsqueeze(0).to(device)
# Get token classifier
token_classifier, token_classifier_tokenizer = load_token_classifier(
"trained_distil_bert_base",
device
)
with torch.no_grad():
query_img_embd = model.feature_extractor.encode_image(query_img)
# Process text query
predictions = predict(
tokens=query_text,
model=token_classifier,
tokenizer=token_classifier_tokenizer,
device=device,
max_length=128
)
# Process positive and negative objects
pos = []
neg = []
last_tag = ''
for token, label in predictions:
if label == '<positive_object>':
if last_tag != '<positive_object>':
pos.append(f"a photo of a {token}.")
else:
pos[-1] = pos[-1][:-1] + f" {token}."
elif label == '<negative_object>':
if last_tag != '<negative_object>':
neg.append(f"a photo of a {token}.")
else:
neg[-1] = neg[-1][:-1] + f" {token}."
last_tag = label
# Combine embeddings
for obj in pos:
query_img_embd += model.feature_extractor.encode_text(
model.tokenizer(obj).to(device)
)[0]
for obj in neg:
query_img_embd -= model.feature_extractor.encode_text(
model.tokenizer(obj).to(device)
)[0]
query_img_embd = torch.nn.functional.normalize(query_img_embd, dim=1, p=2)
# Calculate similarities
query_embedding = query_img_embd.cpu().numpy()
similarities = cosine_similarity(query_embedding, database_embeddings)[0]
# Get most similar image
most_similar_idx = np.argmax(similarities)
most_similar_image_path = database_df.iloc[most_similar_idx]['target_image']
return most_similar_image_path
# Initialize model and database
model = load_model()
test_dataset = RetrievalDataset(
img_dir_path="sample_evaluation/images",
annotations_file_path="sample_evaluation/data.csv",
split='test',
transform=model.processor,
tokenizer=model.tokenizer
)
database_embeddings = encode_database(model, test_dataset.load_database()) # Using your existing function
def interface_fn(selected_image, query_text):
result_image_path = process_single_query(
model,
selected_image,
query_text,
database_embeddings,
test_dataset.load_database()
)
return Image.open(result_image_path)
# Create Gradio interface
demo = gr.Interface(
fn=interface_fn,
inputs=[
gr.Image(type="filepath", label="Select Query Image"),
gr.Textbox(label="Enter Query Text")
],
outputs=gr.Image(label="Retrieved Image"),
title="Compositional Image Retrieval",
description="Select an image and enter a text query to find the most similar image.",
examples=[
["example_images/image1.jpg", "a red car"],
["example_images/image2.jpg", "a blue house"]
]
)
if __name__ == "__main__":
demo.launch() |