Anusha806 commited on
Commit
cd83986
ยท
1 Parent(s): 8e7d115
Files changed (1) hide show
  1. app.py +345 -317
app.py CHANGED
@@ -1,334 +1,61 @@
1
- """Hybrid Multimodal Vector Search for E-Commerce Product Discovery"""
2
-
3
- import os
4
- import time
5
- import numpy as np
6
- from PIL import Image, ImageOps
7
- from datasets import load_dataset
8
- from pinecone import Pinecone, ServerlessSpec
9
- from pinecone_text.sparse import BM25Encoder
10
- from sentence_transformers import SentenceTransformer
11
- import torch
12
- import gradio as gr
13
- import pandas as pd
14
-
15
- # Set Pinecone API Key and config
16
- os.environ["PINECONE_API_KEY"] = "pcsk_TMCYK_LrbmZMTDhkxTjUXcr8iTcQ8LxurwKBFDvv4ahFis8SVob7QexVPPEt6g2zW6d3g"
17
- api_key = os.environ.get('PINECONE_API_KEY')
18
- pc = Pinecone(api_key=api_key)
19
-
20
- cloud = os.environ.get('PINECONE_CLOUD', 'aws')
21
- region = os.environ.get('PINECONE_REGION', 'us-east-1')
22
- spec = ServerlessSpec(cloud=cloud, region=region)
23
- index_name = "hybrid-image-search"
24
-
25
- # Create and connect to index
26
- if index_name not in pc.list_indexes().names():
27
- pc.create_index(index_name, dimension=512, metric='dotproduct', spec=spec)
28
- while not pc.describe_index(index_name).status['ready']:
29
- time.sleep(1)
30
-
31
- index = pc.Index(index_name)
32
- index.describe_index_stats()
33
-
34
- # Load dataset
35
- fashion = load_dataset("ashraq/fashion-product-images-small", split="train")
36
- images = fashion["image"]
37
- metadata = fashion.remove_columns("image").to_pandas()
38
-
39
- # Fit BM25
40
- bm25 = BM25Encoder()
41
- bm25.fit(metadata['productDisplayName'])
42
-
43
- # Load CLIP model
44
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
45
- model = SentenceTransformer('sentence-transformers/clip-ViT-B-32', device=device)
46
-
47
- # Hybrid scaler
48
- def hybrid_scale(dense, sparse, alpha: float):
49
- if alpha < 0 or alpha > 1:
50
- raise ValueError("Alpha must be between 0 and 1")
51
- hsparse = {
52
- 'indices': sparse['indices'],
53
- 'values': [v * (1 - alpha) for v in sparse['values']]
54
- }
55
- hdense = [v * alpha for v in dense]
56
- return hdense, hsparse
57
-
58
- # Metadata filter extractor
59
- from PIL import Image, ImageOps
60
- import numpy as np
61
-
62
- def extract_metadata_filters(query: str):
63
- query_lower = query.lower()
64
- gender = None
65
- category = None
66
- subcategory = None
67
- color = None
68
-
69
- # --- Gender Mapping ---
70
- gender_map = {
71
- "men": "Men", "man": "Men", "mens": "Men", "mans": "Men", "male": "Men",
72
- "women": "Women", "woman": "Women", "womens": "Women", "female": "Women",
73
- "boys": "Boys", "boy": "Boys",
74
- "girls": "Girls", "girl": "Girls",
75
- "kids": "Kids", "unisex": "Unisex"
76
- }
77
- for term, mapped_value in gender_map.items():
78
- if term in query_lower:
79
- gender = mapped_value
80
- break
81
-
82
- # --- Category Mapping ---
83
- category_map = {
84
- "shirt": "Shirts",
85
- "tshirt": "Tshirts", "t-shirt": "Tshirts",
86
- "jeans": "Jeans",
87
- "watch": "Watches",
88
- "kurta": "Kurtas",
89
- "dress": "Dresses", "dresses": "Dresses",
90
- "trousers": "Trousers", "pants": "Trousers",
91
- "shorts": "Shorts",
92
- "footwear": "Footwear",
93
- "shoes": "Footwear",
94
- "fashion": "Apparel"
95
- }
96
- for term, mapped_value in category_map.items():
97
- if term in query_lower:
98
- category = mapped_value
99
- break
100
-
101
- # --- SubCategory Mapping ---
102
- subCategory_list = [
103
- "Accessories", "Apparel Set", "Bags", "Bath and Body", "Beauty Accessories",
104
- "Belts", "Bottomwear", "Cufflinks", "Dress", "Eyes", "Eyewear", "Flip Flops",
105
- "Fragrance", "Free Gifts", "Gloves", "Hair", "Headwear", "Home Furnishing",
106
- "Innerwear", "Jewellery", "Lips", "Loungewear and Nightwear", "Makeup",
107
- "Mufflers", "Nails", "Perfumes", "Sandal", "Saree", "Scarves", "Shoe Accessories",
108
- "Shoes", "Skin", "Skin Care", "Socks", "Sports Accessories", "Sports Equipment",
109
- "Stoles", "Ties", "Topwear", "Umbrellas", "Vouchers", "Wallets", "Watches",
110
- "Water Bottle", "Wristbands"
111
- ]
112
- if "topwear" in query_lower or "top" in query_lower:
113
- subcategory = "Topwear"
114
- else:
115
- for subcat in subCategory_list:
116
- if subcat.lower() in query_lower:
117
- subcategory = subcat
118
- break
119
-
120
- # --- Color Extraction ---
121
- colors = [
122
- "red","blue","green","yellow","black","white",
123
- "orange","pink","purple","brown","grey","beige"
124
- ]
125
- for c in colors:
126
- if c in query_lower:
127
- color = c.capitalize()
128
- break
129
-
130
- # --- Invalid pairs ---
131
- invalid_pairs = {
132
- ("Men", "Dresses"), ("Men", "Sarees"), ("Men", "Skirts"),
133
- ("Boys", "Dresses"), ("Boys", "Sarees"),
134
- ("Girls", "Boxers"), ("Men", "Heels")
135
- }
136
- if (gender, category) in invalid_pairs:
137
- print(f"โš ๏ธ Invalid pair: {gender} + {category}, dropping gender")
138
- gender = None
139
-
140
- # fallback
141
- if gender and not category:
142
- category = "Apparel"
143
-
144
- return gender, category, subcategory, color
145
-
146
-
147
- def search_fashion(query: str, alpha: float):
148
- gender, category, subcategory, color = extract_metadata_filters(query)
149
-
150
- # Build Pinecone filter
151
- filter = {}
152
- if gender:
153
- filter["gender"] = gender
154
- if category:
155
- filter["articleType"] = category
156
- if subcategory:
157
- filter["subCategory"] = subcategory
158
- if color:
159
- filter["baseColour"] = color
160
-
161
- print(f"๐Ÿ” Using filter: {filter}")
162
-
163
- # hybrid
164
- sparse = bm25.encode_queries(query)
165
- dense = model.encode(query).tolist()
166
- hdense, hsparse = hybrid_scale(dense, sparse, alpha=alpha)
167
-
168
- # initial search
169
- result = index.query(
170
- top_k=12,
171
- vector=hdense,
172
- sparse_vector=hsparse,
173
- include_metadata=True,
174
- filter=filter if filter else None
175
- )
176
-
177
- # fallback: if zero results with gender, relax gender
178
- if gender and len(result["matches"]) == 0:
179
- print(f"โš ๏ธ No results with gender {gender}, relaxing gender filter")
180
- filter.pop("gender")
181
- result = index.query(
182
- top_k=12,
183
- vector=hdense,
184
- sparse_vector=hsparse,
185
- include_metadata=True,
186
- filter=filter if filter else None
187
- )
188
-
189
- # results
190
- imgs_with_captions = []
191
- for r in result["matches"]:
192
- idx = int(r["id"])
193
- img = images[idx]
194
- meta = r.get("metadata", {})
195
- if not isinstance(img, Image.Image):
196
- img = Image.fromarray(np.array(img))
197
- padded = ImageOps.pad(img, (256, 256), color="white")
198
- caption = str(meta.get("productDisplayName", "Unknown Product"))
199
- imgs_with_captions.append((padded, caption))
200
-
201
- return imgs_with_captions
202
-
203
- # Search by image only
204
- def search_by_image_only(uploaded_image, top_k=12):
205
- if uploaded_image is None:
206
- return []
207
-
208
- uploaded_image = uploaded_image.convert("RGB")
209
- dense_vec = model.encode(uploaded_image).tolist()
210
-
211
- result = index.query(
212
- vector=dense_vec,
213
- top_k=top_k,
214
- include_metadata=True
215
- )
216
-
217
- imgs_with_captions = []
218
- for r in result["matches"]:
219
- idx = int(r["id"])
220
- img = images[idx]
221
- meta = r.get("metadata", {})
222
- if not isinstance(img, Image.Image):
223
- img = Image.fromarray(np.array(img))
224
- padded = ImageOps.pad(img, (256, 256), color="white")
225
- caption = meta.get("productDisplayName", "Unknown Product")
226
- imgs_with_captions.append((padded, caption))
227
-
228
- return imgs_with_captions
229
-
230
- # Gradio UI
231
- import gradio as gr
232
-
233
- def search_fashion(query, alpha):
234
- # Replace this stub with your real hybrid search logic
235
- return [("Image", f"Result from text: {query} with alpha={alpha}") for _ in range(8)]
236
-
237
- def search_by_image_only(image):
238
- # Replace this stub with your real image-based search logic
239
- return [("Image", "Result from image search") for _ in range(6)]
240
-
241
- with gr.Blocks() as demo:
242
- gr.Markdown("# ๐Ÿ›๏ธ Fashion Product Hybrid Search")
243
-
244
- with gr.Row():
245
- with gr.Column():
246
- query = gr.Textbox(label="Enter your fashion search query")
247
- alpha = gr.Slider(0, 1, value=0.5, label="Hybrid Weight (alpha: 0=sparse, 1=dense)")
248
- search_btn = gr.Button("๐Ÿ” Search by Text")
249
- search_results = gr.Gallery(label="Search Results", columns=8, height="40vh")
250
- search_btn.click(fn=search_fashion, inputs=[query, alpha], outputs=search_results)
251
-
252
- with gr.Column():
253
- image_input = gr.Image(source="webcam", type="pil", label="๐Ÿ“ท Capture an Image")
254
- image_search_btn = gr.Button("๐Ÿ” Search by Image")
255
- image_results = gr.Gallery(label="Image-Based Results", columns=6, height="40vh")
256
- image_search_btn.click(fn=search_by_image_only, inputs=image_input, outputs=image_results)
257
-
258
- demo.launch()
259
-
260
-
261
-
262
-
263
- # # ------------------- Imports -------------------
264
-
265
 
