import gradio as gr import torch import numpy as np from PIL import Image import pandas as pd from sklearn.metrics.pairwise import cosine_similarity from tqdm import tqdm from token_classifier import load_token_classifier, predict from model import Model from dataset import RetrievalDataset device = torch.device("cuda" if torch.cuda.is_available() else "cpu") batch_size = 512 import zipfile import os def unzip_file(zip_path, extract_path): # Create the target directory if it doesn't exist os.makedirs(extract_path, exist_ok=True) # Open the zip file with zipfile.ZipFile(zip_path, 'r') as zip_ref: # Extract all contents to the specified directory zip_ref.extractall(extract_path) # Example usage zip_path = "sample_evaluation.zip" extract_path = "sample_evaluation" unzip_file(zip_path, extract_path) from huggingface_hub import hf_hub_download hf_hub_download(repo_id="safinal/compositional-image-retrieval", filename="weights.pth", local_dir='.') def encode_database(model, df: pd.DataFrame) -> np.ndarray : """ Process database images and generate embeddings. Args: df (pd. DataFrame ): DataFrame with column: - target_image: str, paths to database images Returns: np.ndarray: Embeddings array (num_images, embedding_dim) """ model.eval() all_embeddings = [] for i in tqdm(range(0, len(df), batch_size)): target_imgs = torch.stack([model.processor(Image.open(target_image_path)) for target_image_path in df['target_image'][i:i+batch_size]]).to(device) with torch.no_grad(): # target_imgs_embedding = model.encode_database_image(target_imgs) target_imgs_embedding = model.feature_extractor.encode_image(target_imgs) target_imgs_embedding = torch.nn.functional.normalize(target_imgs_embedding, dim=1, p=2) all_embeddings.append(target_imgs_embedding.detach().cpu().numpy()) return np.concatenate(all_embeddings) # Load model and configurations def load_model(): model = Model(model_name="ViTamin-L-384", pretrained=None) model.load("weights.pth") model.eval() return model def process_single_query(model, query_image_path, query_text, database_embeddings, database_df): # Process query image query_img = model.processor(Image.open(query_image_path)).unsqueeze(0).to(device) # Get token classifier token_classifier, token_classifier_tokenizer = load_token_classifier( "safinal/compositional-image-retrieval-token-classifier", device ) with torch.no_grad(): query_img_embd = model.feature_extractor.encode_image(query_img) # Process text query predictions = predict( tokens=query_text, model=token_classifier, tokenizer=token_classifier_tokenizer, device=device, max_length=128 ) # Process positive and negative objects pos = [] neg = [] last_tag = '' for token, label in predictions: if label == '': if last_tag != '': pos.append(f"a photo of a {token}.") else: pos[-1] = pos[-1][:-1] + f" {token}." elif label == '': if last_tag != '': neg.append(f"a photo of a {token}.") else: neg[-1] = neg[-1][:-1] + f" {token}." last_tag = label # Combine embeddings for obj in pos: query_img_embd += model.feature_extractor.encode_text( model.tokenizer(obj).to(device) )[0] for obj in neg: query_img_embd -= model.feature_extractor.encode_text( model.tokenizer(obj).to(device) )[0] query_img_embd = torch.nn.functional.normalize(query_img_embd, dim=1, p=2) # Calculate similarities query_embedding = query_img_embd.cpu().numpy() similarities = cosine_similarity(query_embedding, database_embeddings)[0] # Get most similar image most_similar_idx = np.argmax(similarities) most_similar_image_path = database_df.iloc[most_similar_idx]['target_image'] return most_similar_image_path # Initialize model and database model = load_model() test_dataset = RetrievalDataset( img_dir_path="sample_evaluation/images", annotations_file_path="sample_evaluation/data.csv", split='test', transform=model.processor, tokenizer=model.tokenizer ) database_embeddings = encode_database(model, test_dataset.load_database()) # Using your existing function def interface_fn(selected_image, query_text): result_image_path = process_single_query( model, selected_image, query_text, database_embeddings, test_dataset.load_database() ) return Image.open(result_image_path) # Create Gradio interface demo = gr.Interface( fn=interface_fn, inputs=[ gr.Image(type="filepath", label="Select Query Image", image_mode="RGB"), gr.Textbox(label="Enter Query Text", lines=2) ], outputs=gr.Image(label="Retrieved Image", type="pil"), title="Compositional Image Retrieval", description="Select an image and enter a text query to find the most similar image.", examples=[ ["sample_evaluation/images/261684.png", "Bring cow into the picture, and then follow up with removing bench."], ["sample_evaluation/images/283700.png", "add bowl and bench and remove shoe and elephant"], ["sample_evaluation/images/455007.png", "Discard chair in the beginning, then proceed to bring car into play."], ["sample_evaluation/images/612311.png", "Get rid of train initially, and then follow up by including snowboard."] ], allow_flagging=False, cache_examples=False ) if __name__ == "__main__": try: demo.queue().launch(server_name="0.0.0.0", server_port=7860) except Exception as e: print(f"Error launching app: {str(e)}") raise