File size: 3,953 Bytes
08adf1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import gradio as gr
import torch
import numpy as np
from PIL import Image
import pandas as pd
from pathlib import Path
from sklearn.metrics.pairwise import cosine_similarity

# Import your model and necessary functions
from src.config import ConfigManager
from src.token_classifier import load_token_classifier, predict
from your_model_file import YourModel  # Replace with your actual model import

# Load model and configurations
def load_model():
    model = YourModel()  # Initialize your model
    model.eval()
    return model

def load_dataset():
    # Load your default dataset
    database_df = pd.read_csv('database.csv')  # Adjust path as needed
    return database_df

def process_single_query(model, query_image_path, query_text, database_embeddings, database_df):
    device = ConfigManager().get("training")["device"]
    
    # Process query image
    query_img = model.processor(Image.open(query_image_path)).unsqueeze(0).to(device)
    
    # Get token classifier
    token_classifier, token_classifier_tokenizer = load_token_classifier(
        ConfigManager().get("paths")["pretrained_token_classifier_path"],
        device
    )
    
    with torch.no_grad():
        query_img_embd = model.feature_extractor.encode_image(query_img)
        
        # Process text query
        predictions = predict(
            tokens=query_text,
            model=token_classifier,
            tokenizer=token_classifier_tokenizer,
            device=device,
            max_length=128
        )
        
        # Process positive and negative objects
        pos = []
        neg = []
        last_tag = ''
        for token, label in predictions:
            if label == '<positive_object>':
                if last_tag != '<positive_object>':
                    pos.append(f"a photo of a {token}.")
                else:
                    pos[-1] = pos[-1][:-1] + f" {token}."
            elif label == '<negative_object>':
                if last_tag != '<negative_object>':
                    neg.append(f"a photo of a {token}.")
                else:
                    neg[-1] = neg[-1][:-1] + f" {token}."
            last_tag = label
            
        # Combine embeddings
        for obj in pos:
            query_img_embd += model.feature_extractor.encode_text(
                model.tokenizer(obj).to(device)
            )[0]
        for obj in neg:
            query_img_embd -= model.feature_extractor.encode_text(
                model.tokenizer(obj).to(device)
            )[0]
            
        query_img_embd = torch.nn.functional.normalize(query_img_embd, dim=1, p=2)
        
    # Calculate similarities
    query_embedding = query_img_embd.cpu().numpy()
    similarities = cosine_similarity(query_embedding, database_embeddings)[0]
    
    # Get most similar image
    most_similar_idx = np.argmax(similarities)
    most_similar_image_path = database_df.iloc[most_similar_idx]['target_image']
    
    return most_similar_image_path

# Initialize model and database
model = load_model()
database_df = load_dataset()
database_embeddings = encode_database(model, database_df)  # Using your existing function

def interface_fn(selected_image, query_text):
    result_image_path = process_single_query(
        model, 
        selected_image, 
        query_text, 
        database_embeddings, 
        database_df
    )
    return Image.open(result_image_path)

# Create Gradio interface
demo = gr.Interface(
    fn=interface_fn,
    inputs=[
        gr.Image(type="filepath", label="Select Query Image"),
        gr.Textbox(label="Enter Query Text")
    ],
    outputs=gr.Image(label="Retrieved Image"),
    title="Compositional Image Retrieval",
    description="Select an image and enter a text query to find the most similar image.",
    examples=[
        ["example_images/image1.jpg", "a red car"],
        ["example_images/image2.jpg", "a blue house"]
    ]
)

if __name__ == "__main__":
    demo.launch()