266
  # import os
267
- # from pinecone import Pinecone, ServerlessSpec
268
- # from PIL import Image, ImageOps
269
  # import numpy as np
 
270
  # from datasets import load_dataset
 
271
  # from pinecone_text.sparse import BM25Encoder
272
  # from sentence_transformers import SentenceTransformer
273
  # import torch
274
- # from tqdm.auto import tqdm
275
  # import gradio as gr
 
276
 
277
- # # ------------------- Pinecone Setup -------------------
278
  # os.environ["PINECONE_API_KEY"] = "pcsk_TMCYK_LrbmZMTDhkxTjUXcr8iTcQ8LxurwKBFDvv4ahFis8SVob7QexVPPEt6g2zW6d3g"
279
  # api_key = os.environ.get('PINECONE_API_KEY')
280
  # pc = Pinecone(api_key=api_key)
281
 
 
 
 
282
  # index_name = "hybrid-image-search"
283
- # spec = ServerlessSpec(cloud="aws", region="us-east-1")
284
 
 
285
  # if index_name not in pc.list_indexes().names():
286
- # pc.create_index(index_name, dimension=512, metric="dotproduct", spec=spec)
287
- # import time
288
  # while not pc.describe_index(index_name).status['ready']:
289
  # time.sleep(1)
290
 
