sibthinon commited on
Commit
7a2742e
·
verified ·
1 Parent(s): cd3f6c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -17
app.py CHANGED
@@ -1,19 +1,16 @@
1
  import gradio as gr
2
  import time
3
  from datetime import datetime
4
- import pandas as pd
5
  from sentence_transformers import SentenceTransformer
6
  from qdrant_client import QdrantClient
7
  from qdrant_client.models import Filter, FieldCondition, MatchValue
8
  import os
9
- from rapidfuzz import process, fuzz
10
  from pythainlp.tokenize import word_tokenize
11
  from pyairtable import Table
12
  from pyairtable import Api
13
- import pickle
14
  import re
15
  import unicodedata
16
- from FlagEmbedding import FlagReranker
17
 
18
  # Setup Qdrant Client
19
  qdrant_client = QdrantClient(
@@ -34,12 +31,10 @@ model = SentenceTransformer("BAAI/bge-m3")
34
  collection_name = "product_bge-m3"
35
  threshold = 0.45
36
 
37
- reranker = FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=True)
38
-
39
  # Utils
40
  def is_non_thai(text):
41
  return re.match(r'^[A-Za-z0-9&\-\s]+$', text) is not None
42
-
43
  def normalize(text: str) -> str:
44
  if is_non_thai(text):
45
  return text.strip()
@@ -72,16 +67,26 @@ def search_product(query):
72
  return
73
 
74
  if len(result) > 0:
75
- topk = 10
76
- docs = [r.payload.get("name", "") for r in result[:topk]]
77
- pairs = [[corrected_query, d] for d in docs]
78
- scores = reranker.compute_score(pairs, normalize=True)
79
- result[:topk] = sorted(
80
- zip(result[:topk], scores),
81
- key=lambda x: 0.6 * x[0].score + 0.4 * x[1],
82
- reverse=True
83
- )
84
- result[:topk] = [r[0] for r in result[:topk]]
 
 
 
 
 
 
 
 
 
 
85
 
86
  elapsed = time.time() - start_time
87
  html_output = f"<p>⏱ <strong>{elapsed:.2f} วินาที</strong></p>"
 
1
  import gradio as gr
2
  import time
3
  from datetime import datetime
 
4
  from sentence_transformers import SentenceTransformer
5
  from qdrant_client import QdrantClient
6
  from qdrant_client.models import Filter, FieldCondition, MatchValue
7
  import os
8
+ from rapidfuzz import fuzz
9
  from pythainlp.tokenize import word_tokenize
10
  from pyairtable import Table
11
  from pyairtable import Api
 
12
  import re
13
  import unicodedata
 
14
 
15
  # Setup Qdrant Client
16
  qdrant_client = QdrantClient(
 
31
  collection_name = "product_bge-m3"
32
  threshold = 0.45
33
 
 
 
34
  # Utils
35
  def is_non_thai(text):
36
  return re.match(r'^[A-Za-z0-9&\-\s]+$', text) is not None
37
+
38
  def normalize(text: str) -> str:
39
  if is_non_thai(text):
40
  return text.strip()
 
67
  return
68
 
69
  if len(result) > 0:
70
+ topk = 50 # ดึงมา rerank แค่ 50 อันดับแรกจาก Qdrant
71
+ result = result[:topk]
72
+
73
+ scored = []
74
+ for r in result:
75
+ name = r.payload.get("name", "")
76
+
77
+ # ถ้า query สั้นเกินไป ให้ fuzzy_score = 0 เพื่อกันเพี้ยน
78
+ if len(corrected_query) >= 3 and name:
79
+ fuzzy_score = fuzz.partial_ratio(corrected_query, name) / 100.0
80
+ else:
81
+ fuzzy_score = 0.0
82
+
83
+ # รวม hybrid score
84
+ hybrid_score = 0.6 * r.score + 0.4 * fuzzy_score
85
+ scored.append((r, hybrid_score))
86
+
87
+ # เรียงตาม hybrid score แล้วกรองผลลัพธ์ที่ hybrid score ต่ำเกิน
88
+ scored = sorted(scored, key=lambda x: x[1], reverse=True)
89
+ result = [r[0] for r in scored]
90
 
91
  elapsed = time.time() - start_time
92
  html_output = f"<p>⏱ <strong>{elapsed:.2f} วินาที</strong></p>"