Spaces:
Running
Running
traopia
commited on
Commit
·
c57012a
1
Parent(s):
69fe46f
search fashiondb
Browse files- app_fashionDB.py +277 -0
- search_fashionDB.py +125 -0
app_fashionDB.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
from search_fashionDB import search_images_by_text, get_similar_images, search_images_by_image
|
5 |
+
import requests
|
6 |
+
from io import BytesIO
|
7 |
+
|
8 |
+
import requests
|
9 |
+
from io import BytesIO
|
10 |
+
|
11 |
+
|
12 |
+
#@st.cache_data(show_spinner="Loading FashionDB...")
|
13 |
+
def load_data_hf():
|
14 |
+
# Load the Parquet file directly from Hugging Face
|
15 |
+
df_url = "https://huggingface.co/datasets/traopia/FashionDB/resolve/main/data_vogue_final.parquet"
|
16 |
+
df = pd.read_parquet(df_url)
|
17 |
+
df = df.explode("image_urls_sample")
|
18 |
+
df = df.rename(columns={"image_urls_sample":"url", "URL":"collection"})
|
19 |
+
|
20 |
+
df_fh = pd.read_parquet("https://huggingface.co/datasets/traopia/FashionDB/resolve/main/final_info_fh.parquet")
|
21 |
+
df_designers = pd.read_parquet("https://huggingface.co/datasets/traopia/FashionDB/resolve/main/final_info_designers.parquet")
|
22 |
+
|
23 |
+
# Load the .npy file using requests
|
24 |
+
npy_url = "https://huggingface.co/datasets/traopia/FashionDB/resolve/main/fashion_clip.npy"
|
25 |
+
response = requests.get(npy_url)
|
26 |
+
response.raise_for_status() # Raise error if download fails
|
27 |
+
embeddings = np.load(BytesIO(response.content))
|
28 |
+
image_urls = "https://huggingface.co/datasets/traopia/FashionDB/resolve/main/image_urls.npy"
|
29 |
+
response = requests.get(image_urls)
|
30 |
+
response.raise_for_status() # Raise error if download fails
|
31 |
+
embeddings_urls = np.load(BytesIO(response.content), allow_pickle=True)
|
32 |
+
|
33 |
+
return df, df_fh, df_designers, embeddings, embeddings_urls
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
df, df_fh, df_designers, embeddings, embeddings_urls = load_data_hf()
|
39 |
+
# Suppose embeddings is a numpy array (N, D) and embeddings_urls is a list of urls/keys
|
40 |
+
embedding_map = {url: i for i, url in enumerate(embeddings_urls)}
|
41 |
+
|
42 |
+
# Filter and search
|
43 |
+
def filter_and_search(fashion_house, designer, category, season, start_year, end_year, query):
|
44 |
+
filtered = df.copy()
|
45 |
+
|
46 |
+
if fashion_house:
|
47 |
+
filtered = filtered[filtered['fashion_house'].isin(fashion_house)]
|
48 |
+
|
49 |
+
if designer:
|
50 |
+
filtered = filtered[filtered['designer_name'].isin(designer)]
|
51 |
+
if category:
|
52 |
+
filtered = filtered[filtered['category'].isin(category)]
|
53 |
+
if season:
|
54 |
+
filtered = filtered[filtered['season'].isin(season)]
|
55 |
+
filtered = filtered[(filtered['year'] >= start_year) & (filtered['year'] <= end_year)]
|
56 |
+
|
57 |
+
if query:
|
58 |
+
image_urls, metadata = search_images_by_text(query, filtered, embeddings, embeddings_urls)
|
59 |
+
else:
|
60 |
+
results = filtered.head(30)
|
61 |
+
image_urls = results["url"].tolist()
|
62 |
+
metadata = results.to_dict(orient="records")
|
63 |
+
return image_urls, metadata
|
64 |
+
|
65 |
+
# Display metadata and similar
|
66 |
+
def show_metadata(idx, metadata):
|
67 |
+
item = metadata[idx]
|
68 |
+
out = ""
|
69 |
+
for field in ["fashion_house", "designer_name", "season", "year", "category"]:
|
70 |
+
if field in item and pd.notna(item[field]):
|
71 |
+
out += f"**{field.title()}**: {item[field]}\n"
|
72 |
+
if 'collection' in item and pd.notna(item['collection']):
|
73 |
+
out += f"\n[View Collection]({item['collection']})"
|
74 |
+
return out
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
def find_similar(idx, metadata,top_k=5):
|
79 |
+
if not isinstance(idx, int) or idx >= len(metadata) or idx < 0:
|
80 |
+
return [], []
|
81 |
+
|
82 |
+
key = metadata[idx]["url"] # assumes each row has "key" (url or id)
|
83 |
+
|
84 |
+
image_urls, metadata = get_similar_images(df, key, embeddings, embedding_map, embeddings_urls, top_k=top_k)
|
85 |
+
return image_urls,metadata
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
with gr.Blocks() as demo:
|
92 |
+
gr.Markdown("# 👗 FashionDB Explorer")
|
93 |
+
|
94 |
+
with gr.Tabs():
|
95 |
+
# TEXT SEARCH TAB
|
96 |
+
with gr.Tab("Search by Text"):
|
97 |
+
with gr.Row():
|
98 |
+
fashion_house = gr.Dropdown(label="Fashion House", choices=sorted(df["fashion_house"].dropna().unique()), multiselect=True)
|
99 |
+
designer = gr.Dropdown(label="Fashion Designer", choices=sorted(df["designer_name"].dropna().unique()), multiselect=True)
|
100 |
+
category = gr.Dropdown(label="Category", choices=sorted(df["category"].dropna().unique()), multiselect=True)
|
101 |
+
season = gr.Dropdown(label="Season", choices=sorted(df["season"].dropna().unique()), multiselect=True)
|
102 |
+
min_year = int(df['year'].min())
|
103 |
+
max_year = int(df['year'].max())
|
104 |
+
start_year = gr.Slider(label="Start Year", minimum=min_year, maximum=max_year, value=2000, step=1)
|
105 |
+
end_year = gr.Slider(label="End Year", minimum=min_year, maximum=max_year, value=2024, step=1)
|
106 |
+
|
107 |
+
query = gr.Textbox(label="Search by text", placeholder="e.g., pink dress")
|
108 |
+
search_button = gr.Button("Search")
|
109 |
+
|
110 |
+
result_gallery = gr.Gallery(label="Search Results", columns=5, height="auto")
|
111 |
+
metadata_output = gr.Markdown()
|
112 |
+
reference_image = gr.Image(label="Reference Image", interactive=False)
|
113 |
+
similar_gallery = gr.Gallery(label="Similar Images", columns=5, height="auto")
|
114 |
+
|
115 |
+
metadata_state = gr.State([])
|
116 |
+
selected_idx = gr.Number(value=0, visible=False)
|
117 |
+
|
118 |
+
def handle_search(fh, dis, cat, sea, sy, ey, q):
|
119 |
+
imgs, meta = filter_and_search(fh, dis, cat, sea, sy, ey, q)
|
120 |
+
return imgs, meta, "", [], None
|
121 |
+
|
122 |
+
search_button.click(
|
123 |
+
handle_search,
|
124 |
+
inputs=[fashion_house, designer, category, season, start_year, end_year, query],
|
125 |
+
outputs=[result_gallery, metadata_state, metadata_output, similar_gallery, reference_image]
|
126 |
+
)
|
127 |
+
|
128 |
+
def handle_click(evt: gr.SelectData, metadata):
|
129 |
+
idx = evt.index
|
130 |
+
md = show_metadata(idx, metadata)
|
131 |
+
img_path = metadata[idx]["url"]
|
132 |
+
return idx, md, img_path
|
133 |
+
|
134 |
+
result_gallery.select(
|
135 |
+
handle_click,
|
136 |
+
inputs=[metadata_state],
|
137 |
+
outputs=[selected_idx, metadata_output, reference_image]
|
138 |
+
)
|
139 |
+
|
140 |
+
def show_similar(idx, metadata):
|
141 |
+
if idx is None or not str(idx).isdigit():
|
142 |
+
return [], []
|
143 |
+
return find_similar(int(idx), metadata)
|
144 |
+
|
145 |
+
similar_metadata_state = gr.State()
|
146 |
+
similar_metadata_output = gr.Markdown()
|
147 |
+
|
148 |
+
show_similar_button = gr.Button("Show Similar Images")
|
149 |
+
show_similar_button.click(
|
150 |
+
show_similar,
|
151 |
+
inputs=[selected_idx, metadata_state],
|
152 |
+
outputs=[similar_gallery, similar_metadata_state]
|
153 |
+
)
|
154 |
+
|
155 |
+
def handle_similar_click(evt: gr.SelectData, metadata):
|
156 |
+
idx = evt.index
|
157 |
+
md = show_metadata(idx, metadata)
|
158 |
+
img_path = metadata[idx]["url"]
|
159 |
+
return idx, md, img_path
|
160 |
+
|
161 |
+
similar_gallery.select(
|
162 |
+
handle_similar_click,
|
163 |
+
inputs=[similar_metadata_state],
|
164 |
+
outputs=[selected_idx, similar_metadata_output, reference_image]
|
165 |
+
)
|
166 |
+
|
167 |
+
# IMAGE SEARCH TAB
|
168 |
+
with gr.Tab("Search by Image"):
|
169 |
+
with gr.Row():
|
170 |
+
fashion_house_img = gr.Dropdown(label="Fashion House", choices=sorted(df["fashion_house"].dropna().unique()), multiselect=True)
|
171 |
+
designer_img = gr.Dropdown(label="Fashion Designer", choices=sorted(df["designer_name"].dropna().unique()), multiselect=True)
|
172 |
+
category_img = gr.Dropdown(label="Category", choices=sorted(df["category"].dropna().unique()), multiselect=True)
|
173 |
+
season_img = gr.Dropdown(label="Season", choices=sorted(df["season"].dropna().unique()), multiselect=True)
|
174 |
+
start_year_img = gr.Slider(label="Start Year", minimum=min_year, maximum=max_year, value=2000, step=1)
|
175 |
+
end_year_img = gr.Slider(label="End Year", minimum=min_year, maximum=max_year, value=2024, step=1)
|
176 |
+
|
177 |
+
uploaded_image = gr.Image(label="Upload an image", type="pil")
|
178 |
+
search_by_image_button = gr.Button("Search by Image")
|
179 |
+
|
180 |
+
uploaded_result_gallery = gr.Gallery(label="Search Results by Image", columns=5, height="auto")
|
181 |
+
uploaded_metadata_state = gr.State([])
|
182 |
+
uploaded_metadata_output = gr.Markdown()
|
183 |
+
uploaded_reference_image = gr.Image(label="Reference Image", interactive=False)
|
184 |
+
|
185 |
+
def handle_search_by_image(image, fh, dis, cat, sea, sy, ey):
|
186 |
+
if image is None:
|
187 |
+
return [], "Please upload an image first.", None
|
188 |
+
# Apply filters
|
189 |
+
filtered_df = df.copy()
|
190 |
+
if fh: filtered_df = filtered_df[filtered_df["fashion_house"].isin(fh)]
|
191 |
+
if dis: filtered_df = filtered_df[filtered_df["designer_name"].isin(fh)]
|
192 |
+
if cat: filtered_df = filtered_df[filtered_df["category"].isin(cat)]
|
193 |
+
if sea: filtered_df = filtered_df[filtered_df["season"].isin(sea)]
|
194 |
+
filtered_df = filtered_df[(filtered_df["year"] >= sy) & (filtered_df["year"] <= ey)]
|
195 |
+
|
196 |
+
images, metadata = search_images_by_image(image, filtered_df, embeddings, embeddings_urls)
|
197 |
+
return images, metadata, ""
|
198 |
+
|
199 |
+
search_by_image_button.click(
|
200 |
+
handle_search_by_image,
|
201 |
+
inputs=[uploaded_image, fashion_house_img, designer_img, category_img, season_img, start_year_img, end_year_img],
|
202 |
+
outputs=[uploaded_result_gallery, uploaded_metadata_state, uploaded_metadata_output]
|
203 |
+
)
|
204 |
+
|
205 |
+
uploaded_selected_idx = gr.Number(visible=False)
|
206 |
+
|
207 |
+
def handle_uploaded_click(evt: gr.SelectData, metadata):
|
208 |
+
idx = evt.index
|
209 |
+
md = show_metadata(idx, metadata)
|
210 |
+
img_path = metadata[idx]["url"]
|
211 |
+
return idx, md, img_path
|
212 |
+
|
213 |
+
uploaded_result_gallery.select(
|
214 |
+
handle_uploaded_click,
|
215 |
+
inputs=[uploaded_metadata_state],
|
216 |
+
outputs=[uploaded_selected_idx, uploaded_metadata_output, uploaded_reference_image]
|
217 |
+
)
|
218 |
+
|
219 |
+
# SIMILAR IMAGE SEARCH FOR IMAGE TAB
|
220 |
+
uploaded_similar_gallery = gr.Gallery(label="Similar Images", columns=5, height="auto")
|
221 |
+
uploaded_similar_metadata_state = gr.State([])
|
222 |
+
uploaded_similar_metadata_output = gr.Markdown()
|
223 |
+
|
224 |
+
uploaded_show_similar_button = gr.Button("Show Similar Images")
|
225 |
+
|
226 |
+
def show_similar_uploaded(idx, metadata):
|
227 |
+
if idx is None or not str(idx).isdigit():
|
228 |
+
return [], []
|
229 |
+
return find_similar(int(idx), metadata)
|
230 |
+
|
231 |
+
uploaded_show_similar_button.click(
|
232 |
+
show_similar_uploaded,
|
233 |
+
inputs=[uploaded_selected_idx, uploaded_metadata_state],
|
234 |
+
outputs=[uploaded_similar_gallery, uploaded_similar_metadata_state]
|
235 |
+
)
|
236 |
+
|
237 |
+
def handle_uploaded_similar_click(evt: gr.SelectData, metadata):
|
238 |
+
idx = evt.index
|
239 |
+
md = show_metadata(idx, metadata)
|
240 |
+
img_path = metadata[idx]["url"]
|
241 |
+
return idx, md, img_path
|
242 |
+
|
243 |
+
uploaded_similar_gallery.select(
|
244 |
+
handle_uploaded_similar_click,
|
245 |
+
inputs=[uploaded_similar_metadata_state],
|
246 |
+
outputs=[uploaded_selected_idx, uploaded_similar_metadata_output, uploaded_reference_image]
|
247 |
+
)
|
248 |
+
|
249 |
+
uploaded_back_button = gr.Button("Back to Initial Uploaded Search")
|
250 |
+
|
251 |
+
def back_to_uploaded_home():
|
252 |
+
return [], "", None
|
253 |
+
|
254 |
+
uploaded_back_button.click(
|
255 |
+
back_to_uploaded_home,
|
256 |
+
outputs=[uploaded_similar_gallery, uploaded_similar_metadata_output, uploaded_reference_image]
|
257 |
+
)
|
258 |
+
|
259 |
+
with gr.Tab("Query on FashionDB"):
|
260 |
+
with gr.Row():
|
261 |
+
gr.Markdown(
|
262 |
+
"### 🔗 Query FashionDB SPARQL Endpoint\n"
|
263 |
+
"[Click here to open the SPARQL endpoint](https://fashionwiki.wikibase.cloud/query/)",
|
264 |
+
elem_id="sparql-link"
|
265 |
+
)
|
266 |
+
|
267 |
+
back_button = gr.Button("Back to Home")
|
268 |
+
|
269 |
+
def back_to_home():
|
270 |
+
return [], "", None # clear similar_gallery, metadata_output, reference image
|
271 |
+
|
272 |
+
back_button.click(
|
273 |
+
back_to_home,
|
274 |
+
outputs=[similar_gallery, similar_metadata_output, reference_image]
|
275 |
+
)
|
276 |
+
|
277 |
+
demo.launch()
|
search_fashionDB.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
from transformers import pipeline
|
6 |
+
from PIL import Image
|
7 |
+
import numpy as np
|
8 |
+
import os
|
9 |
+
import torch
|
10 |
+
from transformers import CLIPProcessor, CLIPModel
|
11 |
+
import pandas as pd
|
12 |
+
|
13 |
+
#set device: Use GPU if availanle, otherwise mps if available otherwise CPU
|
14 |
+
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
15 |
+
|
16 |
+
|
17 |
+
# Load Fashion-CLIP model and processor
|
18 |
+
model_name = "patrickjohncyh/fashion-clip"
|
19 |
+
#model_name = "openai/clip-vit-base-patch32"
|
20 |
+
model = CLIPModel.from_pretrained(model_name).to(device)
|
21 |
+
processor = CLIPProcessor.from_pretrained(model_name)
|
22 |
+
|
23 |
+
# Initialize segmentation pipeline
|
24 |
+
segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes", device = device)
|
25 |
+
|
26 |
+
|
27 |
+
def segment_clothing_white(img, clothes=["Background"]):
|
28 |
+
segments = segmenter(img)
|
29 |
+
|
30 |
+
# Create list of masks
|
31 |
+
mask_list = []
|
32 |
+
for s in segments:
|
33 |
+
if s['label'] in clothes:
|
34 |
+
mask_list.append(s['mask'])
|
35 |
+
|
36 |
+
if not mask_list:
|
37 |
+
print("No clothing segments found in image.")
|
38 |
+
return img # Return the original image if no segments are found
|
39 |
+
|
40 |
+
# Combine all masks into a single mask
|
41 |
+
final_mask = np.array(mask_list[0])
|
42 |
+
for mask in mask_list[1:]:
|
43 |
+
final_mask = np.maximum(final_mask, np.array(mask)) # Combine masks using max
|
44 |
+
|
45 |
+
# Apply the mask to the image
|
46 |
+
img_array = np.array(img) # Convert image to numpy array
|
47 |
+
final_mask = final_mask.astype(bool) # Convert mask to boolean
|
48 |
+
img_array[final_mask] = [255,255,255] # Set unmasked regions to black
|
49 |
+
|
50 |
+
# Convert back to PIL image
|
51 |
+
segmented_img = Image.fromarray(img_array)
|
52 |
+
return segmented_img
|
53 |
+
|
54 |
+
def encode_image(image):
|
55 |
+
"""Encode image into an embedding."""
|
56 |
+
inputs = processor(images=image, return_tensors="pt").to(device)
|
57 |
+
|
58 |
+
with torch.no_grad():
|
59 |
+
embedding = model.get_image_features(**inputs).cpu().numpy() # Move to CPU for stability
|
60 |
+
embedding = embedding / torch.linalg.norm(torch.tensor(embedding), ord=2, dim=-1, keepdim=True)
|
61 |
+
embedding = embedding.numpy().astype(np.float32).flatten()
|
62 |
+
return embedding
|
63 |
+
|
64 |
+
from PIL import Image
|
65 |
+
import torchvision.transforms as T
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
def search_images_by_image(uploaded_image, df, embeddings,embeddings_urls, top_k=30):
|
70 |
+
# Convert to PIL
|
71 |
+
if isinstance(uploaded_image, str):
|
72 |
+
uploaded_image = Image.open(uploaded_image).convert("RGB")
|
73 |
+
elif isinstance(uploaded_image, np.ndarray):
|
74 |
+
uploaded_image = Image.fromarray(uploaded_image).convert("RGB")
|
75 |
+
|
76 |
+
# Encode with CLIP
|
77 |
+
image_emb = encode_image(uploaded_image)
|
78 |
+
|
79 |
+
# Similarity against ALL embeddings
|
80 |
+
sims = cosine_similarity([image_emb], embeddings)[0]
|
81 |
+
top_indices = np.argsort(sims)[::-1][:top_k]
|
82 |
+
top_urls = [embeddings_urls[i] for i in top_indices]
|
83 |
+
metadata = df[df["url"].isin(top_urls)].copy().to_dict(orient="records")
|
84 |
+
|
85 |
+
|
86 |
+
return top_urls, metadata
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
def search_images_by_text(text, df, embeddings, embeddings_urls, top_k=30):
|
96 |
+
inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True, max_length=77).to(device)
|
97 |
+
with torch.no_grad():
|
98 |
+
text_emb = model.get_text_features(**inputs).cpu().numpy()
|
99 |
+
|
100 |
+
df_indices = df.index.to_numpy()
|
101 |
+
# slice embeddings & urls to match the filtered df
|
102 |
+
embeddings_filtered = embeddings[df_indices]
|
103 |
+
sims = cosine_similarity(text_emb, embeddings_filtered)[0]
|
104 |
+
sims = np.asarray(sims).flatten()
|
105 |
+
top_indices = np.argsort(sims)[::-1][:top_k]
|
106 |
+
top_urls = [embeddings_urls[i] for i in top_indices]
|
107 |
+
metadata = df[df["url"].isin(top_urls)].copy().to_dict(orient="records")
|
108 |
+
|
109 |
+
return top_urls, metadata
|
110 |
+
|
111 |
+
def get_similar_images(df, image_key, embeddings, embedding_map, embeddings_urls, top_k=5):
|
112 |
+
if image_key not in embedding_map:
|
113 |
+
return pd.DataFrame() # fallback: no match found
|
114 |
+
|
115 |
+
index = embedding_map[image_key]
|
116 |
+
query_emb = embeddings[index]
|
117 |
+
|
118 |
+
sims = cosine_similarity([query_emb], embeddings)[0]
|
119 |
+
top_indices = np.argsort(sims)[::-1][1:top_k+1] # skip itself
|
120 |
+
top_urls = [embeddings_urls[i] for i in top_indices]
|
121 |
+
metadata = df[df["url"].isin(top_urls)].copy().to_dict(orient="records")
|
122 |
+
|
123 |
+
return top_urls, metadata
|
124 |
+
|
125 |
+
|