File size: 4,946 Bytes
08adf1a
 
 
 
 
 
a263f63
08adf1a
ffd2453
 
 
 
a263f63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ffd2453
08adf1a
 
 
b0bd1f4
 
08adf1a
 
 
 
 
 
 
 
 
 
 
ffd2453
08adf1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ffd2453
 
 
 
 
 
 
 
 
 
08adf1a
 
 
 
 
 
 
ffd2453
08adf1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import gradio as gr
import torch
import numpy as np
from PIL import Image
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm


from token_classifier import load_token_classifier, predict
from model import Model
from dataset import RetrievalDataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 512


def encode_database(model, df: pd.DataFrame) -> np.ndarray :
    """
    Process database images and generate embeddings.

    Args:
    df (pd. DataFrame ): DataFrame with column:
    - target_image: str, paths to database images

    Returns:
    np.ndarray: Embeddings array (num_images, embedding_dim)
    """
    model.eval()
    all_embeddings = []
    for i in tqdm(range(0, len(df), batch_size)):
        target_imgs = torch.stack([model.processor(Image.open(target_image_path)) for target_image_path in df['target_image'][i:i+batch_size]]).to(device)
        with torch.no_grad():
            # target_imgs_embedding = model.encode_database_image(target_imgs)
            target_imgs_embedding = model.feature_extractor.encode_image(target_imgs)
        target_imgs_embedding = torch.nn.functional.normalize(target_imgs_embedding, dim=1, p=2)
        all_embeddings.append(target_imgs_embedding.detach().cpu().numpy())
    return np.concatenate(all_embeddings)


# Load model and configurations
def load_model():
    model = Model(model_name="ViTamin-L-384", pretrained="datacomp1b")
    # model.load("weights.pth")
    model.eval()
    return model


def process_single_query(model, query_image_path, query_text, database_embeddings, database_df):
    
    # Process query image
    query_img = model.processor(Image.open(query_image_path)).unsqueeze(0).to(device)
    
    # Get token classifier
    token_classifier, token_classifier_tokenizer = load_token_classifier(
        "trained_distil_bert_base",
        device
    )
    
    with torch.no_grad():
        query_img_embd = model.feature_extractor.encode_image(query_img)
        
        # Process text query
        predictions = predict(
            tokens=query_text,
            model=token_classifier,
            tokenizer=token_classifier_tokenizer,
            device=device,
            max_length=128
        )
        
        # Process positive and negative objects
        pos = []
        neg = []
        last_tag = ''
        for token, label in predictions:
            if label == '<positive_object>':
                if last_tag != '<positive_object>':
                    pos.append(f"a photo of a {token}.")
                else:
                    pos[-1] = pos[-1][:-1] + f" {token}."
            elif label == '<negative_object>':
                if last_tag != '<negative_object>':
                    neg.append(f"a photo of a {token}.")
                else:
                    neg[-1] = neg[-1][:-1] + f" {token}."
            last_tag = label
            
        # Combine embeddings
        for obj in pos:
            query_img_embd += model.feature_extractor.encode_text(
                model.tokenizer(obj).to(device)
            )[0]
        for obj in neg:
            query_img_embd -= model.feature_extractor.encode_text(
                model.tokenizer(obj).to(device)
            )[0]
            
        query_img_embd = torch.nn.functional.normalize(query_img_embd, dim=1, p=2)
        
    # Calculate similarities
    query_embedding = query_img_embd.cpu().numpy()
    similarities = cosine_similarity(query_embedding, database_embeddings)[0]
    
    # Get most similar image
    most_similar_idx = np.argmax(similarities)
    most_similar_image_path = database_df.iloc[most_similar_idx]['target_image']
    
    return most_similar_image_path

# Initialize model and database
model = load_model()

test_dataset = RetrievalDataset(
    img_dir_path="sample_evaluation/images",
    annotations_file_path="sample_evaluation/data.csv",
    split='test',
    transform=model.processor,
    tokenizer=model.tokenizer
)

database_embeddings = encode_database(model, test_dataset.load_database())  # Using your existing function

def interface_fn(selected_image, query_text):
    result_image_path = process_single_query(
        model, 
        selected_image, 
        query_text, 
        database_embeddings, 
        test_dataset.load_database()
    )
    return Image.open(result_image_path)

# Create Gradio interface
demo = gr.Interface(
    fn=interface_fn,
    inputs=[
        gr.Image(type="filepath", label="Select Query Image"),
        gr.Textbox(label="Enter Query Text")
    ],
    outputs=gr.Image(label="Retrieved Image"),
    title="Compositional Image Retrieval",
    description="Select an image and enter a text query to find the most similar image.",
    examples=[
        ["example_images/image1.jpg", "a red car"],
        ["example_images/image2.jpg", "a blue house"]
    ]
)

if __name__ == "__main__":
    demo.launch()