sibthinon commited on
Commit
dbd7784
·
verified ·
1 Parent(s): 8933ccb

add rerank

Browse files
Files changed (1) hide show
  1. app.py +27 -3
app.py CHANGED
@@ -13,6 +13,7 @@ from pyairtable import Api
13
  import pickle
14
  import re
15
  import unicodedata
 
16
 
17
  # Setup Qdrant Client
18
  qdrant_client = QdrantClient(
@@ -43,15 +44,23 @@ models = {
43
  "BGE M3": {
44
  "model": SentenceTransformer("BAAI/bge-m3"),
45
  "collection": "product_bge-m3",
46
- "threshold": 0.5,
47
  "prefix": ""
48
  }
49
  }
50
 
 
 
51
  # Utils
52
  def is_non_thai(text):
53
  return re.match(r'^[A-Za-z0-9&\-\s]+$', text) is not None
54
 
 
 
 
 
 
 
55
  def normalize(text: str) -> str:
56
  if is_non_thai(text):
57
  return text.strip()
@@ -90,7 +99,7 @@ def correct_query_merge_phrases(query: str, whitelist, threshold=80, max_ngram=3
90
  if not matched:
91
  corrected.append(tokens[i])
92
  i += 1
93
- return "".join([word for word in corrected if len(word) > 1 or word in whitelist])
94
 
95
  # Global state
96
  latest_query_result = {"query": "", "result": "", "raw_query": "", "time": ""}
@@ -110,6 +119,7 @@ def search_product(query, model_choice):
110
  query_embed = model.encode(prefix + corrected_query)
111
 
112
  try:
 
113
  result = qdrant_client.query_points(
114
  collection_name=collection_name,
115
  query=query_embed.tolist(),
@@ -120,11 +130,25 @@ def search_product(query, model_choice):
120
  except Exception as e:
121
  return f"<p>❌ Qdrant error: {str(e)}</p>"
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  elapsed = time.time() - start_time
124
  html_output = f"<p>⏱ <strong>{elapsed:.2f} วินาที</strong></p>"
125
  if corrected_query != query:
126
  html_output += f"<p>🔧 แก้คำค้นจาก: <code>{query}</code> → <code>{corrected_query}</code></p>"
127
-
128
  html_output += '<div style="display: grid; grid-template-columns: repeat(auto-fill, minmax(220px, 1fr)); gap: 20px;">'
129
  result_summary, found = "", False
130
 
 
13
  import pickle
14
  import re
15
  import unicodedata
16
+ from FlagEmbedding import FlagReranker
17
 
18
  # Setup Qdrant Client
19
  qdrant_client = QdrantClient(
 
44
  "BGE M3": {
45
  "model": SentenceTransformer("BAAI/bge-m3"),
46
  "collection": "product_bge-m3",
47
+ "threshold": 0.45,
48
  "prefix": ""
49
  }
50
  }
51
 
52
+ reranker = FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=True)
53
+
54
  # Utils
55
  def is_non_thai(text):
56
  return re.match(r'^[A-Za-z0-9&\-\s]+$', text) is not None
57
 
58
+ def join_corrected_tokens(corrected: list) -> str:
59
+ if corrected and is_non_thai("".join(corrected)):
60
+ return " ".join([w for w in corrected if len(w) > 1 or w in keyword_whitelist])
61
+ else:
62
+ return "".join([w for w in corrected if len(w) > 1 or w in keyword_whitelist])
63
+
64
  def normalize(text: str) -> str:
65
  if is_non_thai(text):
66
  return text.strip()
 
99
  if not matched:
100
  corrected.append(tokens[i])
101
  i += 1
102
+ return join_corrected_tokens(corrected)
103
 
104
  # Global state
105
  latest_query_result = {"query": "", "result": "", "raw_query": "", "time": ""}
 
119
  query_embed = model.encode(prefix + corrected_query)
120
 
121
  try:
122
+ # 🔍 ดึง top-50 ก่อน rerank
123
  result = qdrant_client.query_points(
124
  collection_name=collection_name,
125
  query=query_embed.tolist(),
 
130
  except Exception as e:
131
  return f"<p>❌ Qdrant error: {str(e)}</p>"
132
 
133
+ # ✅ Rerank Top 10 ด้วย Cross-Encoder (เฉพาะ BGE M3 เท่านั้น)
134
+ if model_choice == "BGE M3" and len(result) > 0:
135
+ topk = 10
136
+ docs = [r.payload.get("name", "") for r in result[:topk]]
137
+ pairs = [[corrected_query, d] for d in docs]
138
+ scores = reranker.compute_score(pairs, normalize=True)
139
+
140
+ # ผสมคะแนน: 0.6 จาก embedding, 0.4 จาก reranker
141
+ result[:topk] = sorted(
142
+ zip(result[:topk], scores),
143
+ key=lambda x: 0.6 * x[0].score + 0.4 * x[1],
144
+ reverse=True
145
+ )
146
+ result[:topk] = [r[0] for r in result[:topk]]
147
+
148
  elapsed = time.time() - start_time
149
  html_output = f"<p>⏱ <strong>{elapsed:.2f} วินาที</strong></p>"
150
  if corrected_query != query:
151
  html_output += f"<p>🔧 แก้คำค้นจาก: <code>{query}</code> → <code>{corrected_query}</code></p>"
 
152
  html_output += '<div style="display: grid; grid-template-columns: repeat(auto-fill, minmax(220px, 1fr)); gap: 20px;">'
153
  result_summary, found = "", False
154