import gradio as gr import torch import numpy as np from PIL import Image import pandas as pd from sklearn.metrics.pairwise import cosine_similarity from tqdm import tqdm from token_classifier import load_token_classifier, predict from model import Model from dataset import RetrievalDataset device = torch.device("cuda" if torch.cuda.is_available() else "cpu") batch_size = 512 def encode_database(model, df: pd.DataFrame) -> np.ndarray : """ Process database images and generate embeddings. Args: df (pd. DataFrame ): DataFrame with column: - target_image: str, paths to database images Returns: np.ndarray: Embeddings array (num_images, embedding_dim) """ model.eval() all_embeddings = [] for i in tqdm(range(0, len(df), batch_size)): target_imgs = torch.stack([model.processor(Image.open(target_image_path)) for target_image_path in df['target_image'][i:i+batch_size]]).to(device) with torch.no_grad(): # target_imgs_embedding = model.encode_database_image(target_imgs) target_imgs_embedding = model.feature_extractor.encode_image(target_imgs) target_imgs_embedding = torch.nn.functional.normalize(target_imgs_embedding, dim=1, p=2) all_embeddings.append(target_imgs_embedding.detach().cpu().numpy()) return np.concatenate(all_embeddings) # Load model and configurations def load_model(): model = Model(model_name="ViTamin-L-384", pretrained="datacomp1b") # model.load("weights.pth") model.eval() return model def process_single_query(model, query_image_path, query_text, database_embeddings, database_df): # Process query image query_img = model.processor(Image.open(query_image_path)).unsqueeze(0).to(device) # Get token classifier token_classifier, token_classifier_tokenizer = load_token_classifier( "trained_distil_bert_base", device ) with torch.no_grad(): query_img_embd = model.feature_extractor.encode_image(query_img) # Process text query predictions = predict( tokens=query_text, model=token_classifier, tokenizer=token_classifier_tokenizer, device=device, max_length=128 ) # Process positive and negative objects pos = [] neg = [] last_tag = '' for token, label in predictions: if label == '': if last_tag != '': pos.append(f"a photo of a {token}.") else: pos[-1] = pos[-1][:-1] + f" {token}." elif label == '': if last_tag != '': neg.append(f"a photo of a {token}.") else: neg[-1] = neg[-1][:-1] + f" {token}." last_tag = label # Combine embeddings for obj in pos: query_img_embd += model.feature_extractor.encode_text( model.tokenizer(obj).to(device) )[0] for obj in neg: query_img_embd -= model.feature_extractor.encode_text( model.tokenizer(obj).to(device) )[0] query_img_embd = torch.nn.functional.normalize(query_img_embd, dim=1, p=2) # Calculate similarities query_embedding = query_img_embd.cpu().numpy() similarities = cosine_similarity(query_embedding, database_embeddings)[0] # Get most similar image most_similar_idx = np.argmax(similarities) most_similar_image_path = database_df.iloc[most_similar_idx]['target_image'] return most_similar_image_path # Initialize model and database model = load_model() test_dataset = RetrievalDataset( img_dir_path="sample_evaluation/images", annotations_file_path="sample_evaluation/data.csv", split='test', transform=model.processor, tokenizer=model.tokenizer ) database_embeddings = encode_database(model, test_dataset.load_database()) # Using your existing function def interface_fn(selected_image, query_text): result_image_path = process_single_query( model, selected_image, query_text, database_embeddings, test_dataset.load_database() ) return Image.open(result_image_path) # Create Gradio interface demo = gr.Interface( fn=interface_fn, inputs=[ gr.Image(type="filepath", label="Select Query Image"), gr.Textbox(label="Enter Query Text") ], outputs=gr.Image(label="Retrieved Image"), title="Compositional Image Retrieval", description="Select an image and enter a text query to find the most similar image.", examples=[ ["example_images/image1.jpg", "a red car"], ["example_images/image2.jpg", "a blue house"] ] ) if __name__ == "__main__": demo.launch()