safinal commited on
Commit
08adf1a
Β·
verified Β·
1 Parent(s): 3155f7d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import pandas as pd
6
+ from pathlib import Path
7
+ from sklearn.metrics.pairwise import cosine_similarity
8
+
9
+ # Import your model and necessary functions
10
+ from src.config import ConfigManager
11
+ from src.token_classifier import load_token_classifier, predict
12
+ from your_model_file import YourModel # Replace with your actual model import
13
+
14
+ # Load model and configurations
15
+ def load_model():
16
+ model = YourModel() # Initialize your model
17
+ model.eval()
18
+ return model
19
+
20
+ def load_dataset():
21
+ # Load your default dataset
22
+ database_df = pd.read_csv('database.csv') # Adjust path as needed
23
+ return database_df
24
+
25
+ def process_single_query(model, query_image_path, query_text, database_embeddings, database_df):
26
+ device = ConfigManager().get("training")["device"]
27
+
28
+ # Process query image
29
+ query_img = model.processor(Image.open(query_image_path)).unsqueeze(0).to(device)
30
+
31
+ # Get token classifier
32
+ token_classifier, token_classifier_tokenizer = load_token_classifier(
33
+ ConfigManager().get("paths")["pretrained_token_classifier_path"],
34
+ device
35
+ )
36
+
37
+ with torch.no_grad():
38
+ query_img_embd = model.feature_extractor.encode_image(query_img)
39
+
40
+ # Process text query
41
+ predictions = predict(
42
+ tokens=query_text,
43
+ model=token_classifier,
44
+ tokenizer=token_classifier_tokenizer,
45
+ device=device,
46
+ max_length=128
47
+ )
48
+
49
+ # Process positive and negative objects
50
+ pos = []
51
+ neg = []
52
+ last_tag = ''
53
+ for token, label in predictions:
54
+ if label == '<positive_object>':
55
+ if last_tag != '<positive_object>':
56
+ pos.append(f"a photo of a {token}.")
57
+ else:
58
+ pos[-1] = pos[-1][:-1] + f" {token}."
59
+ elif label == '<negative_object>':
60
+ if last_tag != '<negative_object>':
61
+ neg.append(f"a photo of a {token}.")
62
+ else:
63
+ neg[-1] = neg[-1][:-1] + f" {token}."
64
+ last_tag = label
65
+
66
+ # Combine embeddings
67
+ for obj in pos:
68
+ query_img_embd += model.feature_extractor.encode_text(
69
+ model.tokenizer(obj).to(device)
70
+ )[0]
71
+ for obj in neg:
72
+ query_img_embd -= model.feature_extractor.encode_text(
73
+ model.tokenizer(obj).to(device)
74
+ )[0]
75
+
76
+ query_img_embd = torch.nn.functional.normalize(query_img_embd, dim=1, p=2)
77
+
78
+ # Calculate similarities
79
+ query_embedding = query_img_embd.cpu().numpy()
80
+ similarities = cosine_similarity(query_embedding, database_embeddings)[0]
81
+
82
+ # Get most similar image
83
+ most_similar_idx = np.argmax(similarities)
84
+ most_similar_image_path = database_df.iloc[most_similar_idx]['target_image']
85
+
86
+ return most_similar_image_path
87
+
88
+ # Initialize model and database
89
+ model = load_model()
90
+ database_df = load_dataset()
91
+ database_embeddings = encode_database(model, database_df) # Using your existing function
92
+
93
+ def interface_fn(selected_image, query_text):
94
+ result_image_path = process_single_query(
95
+ model,
96
+ selected_image,
97
+ query_text,
98
+ database_embeddings,
99
+ database_df
100
+ )
101
+ return Image.open(result_image_path)
102
+
103
+ # Create Gradio interface
104
+ demo = gr.Interface(
105
+ fn=interface_fn,
106
+ inputs=[
107
+ gr.Image(type="filepath", label="Select Query Image"),
108
+ gr.Textbox(label="Enter Query Text")
109
+ ],
110
+ outputs=gr.Image(label="Retrieved Image"),
111
+ title="Compositional Image Retrieval",
112
+ description="Select an image and enter a text query to find the most similar image.",
113
+ examples=[
114
+ ["example_images/image1.jpg", "a red car"],
115
+ ["example_images/image2.jpg", "a blue house"]
116
+ ]
117
+ )
118
+
119
+ if __name__ == "__main__":
120
+ demo.launch()