291
  # index = pc.Index(index_name)
 
292
 
293
- # # ------------------- Dataset Loading -------------------
294
  # fashion = load_dataset("ashraq/fashion-product-images-small", split="train")
295
  # images = fashion["image"]
296
  # metadata = fashion.remove_columns("image").to_pandas()
297
 
298
- # # ------------------- Encoders -------------------
299
  # bm25 = BM25Encoder()
300
- # bm25.fit(metadata["productDisplayName"])
301
- # model = SentenceTransformer('sentence-transformers/clip-ViT-B-32', device='cuda' if torch.cuda.is_available() else 'cpu')
302
 
303
- # # ------------------- Hybrid Scaling -------------------
304
- # def hybrid_scale(dense, sparse, alpha: float):
 
305
 
 
 
306
  # if alpha < 0 or alpha > 1:
307
  # raise ValueError("Alpha must be between 0 and 1")
308
- # # scale sparse and dense vectors to create hybrid search vecs
309
  # hsparse = {
310
  # 'indices': sparse['indices'],
311
- # 'values': [v * (1 - alpha) for v in sparse['values']]
312
  # }
313
  # hdense = [v * alpha for v in dense]
314
  # return hdense, hsparse
315
 
316
-
317
- # # def search_fashion(query: str, alpha: float):
318
- # # sparse = bm25.encode_queries(query)
319
- # # dense = model.encode(query).tolist()
320
- # # hdense, hsparse = hybrid_scale(dense, sparse, alpha=alpha)
321
- # # result = index.query(
322
- # # top_k=8,
323
- # # vector=hdense,
324
- # # sparse_vector=hsparse,
325
- # # include_metadata=True
326
- # # )
327
- # # imgs = [images[int(r["id"])] for r in result["matches"]]
328
- # # return imgs
329
-
330
-
331
- # # ------------------- Metadata Filter Extraction -------------------
332
  # from PIL import Image, ImageOps
