Spaces:
Runtime error
Runtime error
File size: 3,953 Bytes
08adf1a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import gradio as gr
import torch
import numpy as np
from PIL import Image
import pandas as pd
from pathlib import Path
from sklearn.metrics.pairwise import cosine_similarity
# Import your model and necessary functions
from src.config import ConfigManager
from src.token_classifier import load_token_classifier, predict
from your_model_file import YourModel # Replace with your actual model import
# Load model and configurations
def load_model():
model = YourModel() # Initialize your model
model.eval()
return model
def load_dataset():
# Load your default dataset
database_df = pd.read_csv('database.csv') # Adjust path as needed
return database_df
def process_single_query(model, query_image_path, query_text, database_embeddings, database_df):
device = ConfigManager().get("training")["device"]
# Process query image
query_img = model.processor(Image.open(query_image_path)).unsqueeze(0).to(device)
# Get token classifier
token_classifier, token_classifier_tokenizer = load_token_classifier(
ConfigManager().get("paths")["pretrained_token_classifier_path"],
device
)
with torch.no_grad():
query_img_embd = model.feature_extractor.encode_image(query_img)
# Process text query
predictions = predict(
tokens=query_text,
model=token_classifier,
tokenizer=token_classifier_tokenizer,
device=device,
max_length=128
)
# Process positive and negative objects
pos = []
neg = []
last_tag = ''
for token, label in predictions:
if label == '<positive_object>':
if last_tag != '<positive_object>':
pos.append(f"a photo of a {token}.")
else:
pos[-1] = pos[-1][:-1] + f" {token}."
elif label == '<negative_object>':
if last_tag != '<negative_object>':
neg.append(f"a photo of a {token}.")
else:
neg[-1] = neg[-1][:-1] + f" {token}."
last_tag = label
# Combine embeddings
for obj in pos:
query_img_embd += model.feature_extractor.encode_text(
model.tokenizer(obj).to(device)
)[0]
for obj in neg:
query_img_embd -= model.feature_extractor.encode_text(
model.tokenizer(obj).to(device)
)[0]
query_img_embd = torch.nn.functional.normalize(query_img_embd, dim=1, p=2)
# Calculate similarities
query_embedding = query_img_embd.cpu().numpy()
similarities = cosine_similarity(query_embedding, database_embeddings)[0]
# Get most similar image
most_similar_idx = np.argmax(similarities)
most_similar_image_path = database_df.iloc[most_similar_idx]['target_image']
return most_similar_image_path
# Initialize model and database
model = load_model()
database_df = load_dataset()
database_embeddings = encode_database(model, database_df) # Using your existing function
def interface_fn(selected_image, query_text):
result_image_path = process_single_query(
model,
selected_image,
query_text,
database_embeddings,
database_df
)
return Image.open(result_image_path)
# Create Gradio interface
demo = gr.Interface(
fn=interface_fn,
inputs=[
gr.Image(type="filepath", label="Select Query Image"),
gr.Textbox(label="Enter Query Text")
],
outputs=gr.Image(label="Retrieved Image"),
title="Compositional Image Retrieval",
description="Select an image and enter a text query to find the most similar image.",
examples=[
["example_images/image1.jpg", "a red car"],
["example_images/image2.jpg", "a blue house"]
]
)
if __name__ == "__main__":
demo.launch() |