sibthinon commited on
Commit
08defce
·
verified ·
1 Parent(s): 1a14f7c

add more to can select model

Browse files
Files changed (1) hide show
  1. app.py +38 -17
app.py CHANGED
@@ -28,14 +28,26 @@ TABLE_NAME = "Feedback_search"
28
  api = Api(AIRTABLE_API_KEY)
29
  table = api.table(BASE_ID, TABLE_NAME)
30
 
31
- # Load model
32
- model = SentenceTransformer('e5_finetuned')
33
- collection_name = "product_E5_finetune"
34
-
35
  # Load whitelist
36
  with open("keyword_whitelist.pkl", "rb") as f:
37
  keyword_whitelist = pickle.load(f)
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  # Utils
40
  def is_non_thai(text):
41
  return re.match(r'^[A-Za-z0-9&\-\s]+$', text) is not None
@@ -83,12 +95,19 @@ def correct_query_merge_phrases(query: str, whitelist, threshold=80, max_ngram=3
83
  # Global state
84
  latest_query_result = {"query": "", "result": "", "raw_query": "", "time": ""}
85
 
86
- # Main Search
87
- def search_product(query):
88
  start_time = time.time()
89
  latest_query_result["raw_query"] = query
 
 
 
 
 
 
 
90
  corrected_query = correct_query_merge_phrases(query, keyword_whitelist)
91
- query_embed = model.encode("query: " + corrected_query)
92
 
93
  try:
94
  result = qdrant_client.query_points(
@@ -107,10 +126,10 @@ def search_product(query):
107
  html_output += f"<p>🔧 แก้คำค้นจาก: <code>{query}</code> → <code>{corrected_query}</code></p>"
108
 
109
  html_output += '<div style="display: grid; grid-template-columns: repeat(auto-fill, minmax(220px, 1fr)); gap: 20px;">'
110
-
111
  result_summary, found = "", False
 
112
  for res in result:
113
- if res.score > 0.8:
114
  found = True
115
  name = res.payload.get("name", "ไม่ทราบชื่อสินค้า")
116
  score = f"{res.score:.4f}"
@@ -145,12 +164,12 @@ def search_product(query):
145
 
146
  return html_output
147
 
148
- # Feedback logging
149
- def log_feedback(feedback):
150
  try:
151
  now = datetime.now().strftime("%Y-%m-%d")
152
  table.create({
153
- "model": "E5 (intfloat/multilingual-e5-small)",
154
  "timestamp": now,
155
  "raw_query": latest_query_result["raw_query"],
156
  "query": latest_query_result["query"],
@@ -166,7 +185,10 @@ def log_feedback(feedback):
166
  with gr.Blocks() as demo:
167
  gr.Markdown("## 🔎 Product Semantic Search (Vector Search + Qdrant)")
168
 
169
- query_input = gr.Textbox(label="พิมพ์คำค้นหา")
 
 
 
170
  result_output = gr.HTML(label="📋 ผลลัพธ์")
171
 
172
  with gr.Row():
@@ -175,9 +197,8 @@ with gr.Blocks() as demo:
175
 
176
  feedback_status = gr.Textbox(label="📬 สถานะ Feedback")
177
 
178
- query_input.submit(search_product, inputs=[query_input], outputs=result_output)
179
- match_btn.click(lambda: log_feedback("match"), outputs=feedback_status)
180
- not_match_btn.click(lambda: log_feedback("not_match"), outputs=feedback_status)
181
 
182
- # Run
183
  demo.launch(share=True)
 
28
  api = Api(AIRTABLE_API_KEY)
29
  table = api.table(BASE_ID, TABLE_NAME)
30
 
 
 
 
 
31
  # Load whitelist
32
  with open("keyword_whitelist.pkl", "rb") as f:
33
  keyword_whitelist = pickle.load(f)
34
 
35
+ # Preload Models
36
+ models = {
37
+ "E5 Finetuned": {
38
+ "model": SentenceTransformer("e5_finetuned"),
39
+ "collection": "product_E5_finetune",
40
+ "threshold": 0.8,
41
+ "prefix": "query: "
42
+ },
43
+ "BGE M3": {
44
+ "model": SentenceTransformer("BAAI/bge-m3"),
45
+ "collection": "product_bge-m3",
46
+ "threshold": 0.5,
47
+ "prefix": ""
48
+ }
49
+ }
50
+
51
  # Utils
52
  def is_non_thai(text):
53
  return re.match(r'^[A-Za-z0-9&\-\s]+$', text) is not None
 
95
  # Global state
96
  latest_query_result = {"query": "", "result": "", "raw_query": "", "time": ""}
97
 
98
+ # Search Function
99
+ def search_product(query, model_choice):
100
  start_time = time.time()
101
  latest_query_result["raw_query"] = query
102
+
103
+ selected = models[model_choice]
104
+ model = selected["model"]
105
+ collection_name = selected["collection"]
106
+ threshold = selected["threshold"]
107
+ prefix = selected["prefix"]
108
+
109
  corrected_query = correct_query_merge_phrases(query, keyword_whitelist)
110
+ query_embed = model.encode(prefix + corrected_query)
111
 
112
  try:
113
  result = qdrant_client.query_points(
 
126
  html_output += f"<p>🔧 แก้คำค้นจาก: <code>{query}</code> → <code>{corrected_query}</code></p>"
127
 
128
  html_output += '<div style="display: grid; grid-template-columns: repeat(auto-fill, minmax(220px, 1fr)); gap: 20px;">'
 
129
  result_summary, found = "", False
130
+
131
  for res in result:
132
+ if res.score >= threshold:
133
  found = True
134
  name = res.payload.get("name", "ไม่ทราบชื่อสินค้า")
135
  score = f"{res.score:.4f}"
 
164
 
165
  return html_output
166
 
167
+ # Feedback Function
168
+ def log_feedback(feedback, model_choice):
169
  try:
170
  now = datetime.now().strftime("%Y-%m-%d")
171
  table.create({
172
+ "model": model_choice,
173
  "timestamp": now,
174
  "raw_query": latest_query_result["raw_query"],
175
  "query": latest_query_result["query"],
 
185
  with gr.Blocks() as demo:
186
  gr.Markdown("## 🔎 Product Semantic Search (Vector Search + Qdrant)")
187
 
188
+ with gr.Row():
189
+ model_selector = gr.Dropdown(label="🔍 เลือกโมเดล", choices=list(models.keys()), value="E5 Finetuned")
190
+ query_input = gr.Textbox(label="พิมพ์คำค้นหา")
191
+
192
  result_output = gr.HTML(label="📋 ผลลัพธ์")
193
 
194
  with gr.Row():
 
197
 
198
  feedback_status = gr.Textbox(label="📬 สถานะ Feedback")
199
 
200
+ query_input.submit(search_product, inputs=[query_input, model_selector], outputs=result_output)
201
+ match_btn.click(fn=lambda model: log_feedback("match", model), inputs=model_selector, outputs=feedback_status)
202
+ not_match_btn.click(fn=lambda model: log_feedback("not_match", model), inputs=model_selector, outputs=feedback_status)
203
 
 
204
  demo.launch(share=True)