333
  # import numpy as np
334
 
@@ -452,7 +179,7 @@ demo.launch()
452
  # print(f"โš ๏ธ No results with gender {gender}, relaxing gender filter")
453
  # filter.pop("gender")
454
  # result = index.query(
455
- # top_k=12,
456
  # vector=hdense,
457
  # sparse_vector=hsparse,
458
  # include_metadata=True,
@@ -472,6 +199,8 @@ demo.launch()
472
  # imgs_with_captions.append((padded, caption))
473
 
474
  # return imgs_with_captions
 
 
475
  # def search_by_image_only(uploaded_image, top_k=12):
476
  # if uploaded_image is None:
477
  # return []
@@ -498,36 +227,335 @@ demo.launch()
498
 
499
  # return imgs_with_captions
500
 
501
-
502
- # # ------------------- Gradio UI -------------------
503
- # custom_css = """
504
- # .search-btn { width: 100%; }
505
- # .gr-row { gap: 8px !important; }
506
- # .query-slider > div { margin-bottom: 4px !important; }
507
- # """
508
  # import gradio as gr
509
 
 
 
 
 
 
 
 
 
510
  # with gr.Blocks() as demo:
511
  # gr.Markdown("# ๐Ÿ›๏ธ Fashion Product Hybrid Search")
512
 
513
- # query = gr.Textbox(label="Enter your fashion search query")
514
- # alpha = gr.Slider(0, 1, value=0.5, label="Hybrid Weight (alpha: 0=sparse, 1=dense)")
515
- # search_btn = gr.Button("Search")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
516
 
517
- # gallery = gr.Gallery(label="Search Results", columns=8, height="40vh")
518
 
519
- # def run_search(q, a):
520
- # return search_fashion(q, a)
 
521
 
522
- # search_btn.click(run_search, inputs=[query, alpha], outputs=gallery)
 
 
 
 
 
523
 
524
- # # โฌ‡๏ธ ADD THIS RIGHT *HERE*, before `demo.launch()`
525
- # gr.Markdown("## ๐Ÿ” Search Visually Similar Products by Uploading an Image")
 
 
 
 
 
 
 
 
 
 
526
 
527
- # image_input = gr.Image(type="pil", label="Upload an image")
528
- # image_search_btn = gr.Button("Search by Image Only")
529
- # image_results = gr.Gallery(label="Image-Based Results", columns=6, height="40vh")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
 
531
- # image_search_btn.click(fn=search_by_image_only, inputs=image_input, outputs=image_results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
 
533
- # demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # """Hybrid Multimodal Vector Search for E-Commerce Product Discovery"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  # import os
4
+ # import time
 
5
  # import numpy as np
6
+ # from PIL import Image, ImageOps
7
  # from datasets import load_dataset
8
+ # from pinecone import Pinecone, ServerlessSpec
9
  # from pinecone_text.sparse import BM25Encoder
