File size: 3,972 Bytes
08adf1a
 
 
 
 
 
 
ffd2453
 
 
 
 
 
08adf1a
 
 
ffd2453
 
08adf1a
 
 
 
 
ffd2453
08adf1a
 
 
 
 
 
ffd2453
08adf1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ffd2453
 
 
 
 
 
 
 
 
 
08adf1a
 
 
 
 
 
 
ffd2453
08adf1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
import torch
import numpy as np
from PIL import Image
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity


from token_classifier import load_token_classifier, predict
from model import Model
from dataset import RetrievalDataset
from generate_embeds import encode_database


# Load model and configurations
def load_model():
    model = Model(model_name="ViTamin-L-384", pretrained=None)
    model.load("weights.pth")
    model.eval()
    return model


def process_single_query(model, query_image_path, query_text, database_embeddings, database_df):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Process query image
    query_img = model.processor(Image.open(query_image_path)).unsqueeze(0).to(device)
    
    # Get token classifier
    token_classifier, token_classifier_tokenizer = load_token_classifier(
        "trained_distil_bert_base",
        device
    )
    
    with torch.no_grad():
        query_img_embd = model.feature_extractor.encode_image(query_img)
        
        # Process text query
        predictions = predict(
            tokens=query_text,
            model=token_classifier,
            tokenizer=token_classifier_tokenizer,
            device=device,
            max_length=128
        )
        
        # Process positive and negative objects
        pos = []
        neg = []
        last_tag = ''
        for token, label in predictions:
            if label == '<positive_object>':
                if last_tag != '<positive_object>':
                    pos.append(f"a photo of a {token}.")
                else:
                    pos[-1] = pos[-1][:-1] + f" {token}."
            elif label == '<negative_object>':
                if last_tag != '<negative_object>':
                    neg.append(f"a photo of a {token}.")
                else:
                    neg[-1] = neg[-1][:-1] + f" {token}."
            last_tag = label
            
        # Combine embeddings
        for obj in pos:
            query_img_embd += model.feature_extractor.encode_text(
                model.tokenizer(obj).to(device)
            )[0]
        for obj in neg:
            query_img_embd -= model.feature_extractor.encode_text(
                model.tokenizer(obj).to(device)
            )[0]
            
        query_img_embd = torch.nn.functional.normalize(query_img_embd, dim=1, p=2)
        
    # Calculate similarities
    query_embedding = query_img_embd.cpu().numpy()
    similarities = cosine_similarity(query_embedding, database_embeddings)[0]
    
    # Get most similar image
    most_similar_idx = np.argmax(similarities)
    most_similar_image_path = database_df.iloc[most_similar_idx]['target_image']
    
    return most_similar_image_path

# Initialize model and database
model = load_model()

test_dataset = RetrievalDataset(
    img_dir_path="sample_evaluation/images",
    annotations_file_path="sample_evaluation/data.csv",
    split='test',
    transform=model.processor,
    tokenizer=model.tokenizer
)

database_embeddings = encode_database(model, test_dataset.load_database())  # Using your existing function

def interface_fn(selected_image, query_text):
    result_image_path = process_single_query(
        model, 
        selected_image, 
        query_text, 
        database_embeddings, 
        test_dataset.load_database()
    )
    return Image.open(result_image_path)

# Create Gradio interface
demo = gr.Interface(
    fn=interface_fn,
    inputs=[
        gr.Image(type="filepath", label="Select Query Image"),
        gr.Textbox(label="Enter Query Text")
    ],
    outputs=gr.Image(label="Retrieved Image"),
    title="Compositional Image Retrieval",
    description="Select an image and enter a text query to find the most similar image.",
    examples=[
        ["example_images/image1.jpg", "a red car"],
        ["example_images/image2.jpg", "a blue house"]
    ]
)

if __name__ == "__main__":
    demo.launch()