File size: 6,182 Bytes
08adf1a
 
 
 
 
 
a263f63
08adf1a
ffd2453
 
 
 
a263f63
 
 
 
 
90ce7ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca23554
73a10b7
b05b11a
90ce7ca
 
a263f63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ffd2453
08adf1a
 
 
b05b11a
 
08adf1a
 
 
 
 
 
 
 
 
 
 
73a10b7
08adf1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ffd2453
 
 
 
 
 
 
 
 
 
08adf1a
 
 
 
 
 
 
ffd2453
08adf1a
 
 
 
 
 
 
c902692
 
08adf1a
c902692
08adf1a
 
 
b05b11a
 
de7af60
 
c902692
 
 
08adf1a
 
 
455daca
c902692
455daca
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import gradio as gr
import torch
import numpy as np
from PIL import Image
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm


from token_classifier import load_token_classifier, predict
from model import Model
from dataset import RetrievalDataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 512


import zipfile
import os

def unzip_file(zip_path, extract_path):
    # Create the target directory if it doesn't exist
    os.makedirs(extract_path, exist_ok=True)
    
    # Open the zip file
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        # Extract all contents to the specified directory
        zip_ref.extractall(extract_path)

# Example usage
zip_path = "sample_evaluation.zip"
extract_path = "sample_evaluation"
unzip_file(zip_path, extract_path)

from huggingface_hub import hf_hub_download
hf_hub_download(repo_id="safinal/compositional-image-retrieval", filename="weights.pth", local_dir='.')



def encode_database(model, df: pd.DataFrame) -> np.ndarray :
    """
    Process database images and generate embeddings.

    Args:
    df (pd. DataFrame ): DataFrame with column:
    - target_image: str, paths to database images

    Returns:
    np.ndarray: Embeddings array (num_images, embedding_dim)
    """
    model.eval()
    all_embeddings = []
    for i in tqdm(range(0, len(df), batch_size)):
        target_imgs = torch.stack([model.processor(Image.open(target_image_path)) for target_image_path in df['target_image'][i:i+batch_size]]).to(device)
        with torch.no_grad():
            # target_imgs_embedding = model.encode_database_image(target_imgs)
            target_imgs_embedding = model.feature_extractor.encode_image(target_imgs)
        target_imgs_embedding = torch.nn.functional.normalize(target_imgs_embedding, dim=1, p=2)
        all_embeddings.append(target_imgs_embedding.detach().cpu().numpy())
    return np.concatenate(all_embeddings)


# Load model and configurations
def load_model():
    model = Model(model_name="ViTamin-L-384", pretrained=None)
    model.load("weights.pth")
    model.eval()
    return model


def process_single_query(model, query_image_path, query_text, database_embeddings, database_df):
    
    # Process query image
    query_img = model.processor(Image.open(query_image_path)).unsqueeze(0).to(device)
    
    # Get token classifier
    token_classifier, token_classifier_tokenizer = load_token_classifier(
        "safinal/compositional-image-retrieval-token-classifier",
        device
    )
    
    with torch.no_grad():
        query_img_embd = model.feature_extractor.encode_image(query_img)
        
        # Process text query
        predictions = predict(
            tokens=query_text,
            model=token_classifier,
            tokenizer=token_classifier_tokenizer,
            device=device,
            max_length=128
        )
        
        # Process positive and negative objects
        pos = []
        neg = []
        last_tag = ''
        for token, label in predictions:
            if label == '<positive_object>':
                if last_tag != '<positive_object>':
                    pos.append(f"a photo of a {token}.")
                else:
                    pos[-1] = pos[-1][:-1] + f" {token}."
            elif label == '<negative_object>':
                if last_tag != '<negative_object>':
                    neg.append(f"a photo of a {token}.")
                else:
                    neg[-1] = neg[-1][:-1] + f" {token}."
            last_tag = label
            
        # Combine embeddings
        for obj in pos:
            query_img_embd += model.feature_extractor.encode_text(
                model.tokenizer(obj).to(device)
            )[0]
        for obj in neg:
            query_img_embd -= model.feature_extractor.encode_text(
                model.tokenizer(obj).to(device)
            )[0]
            
        query_img_embd = torch.nn.functional.normalize(query_img_embd, dim=1, p=2)
        
    # Calculate similarities
    query_embedding = query_img_embd.cpu().numpy()
    similarities = cosine_similarity(query_embedding, database_embeddings)[0]
    
    # Get most similar image
    most_similar_idx = np.argmax(similarities)
    most_similar_image_path = database_df.iloc[most_similar_idx]['target_image']
    
    return most_similar_image_path

# Initialize model and database
model = load_model()

test_dataset = RetrievalDataset(
    img_dir_path="sample_evaluation/images",
    annotations_file_path="sample_evaluation/data.csv",
    split='test',
    transform=model.processor,
    tokenizer=model.tokenizer
)

database_embeddings = encode_database(model, test_dataset.load_database())  # Using your existing function

def interface_fn(selected_image, query_text):
    result_image_path = process_single_query(
        model, 
        selected_image, 
        query_text, 
        database_embeddings, 
        test_dataset.load_database()
    )
    return Image.open(result_image_path)

# Create Gradio interface
demo = gr.Interface(
    fn=interface_fn,
    inputs=[
        gr.Image(type="filepath", label="Select Query Image", image_mode="RGB"),
        gr.Textbox(label="Enter Query Text", lines=2)
    ],
    outputs=gr.Image(label="Retrieved Image", type="pil"),
    title="Compositional Image Retrieval",
    description="Select an image and enter a text query to find the most similar image.",
    examples=[
        ["sample_evaluation/images/261684.png", "Bring cow into the picture, and then follow up with removing bench."],
        ["sample_evaluation/images/283700.png", "add bowl and bench and remove shoe and elephant"],
        ["sample_evaluation/images/455007.png", "Discard chair in the beginning, then proceed to bring car into play."],
        ["sample_evaluation/images/612311.png", "Get rid of train initially, and then follow up by including snowboard."]
    ],
    allow_flagging=False,
    cache_examples=False
)

if __name__ == "__main__":
    try:
        demo.queue().launch(server_name="0.0.0.0", server_port=7860)
    except Exception as e:
        print(f"Error launching app: {str(e)}")
        raise