traopia commited on
Commit
c57012a
·
1 Parent(s): 69fe46f

search fashiondb

Browse files
Files changed (2) hide show
  1. app_fashionDB.py +277 -0
  2. search_fashionDB.py +125 -0
app_fashionDB.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+ from search_fashionDB import search_images_by_text, get_similar_images, search_images_by_image
5
+ import requests
6
+ from io import BytesIO
7
+
8
+ import requests
9
+ from io import BytesIO
10
+
11
+
12
+ #@st.cache_data(show_spinner="Loading FashionDB...")
13
+ def load_data_hf():
14
+ # Load the Parquet file directly from Hugging Face
15
+ df_url = "https://huggingface.co/datasets/traopia/FashionDB/resolve/main/data_vogue_final.parquet"
16
+ df = pd.read_parquet(df_url)
17
+ df = df.explode("image_urls_sample")
18
+ df = df.rename(columns={"image_urls_sample":"url", "URL":"collection"})
19
+
20
+ df_fh = pd.read_parquet("https://huggingface.co/datasets/traopia/FashionDB/resolve/main/final_info_fh.parquet")
21
+ df_designers = pd.read_parquet("https://huggingface.co/datasets/traopia/FashionDB/resolve/main/final_info_designers.parquet")
22
+
23
+ # Load the .npy file using requests
24
+ npy_url = "https://huggingface.co/datasets/traopia/FashionDB/resolve/main/fashion_clip.npy"
25
+ response = requests.get(npy_url)
26
+ response.raise_for_status() # Raise error if download fails
27
+ embeddings = np.load(BytesIO(response.content))
28
+ image_urls = "https://huggingface.co/datasets/traopia/FashionDB/resolve/main/image_urls.npy"
29
+ response = requests.get(image_urls)
30
+ response.raise_for_status() # Raise error if download fails
31
+ embeddings_urls = np.load(BytesIO(response.content), allow_pickle=True)
32
+
33
+ return df, df_fh, df_designers, embeddings, embeddings_urls
34
+
35
+
36
+
37
+
38
+ df, df_fh, df_designers, embeddings, embeddings_urls = load_data_hf()
39
+ # Suppose embeddings is a numpy array (N, D) and embeddings_urls is a list of urls/keys
40
+ embedding_map = {url: i for i, url in enumerate(embeddings_urls)}
41
+
42
+ # Filter and search
43
+ def filter_and_search(fashion_house, designer, category, season, start_year, end_year, query):
44
+ filtered = df.copy()
45
+
46
+ if fashion_house:
47
+ filtered = filtered[filtered['fashion_house'].isin(fashion_house)]
48
+
49
+ if designer:
50
+ filtered = filtered[filtered['designer_name'].isin(designer)]
51
+ if category:
52
+ filtered = filtered[filtered['category'].isin(category)]
53
+ if season:
54
+ filtered = filtered[filtered['season'].isin(season)]
55
+ filtered = filtered[(filtered['year'] >= start_year) & (filtered['year'] <= end_year)]
56
+
57
+ if query:
58
+ image_urls, metadata = search_images_by_text(query, filtered, embeddings, embeddings_urls)
59
+ else:
60
+ results = filtered.head(30)
61
+ image_urls = results["url"].tolist()
62
+ metadata = results.to_dict(orient="records")
63
+ return image_urls, metadata
64
+
65
+ # Display metadata and similar
66
+ def show_metadata(idx, metadata):
67
+ item = metadata[idx]
68
+ out = ""
69
+ for field in ["fashion_house", "designer_name", "season", "year", "category"]:
70
+ if field in item and pd.notna(item[field]):
71
+ out += f"**{field.title()}**: {item[field]}\n"
72
+ if 'collection' in item and pd.notna(item['collection']):
73
+ out += f"\n[View Collection]({item['collection']})"
74
+ return out
75
+
76
+
77
+
78
+ def find_similar(idx, metadata,top_k=5):
79
+ if not isinstance(idx, int) or idx >= len(metadata) or idx < 0:
80
+ return [], []
81
+
82
+ key = metadata[idx]["url"] # assumes each row has "key" (url or id)
83
+
84
+ image_urls, metadata = get_similar_images(df, key, embeddings, embedding_map, embeddings_urls, top_k=top_k)
85
+ return image_urls,metadata
86
+
87
+
88
+
89
+
90
+
91
+ with gr.Blocks() as demo:
92
+ gr.Markdown("# 👗 FashionDB Explorer")
93
+
94
+ with gr.Tabs():
95
+ # TEXT SEARCH TAB
96
+ with gr.Tab("Search by Text"):
97
+ with gr.Row():
98
+ fashion_house = gr.Dropdown(label="Fashion House", choices=sorted(df["fashion_house"].dropna().unique()), multiselect=True)
99
+ designer = gr.Dropdown(label="Fashion Designer", choices=sorted(df["designer_name"].dropna().unique()), multiselect=True)
100
+ category = gr.Dropdown(label="Category", choices=sorted(df["category"].dropna().unique()), multiselect=True)
101
+ season = gr.Dropdown(label="Season", choices=sorted(df["season"].dropna().unique()), multiselect=True)
102
+ min_year = int(df['year'].min())
103
+ max_year = int(df['year'].max())
104
+ start_year = gr.Slider(label="Start Year", minimum=min_year, maximum=max_year, value=2000, step=1)
105
+ end_year = gr.Slider(label="End Year", minimum=min_year, maximum=max_year, value=2024, step=1)
106
+
107
+ query = gr.Textbox(label="Search by text", placeholder="e.g., pink dress")
108
+ search_button = gr.Button("Search")
109
+
110
+ result_gallery = gr.Gallery(label="Search Results", columns=5, height="auto")
111
+ metadata_output = gr.Markdown()
112
+ reference_image = gr.Image(label="Reference Image", interactive=False)
113
+ similar_gallery = gr.Gallery(label="Similar Images", columns=5, height="auto")
114
+
115
+ metadata_state = gr.State([])
116
+ selected_idx = gr.Number(value=0, visible=False)
117
+
118
+ def handle_search(fh, dis, cat, sea, sy, ey, q):
119
+ imgs, meta = filter_and_search(fh, dis, cat, sea, sy, ey, q)
120
+ return imgs, meta, "", [], None
121
+
122
+ search_button.click(
123
+ handle_search,
124
+ inputs=[fashion_house, designer, category, season, start_year, end_year, query],
125
+ outputs=[result_gallery, metadata_state, metadata_output, similar_gallery, reference_image]
126
+ )
127
+
128
+ def handle_click(evt: gr.SelectData, metadata):
129
+ idx = evt.index
130
+ md = show_metadata(idx, metadata)
131
+ img_path = metadata[idx]["url"]
132
+ return idx, md, img_path
133
+
134
+ result_gallery.select(
135
+ handle_click,
136
+ inputs=[metadata_state],
137
+ outputs=[selected_idx, metadata_output, reference_image]
138
+ )
139
+
140
+ def show_similar(idx, metadata):
141
+ if idx is None or not str(idx).isdigit():
142
+ return [], []
143
+ return find_similar(int(idx), metadata)
144
+
145
+ similar_metadata_state = gr.State()
146
+ similar_metadata_output = gr.Markdown()
147
+
148
+ show_similar_button = gr.Button("Show Similar Images")
149
+ show_similar_button.click(
150
+ show_similar,
151
+ inputs=[selected_idx, metadata_state],
152
+ outputs=[similar_gallery, similar_metadata_state]
153
+ )
154
+
155
+ def handle_similar_click(evt: gr.SelectData, metadata):
156
+ idx = evt.index
157
+ md = show_metadata(idx, metadata)
158
+ img_path = metadata[idx]["url"]
159
+ return idx, md, img_path
160
+
161
+ similar_gallery.select(
162
+ handle_similar_click,
163
+ inputs=[similar_metadata_state],
164
+ outputs=[selected_idx, similar_metadata_output, reference_image]
165
+ )
166
+
167
+ # IMAGE SEARCH TAB
168
+ with gr.Tab("Search by Image"):
169
+ with gr.Row():
170
+ fashion_house_img = gr.Dropdown(label="Fashion House", choices=sorted(df["fashion_house"].dropna().unique()), multiselect=True)
171
+ designer_img = gr.Dropdown(label="Fashion Designer", choices=sorted(df["designer_name"].dropna().unique()), multiselect=True)
172
+ category_img = gr.Dropdown(label="Category", choices=sorted(df["category"].dropna().unique()), multiselect=True)
173
+ season_img = gr.Dropdown(label="Season", choices=sorted(df["season"].dropna().unique()), multiselect=True)
174
+ start_year_img = gr.Slider(label="Start Year", minimum=min_year, maximum=max_year, value=2000, step=1)
175
+ end_year_img = gr.Slider(label="End Year", minimum=min_year, maximum=max_year, value=2024, step=1)
176
+
177
+ uploaded_image = gr.Image(label="Upload an image", type="pil")
178
+ search_by_image_button = gr.Button("Search by Image")
179
+
180
+ uploaded_result_gallery = gr.Gallery(label="Search Results by Image", columns=5, height="auto")
181
+ uploaded_metadata_state = gr.State([])
182
+ uploaded_metadata_output = gr.Markdown()
183
+ uploaded_reference_image = gr.Image(label="Reference Image", interactive=False)
184
+
185
+ def handle_search_by_image(image, fh, dis, cat, sea, sy, ey):
186
+ if image is None:
187
+ return [], "Please upload an image first.", None
188
+ # Apply filters
189
+ filtered_df = df.copy()
190
+ if fh: filtered_df = filtered_df[filtered_df["fashion_house"].isin(fh)]
191
+ if dis: filtered_df = filtered_df[filtered_df["designer_name"].isin(fh)]
192
+ if cat: filtered_df = filtered_df[filtered_df["category"].isin(cat)]
193
+ if sea: filtered_df = filtered_df[filtered_df["season"].isin(sea)]
194
+ filtered_df = filtered_df[(filtered_df["year"] >= sy) & (filtered_df["year"] <= ey)]
195
+
196
+ images, metadata = search_images_by_image(image, filtered_df, embeddings, embeddings_urls)
197
+ return images, metadata, ""
198
+
199
+ search_by_image_button.click(
200
+ handle_search_by_image,
201
+ inputs=[uploaded_image, fashion_house_img, designer_img, category_img, season_img, start_year_img, end_year_img],
202
+ outputs=[uploaded_result_gallery, uploaded_metadata_state, uploaded_metadata_output]
203
+ )
204
+
205
+ uploaded_selected_idx = gr.Number(visible=False)
206
+
207
+ def handle_uploaded_click(evt: gr.SelectData, metadata):
208
+ idx = evt.index
209
+ md = show_metadata(idx, metadata)
210
+ img_path = metadata[idx]["url"]
211
+ return idx, md, img_path
212
+
213
+ uploaded_result_gallery.select(
214
+ handle_uploaded_click,
215
+ inputs=[uploaded_metadata_state],
216
+ outputs=[uploaded_selected_idx, uploaded_metadata_output, uploaded_reference_image]
217
+ )
218
+
219
+ # SIMILAR IMAGE SEARCH FOR IMAGE TAB
220
+ uploaded_similar_gallery = gr.Gallery(label="Similar Images", columns=5, height="auto")
221
+ uploaded_similar_metadata_state = gr.State([])
222
+ uploaded_similar_metadata_output = gr.Markdown()
223
+
224
+ uploaded_show_similar_button = gr.Button("Show Similar Images")
225
+
226
+ def show_similar_uploaded(idx, metadata):
227
+ if idx is None or not str(idx).isdigit():
228
+ return [], []
229
+ return find_similar(int(idx), metadata)
230
+
231
+ uploaded_show_similar_button.click(
232
+ show_similar_uploaded,
233
+ inputs=[uploaded_selected_idx, uploaded_metadata_state],
234
+ outputs=[uploaded_similar_gallery, uploaded_similar_metadata_state]
235
+ )
236
+
237
+ def handle_uploaded_similar_click(evt: gr.SelectData, metadata):
238
+ idx = evt.index
239
+ md = show_metadata(idx, metadata)
240
+ img_path = metadata[idx]["url"]
241
+ return idx, md, img_path
242
+
243
+ uploaded_similar_gallery.select(
244
+ handle_uploaded_similar_click,
245
+ inputs=[uploaded_similar_metadata_state],
246
+ outputs=[uploaded_selected_idx, uploaded_similar_metadata_output, uploaded_reference_image]
247
+ )
248
+
249
+ uploaded_back_button = gr.Button("Back to Initial Uploaded Search")
250
+
251
+ def back_to_uploaded_home():
252
+ return [], "", None
253
+
254
+ uploaded_back_button.click(
255
+ back_to_uploaded_home,
256
+ outputs=[uploaded_similar_gallery, uploaded_similar_metadata_output, uploaded_reference_image]
257
+ )
258
+
259
+ with gr.Tab("Query on FashionDB"):
260
+ with gr.Row():
261
+ gr.Markdown(
262
+ "### 🔗 Query FashionDB SPARQL Endpoint\n"
263
+ "[Click here to open the SPARQL endpoint](https://fashionwiki.wikibase.cloud/query/)",
264
+ elem_id="sparql-link"
265
+ )
266
+
267
+ back_button = gr.Button("Back to Home")
268
+
269
+ def back_to_home():
270
+ return [], "", None # clear similar_gallery, metadata_output, reference image
271
+
272
+ back_button.click(
273
+ back_to_home,
274
+ outputs=[similar_gallery, similar_metadata_output, reference_image]
275
+ )
276
+
277
+ demo.launch()
search_fashionDB.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.metrics.pairwise import cosine_similarity
2
+ import numpy as np
3
+
4
+
5
+ from transformers import pipeline
6
+ from PIL import Image
7
+ import numpy as np
8
+ import os
9
+ import torch
10
+ from transformers import CLIPProcessor, CLIPModel
11
+ import pandas as pd
12
+
13
+ #set device: Use GPU if availanle, otherwise mps if available otherwise CPU
14
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
15
+
16
+
17
+ # Load Fashion-CLIP model and processor
18
+ model_name = "patrickjohncyh/fashion-clip"
19
+ #model_name = "openai/clip-vit-base-patch32"
20
+ model = CLIPModel.from_pretrained(model_name).to(device)
21
+ processor = CLIPProcessor.from_pretrained(model_name)
22
+
23
+ # Initialize segmentation pipeline
24
+ segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes", device = device)
25
+
26
+
27
+ def segment_clothing_white(img, clothes=["Background"]):
28
+ segments = segmenter(img)
29
+
30
+ # Create list of masks
31
+ mask_list = []
32
+ for s in segments:
33
+ if s['label'] in clothes:
34
+ mask_list.append(s['mask'])
35
+
36
+ if not mask_list:
37
+ print("No clothing segments found in image.")
38
+ return img # Return the original image if no segments are found
39
+
40
+ # Combine all masks into a single mask
41
+ final_mask = np.array(mask_list[0])
42
+ for mask in mask_list[1:]:
43
+ final_mask = np.maximum(final_mask, np.array(mask)) # Combine masks using max
44
+
45
+ # Apply the mask to the image
46
+ img_array = np.array(img) # Convert image to numpy array
47
+ final_mask = final_mask.astype(bool) # Convert mask to boolean
48
+ img_array[final_mask] = [255,255,255] # Set unmasked regions to black
49
+
50
+ # Convert back to PIL image
51
+ segmented_img = Image.fromarray(img_array)
52
+ return segmented_img
53
+
54
+ def encode_image(image):
55
+ """Encode image into an embedding."""
56
+ inputs = processor(images=image, return_tensors="pt").to(device)
57
+
58
+ with torch.no_grad():
59
+ embedding = model.get_image_features(**inputs).cpu().numpy() # Move to CPU for stability
60
+ embedding = embedding / torch.linalg.norm(torch.tensor(embedding), ord=2, dim=-1, keepdim=True)
61
+ embedding = embedding.numpy().astype(np.float32).flatten()
62
+ return embedding
63
+
64
+ from PIL import Image
65
+ import torchvision.transforms as T
66
+
67
+
68
+
69
+ def search_images_by_image(uploaded_image, df, embeddings,embeddings_urls, top_k=30):
70
+ # Convert to PIL
71
+ if isinstance(uploaded_image, str):
72
+ uploaded_image = Image.open(uploaded_image).convert("RGB")
73
+ elif isinstance(uploaded_image, np.ndarray):
74
+ uploaded_image = Image.fromarray(uploaded_image).convert("RGB")
75
+
76
+ # Encode with CLIP
77
+ image_emb = encode_image(uploaded_image)
78
+
79
+ # Similarity against ALL embeddings
80
+ sims = cosine_similarity([image_emb], embeddings)[0]
81
+ top_indices = np.argsort(sims)[::-1][:top_k]
82
+ top_urls = [embeddings_urls[i] for i in top_indices]
83
+ metadata = df[df["url"].isin(top_urls)].copy().to_dict(orient="records")
84
+
85
+
86
+ return top_urls, metadata
87
+
88
+
89
+
90
+
91
+
92
+
93
+
94
+
95
+ def search_images_by_text(text, df, embeddings, embeddings_urls, top_k=30):
96
+ inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True, max_length=77).to(device)
97
+ with torch.no_grad():
98
+ text_emb = model.get_text_features(**inputs).cpu().numpy()
99
+
100
+ df_indices = df.index.to_numpy()
101
+ # slice embeddings & urls to match the filtered df
102
+ embeddings_filtered = embeddings[df_indices]
103
+ sims = cosine_similarity(text_emb, embeddings_filtered)[0]
104
+ sims = np.asarray(sims).flatten()
105
+ top_indices = np.argsort(sims)[::-1][:top_k]
106
+ top_urls = [embeddings_urls[i] for i in top_indices]
107
+ metadata = df[df["url"].isin(top_urls)].copy().to_dict(orient="records")
108
+
109
+ return top_urls, metadata
110
+
111
+ def get_similar_images(df, image_key, embeddings, embedding_map, embeddings_urls, top_k=5):
112
+ if image_key not in embedding_map:
113
+ return pd.DataFrame() # fallback: no match found
114
+
115
+ index = embedding_map[image_key]
116
+ query_emb = embeddings[index]
117
+
118
+ sims = cosine_similarity([query_emb], embeddings)[0]
119
+ top_indices = np.argsort(sims)[::-1][1:top_k+1] # skip itself
120
+ top_urls = [embeddings_urls[i] for i in top_indices]
121
+ metadata = df[df["url"].isin(top_urls)].copy().to_dict(orient="records")
122
+
123
+ return top_urls, metadata
124
+
125
+