Spaces:
Running
Running
add rerank
Browse files
app.py
CHANGED
@@ -13,6 +13,7 @@ from pyairtable import Api
|
|
13 |
import pickle
|
14 |
import re
|
15 |
import unicodedata
|
|
|
16 |
|
17 |
# Setup Qdrant Client
|
18 |
qdrant_client = QdrantClient(
|
@@ -43,15 +44,23 @@ models = {
|
|
43 |
"BGE M3": {
|
44 |
"model": SentenceTransformer("BAAI/bge-m3"),
|
45 |
"collection": "product_bge-m3",
|
46 |
-
"threshold": 0.
|
47 |
"prefix": ""
|
48 |
}
|
49 |
}
|
50 |
|
|
|
|
|
51 |
# Utils
|
52 |
def is_non_thai(text):
|
53 |
return re.match(r'^[A-Za-z0-9&\-\s]+$', text) is not None
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
def normalize(text: str) -> str:
|
56 |
if is_non_thai(text):
|
57 |
return text.strip()
|
@@ -90,7 +99,7 @@ def correct_query_merge_phrases(query: str, whitelist, threshold=80, max_ngram=3
|
|
90 |
if not matched:
|
91 |
corrected.append(tokens[i])
|
92 |
i += 1
|
93 |
-
return
|
94 |
|
95 |
# Global state
|
96 |
latest_query_result = {"query": "", "result": "", "raw_query": "", "time": ""}
|
@@ -110,6 +119,7 @@ def search_product(query, model_choice):
|
|
110 |
query_embed = model.encode(prefix + corrected_query)
|
111 |
|
112 |
try:
|
|
|
113 |
result = qdrant_client.query_points(
|
114 |
collection_name=collection_name,
|
115 |
query=query_embed.tolist(),
|
@@ -120,11 +130,25 @@ def search_product(query, model_choice):
|
|
120 |
except Exception as e:
|
121 |
return f"<p>❌ Qdrant error: {str(e)}</p>"
|
122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
elapsed = time.time() - start_time
|
124 |
html_output = f"<p>⏱ <strong>{elapsed:.2f} วินาที</strong></p>"
|
125 |
if corrected_query != query:
|
126 |
html_output += f"<p>🔧 แก้คำค้นจาก: <code>{query}</code> → <code>{corrected_query}</code></p>"
|
127 |
-
|
128 |
html_output += '<div style="display: grid; grid-template-columns: repeat(auto-fill, minmax(220px, 1fr)); gap: 20px;">'
|
129 |
result_summary, found = "", False
|
130 |
|
|
|
13 |
import pickle
|
14 |
import re
|
15 |
import unicodedata
|
16 |
+
from FlagEmbedding import FlagReranker
|
17 |
|
18 |
# Setup Qdrant Client
|
19 |
qdrant_client = QdrantClient(
|
|
|
44 |
"BGE M3": {
|
45 |
"model": SentenceTransformer("BAAI/bge-m3"),
|
46 |
"collection": "product_bge-m3",
|
47 |
+
"threshold": 0.45,
|
48 |
"prefix": ""
|
49 |
}
|
50 |
}
|
51 |
|
52 |
+
reranker = FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=True)
|
53 |
+
|
54 |
# Utils
|
55 |
def is_non_thai(text):
|
56 |
return re.match(r'^[A-Za-z0-9&\-\s]+$', text) is not None
|
57 |
|
58 |
+
def join_corrected_tokens(corrected: list) -> str:
|
59 |
+
if corrected and is_non_thai("".join(corrected)):
|
60 |
+
return " ".join([w for w in corrected if len(w) > 1 or w in keyword_whitelist])
|
61 |
+
else:
|
62 |
+
return "".join([w for w in corrected if len(w) > 1 or w in keyword_whitelist])
|
63 |
+
|
64 |
def normalize(text: str) -> str:
|
65 |
if is_non_thai(text):
|
66 |
return text.strip()
|
|
|
99 |
if not matched:
|
100 |
corrected.append(tokens[i])
|
101 |
i += 1
|
102 |
+
return join_corrected_tokens(corrected)
|
103 |
|
104 |
# Global state
|
105 |
latest_query_result = {"query": "", "result": "", "raw_query": "", "time": ""}
|
|
|
119 |
query_embed = model.encode(prefix + corrected_query)
|
120 |
|
121 |
try:
|
122 |
+
# 🔍 ดึง top-50 ก่อน rerank
|
123 |
result = qdrant_client.query_points(
|
124 |
collection_name=collection_name,
|
125 |
query=query_embed.tolist(),
|
|
|
130 |
except Exception as e:
|
131 |
return f"<p>❌ Qdrant error: {str(e)}</p>"
|
132 |
|
133 |
+
# ✅ Rerank Top 10 ด้วย Cross-Encoder (เฉพาะ BGE M3 เท่านั้น)
|
134 |
+
if model_choice == "BGE M3" and len(result) > 0:
|
135 |
+
topk = 10
|
136 |
+
docs = [r.payload.get("name", "") for r in result[:topk]]
|
137 |
+
pairs = [[corrected_query, d] for d in docs]
|
138 |
+
scores = reranker.compute_score(pairs, normalize=True)
|
139 |
+
|
140 |
+
# ผสมคะแนน: 0.6 จาก embedding, 0.4 จาก reranker
|
141 |
+
result[:topk] = sorted(
|
142 |
+
zip(result[:topk], scores),
|
143 |
+
key=lambda x: 0.6 * x[0].score + 0.4 * x[1],
|
144 |
+
reverse=True
|
145 |
+
)
|
146 |
+
result[:topk] = [r[0] for r in result[:topk]]
|
147 |
+
|
148 |
elapsed = time.time() - start_time
|
149 |
html_output = f"<p>⏱ <strong>{elapsed:.2f} วินาที</strong></p>"
|
150 |
if corrected_query != query:
|
151 |
html_output += f"<p>🔧 แก้คำค้นจาก: <code>{query}</code> → <code>{corrected_query}</code></p>"
|
|
|
152 |
html_output += '<div style="display: grid; grid-template-columns: repeat(auto-fill, minmax(220px, 1fr)); gap: 20px;">'
|
153 |
result_summary, found = "", False
|
154 |
|