10
  # from sentence_transformers import SentenceTransformer
11
  # import torch
 
12
  # import gradio as gr
13
+ # import pandas as pd
14
 
15
+ # # Set Pinecone API Key and config
16
  # os.environ["PINECONE_API_KEY"] = "pcsk_TMCYK_LrbmZMTDhkxTjUXcr8iTcQ8LxurwKBFDvv4ahFis8SVob7QexVPPEt6g2zW6d3g"
17
  # api_key = os.environ.get('PINECONE_API_KEY')
18
  # pc = Pinecone(api_key=api_key)
19
 
20
+ # cloud = os.environ.get('PINECONE_CLOUD', 'aws')
21
+ # region = os.environ.get('PINECONE_REGION', 'us-east-1')
22
+ # spec = ServerlessSpec(cloud=cloud, region=region)
23
  # index_name = "hybrid-image-search"
 
24
 
25
+ # # Create and connect to index
26
  # if index_name not in pc.list_indexes().names():
27
+ # pc.create_index(index_name, dimension=512, metric='dotproduct', spec=spec)
 
28
  # while not pc.describe_index(index_name).status['ready']:
29
  # time.sleep(1)
30
 
31
  # index = pc.Index(index_name)
32
+ # index.describe_index_stats()
33
 
34
+ # # Load dataset
35
  # fashion = load_dataset("ashraq/fashion-product-images-small", split="train")
36
  # images = fashion["image"]
37
  # metadata = fashion.remove_columns("image").to_pandas()
38
 
39
+ # # Fit BM25
40
  # bm25 = BM25Encoder()
41
+ # bm25.fit(metadata['productDisplayName'])
 
42
 
43
+ # # Load CLIP model
44
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
45
+ # model = SentenceTransformer('sentence-transformers/clip-ViT-B-32', device=device)
46
 
47
+ # # Hybrid scaler
48
+ # def hybrid_scale(dense, sparse, alpha: float):
49
  # if alpha < 0 or alpha > 1:
50
  # raise ValueError("Alpha must be between 0 and 1")
 
51
  # hsparse = {
52
  # 'indices': sparse['indices'],
53
+ # 'values': [v * (1 - alpha) for v in sparse['values']]
54
  # }
55
  # hdense = [v * alpha for v in dense]
56
  # return hdense, hsparse
57
 
58
+ # # Metadata filter extractor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  # from PIL import Image, ImageOps
60
  # import numpy as np
61
 
 
179
  # print(f"โš ๏ธ No results with gender {gender}, relaxing gender filter")
180
  # filter.pop("gender")
181
  # result = index.query(
182
+ # top_k=12,
183
  # vector=hdense,
184
  # sparse_vector=hsparse,
185
  # include_metadata=True,
 
199
  # imgs_with_captions.append((padded, caption))
200
 
201
  # return imgs_with_captions
202
+
203
+ # # Search by image only
204
  # def search_by_image_only(uploaded_image, top_k=12):
205
  # if uploaded_image is None:
206
  # return []
 
227
 
228
  # return imgs_with_captions
229
 
230
+ # # Gradio UI
 
 
 
 
 
 
231
  # import gradio as gr
232
 
233
+ # def search_fashion(query, alpha):
234
+ # # Replace this stub with your real hybrid search logic
235
+ # return [("Image", f"Result from text: {query} with alpha={alpha}") for _ in range(8)]
236
+
237
+ # def search_by_image_only(image):
238
+ # # Replace this stub with your real image-based search logic
239
+ # return [("Image", "Result from image search") for _ in range(6)]
240
+
241
  # with gr.Blocks() as demo:
242
  # gr.Markdown("# ๐Ÿ›๏ธ Fashion Product Hybrid Search")
243
 
