sibthinon commited on
Commit
68b12c7
·
verified ·
1 Parent(s): fd2a235

use only model bge

Browse files
Files changed (1) hide show
  1. app.py +14 -37
app.py CHANGED
@@ -34,20 +34,9 @@ with open("keyword_whitelist.pkl", "rb") as f:
34
  keyword_whitelist = pickle.load(f)
35
 
36
  # Preload Models
37
- models = {
38
- "E5 Finetuned": {
39
- "model": SentenceTransformer("e5_finetuned"),
40
- "collection": "product_E5_finetune",
41
- "threshold": 0.8,
42
- "prefix": "query: "
43
- },
44
- "BGE M3": {
45
- "model": SentenceTransformer("BAAI/bge-m3"),
46
- "collection": "product_bge-m3",
47
- "threshold": 0.45,
48
- "prefix": ""
49
- }
50
- }
51
 
52
  reranker = FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=True)
53
 
@@ -105,21 +94,14 @@ def correct_query_merge_phrases(query: str, whitelist, threshold=80, max_ngram=3
105
  latest_query_result = {"query": "", "result": "", "raw_query": "", "time": ""}
106
 
107
  # Search Function
108
- def search_product(query, model_choice):
109
  start_time = time.time()
110
  latest_query_result["raw_query"] = query
111
 
112
- selected = models[model_choice]
113
- model = selected["model"]
114
- collection_name = selected["collection"]
115
- threshold = selected["threshold"]
116
- prefix = selected["prefix"]
117
-
118
  corrected_query = correct_query_merge_phrases(query, keyword_whitelist)
119
- query_embed = model.encode(prefix + corrected_query)
120
 
121
  try:
122
- # 🔍 ดึง top-50 ก่อน rerank
123
  result = qdrant_client.query_points(
124
  collection_name=collection_name,
125
  query=query_embed.tolist(),
@@ -130,14 +112,12 @@ def search_product(query, model_choice):
130
  except Exception as e:
131
  return f"<p>❌ Qdrant error: {str(e)}</p>"
132
 
133
- # ✅ Rerank Top 10 ด้วย Cross-Encoder (เฉพาะ BGE M3 เท่านั้น)
134
- if model_choice == "BGE M3" and len(result) > 0:
135
  topk = 10
136
  docs = [r.payload.get("name", "") for r in result[:topk]]
137
  pairs = [[corrected_query, d] for d in docs]
138
  scores = reranker.compute_score(pairs, normalize=True)
139
-
140
- # ผสมคะแนน: 0.6 จาก embedding, 0.4 จาก reranker
141
  result[:topk] = sorted(
142
  zip(result[:topk], scores),
143
  key=lambda x: 0.6 * x[0].score + 0.4 * x[1],
@@ -189,11 +169,11 @@ def search_product(query, model_choice):
189
  return html_output
190
 
191
  # Feedback Function
192
- def log_feedback(feedback, model_choice):
193
  try:
194
  now = datetime.now().strftime("%Y-%m-%d")
195
  table.create({
196
- "model": model_choice,
197
  "timestamp": now,
198
  "raw_query": latest_query_result["raw_query"],
199
  "query": latest_query_result["query"],
@@ -207,12 +187,9 @@ def log_feedback(feedback, model_choice):
207
 
208
  # Gradio UI
209
  with gr.Blocks() as demo:
210
- gr.Markdown("## 🔎 Product Semantic Search (Vector Search + Qdrant)")
211
-
212
- with gr.Row():
213
- model_selector = gr.Dropdown(label="🔍 เลือกโมเดล", choices=list(models.keys()), value="E5 Finetuned")
214
- query_input = gr.Textbox(label="พิมพ์คำค้นหา")
215
 
 
216
  result_output = gr.HTML(label="📋 ผลลัพธ์")
217
 
218
  with gr.Row():
@@ -221,8 +198,8 @@ with gr.Blocks() as demo:
221
 
222
  feedback_status = gr.Textbox(label="📬 สถานะ Feedback")
223
 
224
- query_input.submit(search_product, inputs=[query_input, model_selector], outputs=result_output)
225
- match_btn.click(fn=lambda model: log_feedback("match", model), inputs=model_selector, outputs=feedback_status)
226
- not_match_btn.click(fn=lambda model: log_feedback("not_match", model), inputs=model_selector, outputs=feedback_status)
227
 
228
  demo.launch(share=True)
 
34
  keyword_whitelist = pickle.load(f)
35
 
36
  # Preload Models
37
+ model = SentenceTransformer("BAAI/bge-m3")
38
+ collection_name = "product_bge-m3"
39
+ threshold = 0.45
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  reranker = FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=True)
42
 
 
94
  latest_query_result = {"query": "", "result": "", "raw_query": "", "time": ""}
95
 
96
  # Search Function
97
+ def search_product(query):
98
  start_time = time.time()
99
  latest_query_result["raw_query"] = query
100
 
 
 
 
 
 
 
101
  corrected_query = correct_query_merge_phrases(query, keyword_whitelist)
102
+ query_embed = model.encode(corrected_query)
103
 
104
  try:
 
105
  result = qdrant_client.query_points(
106
  collection_name=collection_name,
107
  query=query_embed.tolist(),
 
112
  except Exception as e:
113
  return f"<p>❌ Qdrant error: {str(e)}</p>"
114
 
115
+ # ✅ Rerank Top 10
116
+ if len(result) > 0:
117
  topk = 10
118
  docs = [r.payload.get("name", "") for r in result[:topk]]
119
  pairs = [[corrected_query, d] for d in docs]
120
  scores = reranker.compute_score(pairs, normalize=True)
 
 
121
  result[:topk] = sorted(
122
  zip(result[:topk], scores),
123
  key=lambda x: 0.6 * x[0].score + 0.4 * x[1],
 
169
  return html_output
170
 
171
  # Feedback Function
172
+ def log_feedback(feedback):
173
  try:
174
  now = datetime.now().strftime("%Y-%m-%d")
175
  table.create({
176
+ "model": "BGE M3",
177
  "timestamp": now,
178
  "raw_query": latest_query_result["raw_query"],
179
  "query": latest_query_result["query"],
 
187
 
188
  # Gradio UI
189
  with gr.Blocks() as demo:
190
+ gr.Markdown("## 🔎 Product Semantic Search (BGE M3 + Qdrant)")
 
 
 
 
191
 
192
+ query_input = gr.Textbox(label="พิมพ์คำค้นหา")
193
  result_output = gr.HTML(label="📋 ผลลัพธ์")
194
 
195
  with gr.Row():
 
198
 
199
  feedback_status = gr.Textbox(label="📬 สถานะ Feedback")
200
 
201
+ query_input.submit(search_product, inputs=[query_input], outputs=result_output)
202
+ match_btn.click(fn=lambda: log_feedback("match"), outputs=feedback_status)
203
+ not_match_btn.click(fn=lambda: log_feedback("not_match"), outputs=feedback_status)
204
 
205
  demo.launch(share=True)