244
+ # with gr.Row():
245
+ # with gr.Column():
246
+ # query = gr.Textbox(label="Enter your fashion search query")
247
+ # alpha = gr.Slider(0, 1, value=0.5, label="Hybrid Weight (alpha: 0=sparse, 1=dense)")
248
+ # search_btn = gr.Button("๐Ÿ” Search by Text")
249
+ # search_results = gr.Gallery(label="Search Results", columns=8, height="40vh")
250
+ # search_btn.click(fn=search_fashion, inputs=[query, alpha], outputs=search_results)
251
+
252
+ # with gr.Column():
253
+ # image_input = gr.Image(source="webcam", type="pil", label="๐Ÿ“ท Capture an Image")
254
+ # image_search_btn = gr.Button("๐Ÿ” Search by Image")
255
+ # image_results = gr.Gallery(label="Image-Based Results", columns=6, height="40vh")
256
+ # image_search_btn.click(fn=search_by_image_only, inputs=image_input, outputs=image_results)
257
+
258
+ # demo.launch()
259
+
260
+
261
+
262
+
263
+ # ------------------- Imports -------------------
264
+
265
+
266
+ import os
267
+ from pinecone import Pinecone, ServerlessSpec
268
+ from PIL import Image, ImageOps
269
+ import numpy as np
270
+ from datasets import load_dataset
271
+ from pinecone_text.sparse import BM25Encoder
272
+ from sentence_transformers import SentenceTransformer
273
+ import torch
274
+ from tqdm.auto import tqdm
275
+ import gradio as gr
276
+
277
+ # ------------------- Pinecone Setup -------------------
278
+ os.environ["PINECONE_API_KEY"] = "pcsk_TMCYK_LrbmZMTDhkxTjUXcr8iTcQ8LxurwKBFDvv4ahFis8SVob7QexVPPEt6g2zW6d3g"
279
+ api_key = os.environ.get('PINECONE_API_KEY')
280
+ pc = Pinecone(api_key=api_key)
281
+
282
+ index_name = "hybrid-image-search"
283
+ spec = ServerlessSpec(cloud="aws", region="us-east-1")
284
+
285
+ if index_name not in pc.list_indexes().names():
286
+ pc.create_index(index_name, dimension=512, metric="dotproduct", spec=spec)
287
+ import time
288
+ while not pc.describe_index(index_name).status['ready']:
289
+ time.sleep(1)
290
+
291
+ index = pc.Index(index_name)
292
+
293
+ # ------------------- Dataset Loading -------------------
294
+ fashion = load_dataset("ashraq/fashion-product-images-small", split="train")
295
+ images = fashion["image"]
296
+ metadata = fashion.remove_columns("image").to_pandas()
297
+
298
+ # ------------------- Encoders -------------------
299
+ bm25 = BM25Encoder()
300
+ bm25.fit(metadata["productDisplayName"])
301
+ model = SentenceTransformer('sentence-transformers/clip-ViT-B-32', device='cuda' if torch.cuda.is_available() else 'cpu')
302
+
303
+ # ------------------- Hybrid Scaling -------------------
304
+ def hybrid_scale(dense, sparse, alpha: float):
305
+
306
+ if alpha < 0 or alpha > 1:
307
+ raise ValueError("Alpha must be between 0 and 1")
308
+ # scale sparse and dense vectors to create hybrid search vecs
309
+ hsparse = {
310
+ 'indices': sparse['indices'],
311
+ 'values': [v * (1 - alpha) for v in sparse['values']]
312
+ }
313
+ hdense = [v * alpha for v in dense]
314
+ return hdense, hsparse
315
+
316
+
317
+ # def search_fashion(query: str, alpha: float):
318
+ # sparse = bm25.encode_queries(query)
319
+ # dense = model.encode(query).tolist()
320
+ # hdense, hsparse = hybrid_scale(dense, sparse, alpha=alpha)
321
+ # result = index.query(
322
+ # top_k=8,
323
+ # vector=hdense,
324
+ # sparse_vector=hsparse,
325
+ # include_metadata=True
326
+ # )
327
+ # imgs = [images[int(r["id"])] for r in result["matches"]]
328
+ # return imgs
329
 
 
330
 
331
+ # ------------------- Metadata Filter Extraction -------------------
332
+ from PIL import Image, ImageOps
333
+ import numpy as np
334
 
335
+ def extract_metadata_filters(query: str):
336
+ query_lower = query.lower()
337
+ gender = None
338
+ category = None
339
+ subcategory = None
340
+ color = None
341
 
342
+ # --- Gender Mapping ---
343
+ gender_map = {
344
+ "men": "Men", "man": "Men", "mens": "Men", "mans": "Men", "male": "Men",
345
+ "women": "Women", "woman": "Women", "womens": "Women", "female": "Women",
346
+ "boys": "Boys", "boy": "Boys",
347
+ "girls": "Girls", "girl": "Girls",
348
+ "kids": "Kids", "unisex": "Unisex"
349
+ }
350
+ for term, mapped_value in gender_map.items():
351
+ if term in query_lower:
352
+ gender = mapped_value
353
+ break
354
 
355
+ # --- Category Mapping ---
356
+ category_map = {
357
+ "shirt": "Shirts",
358
+ "tshirt": "Tshirts", "t-shirt": "Tshirts",
359
+ "jeans": "Jeans",
360
+ "watch": "Watches",
361
+ "kurta": "Kurtas",
362
+ "dress": "Dresses", "dresses": "Dresses",
363
+ "trousers": "Trousers", "pants": "Trousers",
364
+ "shorts": "Shorts",
365
+ "footwear": "Footwear",
366
+ "shoes": "Footwear",
367
+ "fashion": "Apparel"
368
+ }
369
+ for term, mapped_value in category_map.items():
370
+ if term in query_lower:
371
+ category = mapped_value
372
+ break
373
 
374
+ # --- SubCategory Mapping ---
375
+ subCategory_list = [
376
+ "Accessories", "Apparel Set", "Bags", "Bath and Body", "Beauty Accessories",
377
+ "Belts", "Bottomwear", "Cufflinks", "Dress", "Eyes", "Eyewear", "Flip Flops",
378
+ "Fragrance", "Free Gifts", "Gloves", "Hair", "Headwear", "Home Furnishing",
379
+ "Innerwear", "Jewellery", "Lips", "Loungewear and Nightwear", "Makeup",
380
+ "Mufflers", "Nails", "Perfumes", "Sandal", "Saree", "Scarves", "Shoe Accessories",
381
+ "Shoes", "Skin", "Skin Care", "Socks", "Sports Accessories", "Sports Equipment",
382
+ "Stoles", "Ties", "Topwear", "Umbrellas", "Vouchers", "Wallets", "Watches",
383
+ "Water Bottle", "Wristbands"
384
+ ]
385
+ if "topwear" in query_lower or "top" in query_lower:
386
+ subcategory = "Topwear"
387
+ else:
388
+ for subcat in subCategory_list:
389
+ if subcat.lower() in query_lower:
390
+ subcategory = subcat
391
+ break
392
 
393
+ # --- Color Extraction ---
394
+ colors = [
395
+ "red","blue","green","yellow","black","white",
396
+ "orange","pink","purple","brown","grey","beige"
397
+ ]
398
+ for c in colors:
399
+ if c in query_lower:
400
+ color = c.capitalize()
401
+ break
402
+
403
+ # --- Invalid pairs ---
404
+ invalid_pairs = {
405
+ ("Men", "Dresses"), ("Men", "Sarees"), ("Men", "Skirts"),
406
+ ("Boys", "Dresses"), ("Boys", "Sarees"),
407
+ ("Girls", "Boxers"), ("Men", "Heels")
408
+ }
409
+ if (gender, category) in invalid_pairs:
410
+ print(f"โš ๏ธ Invalid pair: {gender} + {category}, dropping gender")
411
+ gender = None
412
+
413
+ # fallback
414
+ if gender and not category:
415
+ category = "Apparel"
416
+
417
+ return gender, category, subcategory, color
418
+
419
+
420
+ def search_fashion(query: str, alpha: float):
421
+ gender, category, subcategory, color = extract_metadata_filters(query)
422
+
423
+ # Build Pinecone filter
424
+ filter = {}
425
+ if gender:
426
+ filter["gender"] = gender
427
+ if category:
428
+ filter["articleType"] = category
429
+ if subcategory:
430
+ filter["subCategory"] = subcategory
431
+ if color:
432
+ filter["baseColour"] = color
433
+
434
+ print(f"๐Ÿ” Using filter: {filter}")
435
+
436
+ # hybrid
437
+ sparse = bm25.encode_queries(query)
438
+ dense = model.encode(query).tolist()
439
+ hdense, hsparse = hybrid_scale(dense, sparse, alpha=alpha)
440
+
441
+ # initial search
442
+ result = index.query(
443
+ top_k=12,
444
+ vector=hdense,
445
+ sparse_vector=hsparse,
446
+ include_metadata=True,
447
+ filter=filter if filter else None
448
+ )
449
+
450
+ # fallback: if zero results with gender, relax gender
451
+ if gender and len(result["matches"]) == 0:
452
+ print(f"โš ๏ธ No results with gender {gender}, relaxing gender filter")
453
+ filter.pop("gender")
454
+ result = index.query(
455
+ top_k=12,
456
+ vector=hdense,
457
+ sparse_vector=hsparse,
458
+ include_metadata=True,
459
+ filter=filter if filter else None
460
+ )
461
+
462
+ # results
463
+ imgs_with_captions = []
464
+ for r in result["matches"]:
465
+ idx = int(r["id"])
466
+ img = images[idx]
467
+ meta = r.get("metadata", {})
468
+ if not isinstance(img, Image.Image):
469
+ img = Image.fromarray(np.array(img))
470
+ padded = ImageOps.pad(img, (256, 256), color="white")
471
+ caption = str(meta.get("productDisplayName", "Unknown Product"))
472
+ imgs_with_captions.append((padded, caption))
473
+
474
+ return imgs_with_captions
475
+ def search_by_image_only(uploaded_image, top_k=12):
476
+ if uploaded_image is None:
477
+ return []
478
+
479
+ uploaded_image = uploaded_image.convert("RGB")
480
+ dense_vec = model.encode(uploaded_image).tolist()
481
+
482
+ result = index.query(
483
+ vector=dense_vec,
484
+ top_k=top_k,
485
+ include_metadata=True
486
+ )
487
+
488
+ imgs_with_captions = []
489
+ for r in result["matches"]:
490
+ idx = int(r["id"])
491
+ img = images[idx]
492
+ meta = r.get("metadata", {})
493
+ if not isinstance(img, Image.Image):
494
+ img = Image.fromarray(np.array(img))
495
+ padded = ImageOps.pad(img, (256, 256), color="white")
496
+ caption = meta.get("productDisplayName", "Unknown Product")
497
+ imgs_with_captions.append((padded, caption))
498
+
499
+ return imgs_with_captions
500
+
501
+
502
+ # ------------------- Gradio UI -------------------
503
+ custom_css = """
504
+ .search-btn {
505
+ width: 100%;
506
+ }
507
+ .gr-row {
508
+ gap: 8px !important; /* slightly tighter column gap */
509
+ }
510
+ .query-slider > div {
511
+ margin-bottom: 4px !important; /* reduce space between textbox and slider */
512
+ }
513
+ """
514
+
515
+ with gr.Blocks(css=custom_css) as demo:
516
+ gr.Markdown("# ๐Ÿ›๏ธ Fashion Product Hybrid Search")
517
+
518
+ with gr.Row(equal_height=True):
519
+ with gr.Column(scale=5, elem_classes="query-slider"):
520
+ query = gr.Textbox(
521
+ label="Enter your fashion search query",
522
+ placeholder="Type something or leave blank to only use the image"
523
+ )
524
+ alpha = gr.Slider(
525
+ 0, 1, value=0.5,
526
+ label="Hybrid Weight (alpha: 0=sparse, 1=dense)"
527
+ )
528
+ with gr.Column(scale=1):
529
+ image_input = gr.Image(
530
+ type="pil",
531
+ label="Upload an image (optional)",
532
+ height=256,
533
+ width=356,
534
+ show_label=True
535
+ )
536
+
537
+ search_btn = gr.Button("Search", elem_classes="search-btn")
538
+
539
+ gallery = gr.Gallery(
540
+ label="Search Results",
541
+ columns=6,
542
+ height="40vh"
543
+ )
544
+
545
+ def unified_search(q, uploaded_image, a):
546
+ if uploaded_image is not None:
547
+ return search_by_image(uploaded_image, a)
548
+ elif q.strip() != "":
549
+ return search_fashion(q, a)
550
+ else:
551
+ return []
552
+
553
+ search_btn.click(
554
+ unified_search,
555
+ inputs=[query, image_input, alpha],
556
+ outputs=gallery
557
+ )
558
+
559
+ gr.Markdown("Powered by your hybrid AI search model ๐Ÿš€")
560
+
561
+ demo.launch()