File size: 8,627 Bytes
28b8e02
 
 
80c9031
 
28b8e02
 
 
7a2742e
7c23eb0
 
9ddaa27
39b722c
5629bb7
6d417ec
28b8e02
41cf03d
 
c05c4ca
28b8e02
 
6d417ec
41cf03d
 
80c9031
 
 
7c23eb0
08defce
80c9031
 
 
 
 
 
 
 
08defce
6d417ec
80c9031
374e7c8
7a2742e
9ddaa27
80c9031
374e7c8
80c9031
 
9ddaa27
6d417ec
80c9031
50c341d
08defce
68b12c7
80c9031
79da1d1
80c9031
 
08defce
80c9031
 
5629bb7
 
80c9031
49c543f
80c9031
 
 
 
49c543f
5629bb7
80c9031
79da1d1
28b8e02
68b12c7
7a2742e
 
 
80c9031
7a2742e
80c9031
 
 
4ccade9
7a2742e
 
80c9031
 
7a2742e
6133ede
 
 
7a2742e
6133ede
80c9031
4ccade9
80c9031
6133ede
80c9031
4ccade9
6133ede
 
4ccade9
80c9031
7a2742e
 
80c9031
 
dbd7784
80c9031
8c1aede
49c543f
66a3591
6d417ec
 
08defce
28b8e02
80c9031
 
6d417ec
4ccade9
80c9031
6d417ec
 
 
 
 
 
 
 
 
 
 
 
8c1aede
6d417ec
 
66a3591
8c1aede
28b8e02
c68ca70
6d417ec
28b8e02
6d417ec
 
 
 
 
28b8e02
79da1d1
28b8e02
08defce
68b12c7
7c23eb0
66a3591
80c9031
 
7c23eb0
68b12c7
7c23eb0
cddab55
7c23eb0
 
ef6809f
7c23eb0
 
 
 
 
28b8e02
6d417ec
28b8e02
68b12c7
08defce
68b12c7
6d417ec
79da1d1
28b8e02
 
 
 
 
 
 
79da1d1
 
 
 
 
68b12c7
 
28b8e02
6d417ec
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import gradio as gr
import time
from datetime import datetime
from visual_bge.modeling import Visualized_BGE
from huggingface_hub import hf_hub_download
from qdrant_client import QdrantClient
from qdrant_client.models import Filter, FieldCondition, MatchValue
import os
from rapidfuzz import fuzz
from pyairtable import Table
from pyairtable import Api
import re
import unicodedata

# Setup Qdrant Client
qdrant_client = QdrantClient(
    url=os.environ.get("Qdrant_url"),
    api_key=os.environ.get("Qdrant_api"),
    timeout=30.0
)

# Airtable Config
AIRTABLE_API_KEY = os.environ.get("airtable_api")
BASE_ID = os.environ.get("airtable_baseid")
TABLE_NAME = "Feedback_search" # use table name
api = Api(AIRTABLE_API_KEY) # api to airtable
table = api.table(BASE_ID, TABLE_NAME) # choose table

# Preload Models
model_weight = hf_hub_download(repo_id="BAAI/bge-visualized", filename="Visualized_m3.pth")
# Load model
model = Visualized_BGE(
    model_name_bge="BAAI/bge-m3",
    model_weight=model_weight
)
collection_name = "product_visual_bge" # setup collection name in qdrant
threshold = 0.5 # threshold use when rerank

# Utils
def is_non_thai(text): # check if english retune true
    return re.match(r'^[A-Za-z0-9&\-\s]+$', text) is not None

def normalize(text: str) -> str:
    if is_non_thai(text): # send text to check english
        return text.strip()
    text = unicodedata.normalize("NFC", text) # change text to unicode
    return text.replace("เแ", "แ").replace("เเ", "แ").strip().lower() # เเก้กรณีกด เ หลายที

# Global state
latest_query_result = {"query": "", "result": "", "raw_query": "", "time": ""} # create for send to airtable

# Search Function
def search_product(query):
    yield gr.update(value="🔄 กำลังค้นหา..."), "" # when user search

    start_time = time.time() # start timer
    latest_query_result["raw_query"] = query # collect user qeary

    corrected_query = normalize(query) # change query to normalize query
    query_embed = model.encode(text=corrected_query)[0] # embed corrected_query to vector

    try:
        #use qdrant search
        result = qdrant_client.query_points(
            collection_name=collection_name, # choose collection in qdrant
            query=query_embed.tolist(), # vector query
            with_payload=True, # see payload
            limit=50 # need 50 product
        ).points
    except Exception as e:
        yield gr.update(value="❌ Qdrant error"), f"<p>❌ Qdrant error: {str(e)}</p>" # have problem when search
        return

    if len(result) > 0:
      topk = 50  # ดึงมา rerank แค่ 50 อันดับแรกจาก Qdrant
      result = result[:topk]

      scored = [] # use to collect product and score
      for r in result:
          name = str(r.payload.get("name", "")).lower() # get name in payload and lowercase
          brand = str(r.payload.get("brand", "")).lower() # get brand in payload and lowercase
          query_lower = corrected_query.lower() # lowercase corected_quey

          # ถ้า query สั้นเกินไป ให้ fuzzy_score = 0 เพื่อกันเพี้ยน
          if len(corrected_query) >= 3 and name:
            fuzzy_name_score = fuzz.partial_ratio(query_lower, name) / 100.0 # query compare name score
            fuzzy_brand_score = fuzz.partial_ratio(query_lower, brand) / 100.0 # query compare brand score
          else:
            fuzzy_name_score = 0.0
            fuzzy_brand_score = fuzz.partial_ratio(query_lower, brand) / 100.0

          # รวม hybrid score
          if fuzzy_name_score < 0.5:
            hybrid_score = r.score # not change qdrant score
          else:
            hybrid_score = 0.7 * r.score + 0.3 * fuzzy_name_score # use qdrant score 70% and fuzzy name score 30%
          if fuzzy_brand_score >= 0.8:
            hybrid_score = hybrid_score*1.2 # มั่นใจว่าถูกเเบรนด์ เพิ่ม score 120%
          r.payload["score"] = hybrid_score  # เก็บลง payload ใช้เทียบ treshold ตอนเเสดงผล
          r.payload["fuzzy_name_score"] = fuzzy_name_score # เก็บไว้เผื่อ debug
          r.payload["fuzzy_brand_score"] = fuzzy_brand_score # เก็บไว้เผื่อ debug
          r.payload['semantic_score'] = r.score # เก็บไว้เผื่อ debug
          scored.append((r, hybrid_score)) # collect product and hybrid score

      # เรียงตาม hybrid score แล้วกรองผลลัพธ์ที่ hybrid score ต่ำเกิน
      scored = sorted(scored, key=lambda x: x[1], reverse=True) # sort
      result = [r[0] for r in scored] # collect new sort product

    elapsed = time.time() - start_time # stop search time
    html_output = f"<p>⏱ <strong>{elapsed:.2f} วินาที</strong></p>"
    if corrected_query != query:
        html_output += f"<p>🔧 แก้คำค้นจาก: <code>{query}</code> → <code>{corrected_query}</code></p>"
    html_output += '<div style="display: grid; grid-template-columns: repeat(auto-fill, minmax(220px, 1fr)); gap: 20px;">'
    result_summary, found = "", False

    for res in result:
        if res.payload["score"] >= threshold: # choose only product score more than threshold
            found = True # find product
            name = res.payload.get("name", "ไม่ทราบชื่อสินค้า")
            score = f"{res.payload['score']:.4f}"
            img_url = res.payload.get("image_url", "")
            price = res.payload.get("price", "ไม่ระบุ")
            brand = res.payload.get("brand", "")

            html_output += f"""
            <div style="border: 1px solid #ddd; border-radius: 8px; padding: 10px; text-align: center; box-shadow: 1px 1px 5px rgba(0,0,0,0.1); background: #fff;">
                <img src="{img_url}" style="width: 100%; max-height: 150px; object-fit: contain; border-radius: 4px;">
                <div style="margin-top: 10px;">
                    <div style="font-weight: bold; font-size: 14px;">{name}</div>
                    <div style="color: gray; font-size: 12px;">{brand}</div>
                    <div style="color: green; margin: 4px 0;">฿{price}</div>
                    <div style="font-size: 12px; color: #555;">score: {score}</div>
                </div>
            </div>
            """
            result_summary += f"{name} (score: {score}) | "

    html_output += "</div>"

    if not found:
        html_output += '<div style="text-align: center; font-size: 18px; color: #a00; padding: 30px;">❌ ไม่พบสินค้าที่เกี่ยวข้องกับคำค้นนี้</div>'

    latest_query_result.update({
        "query": corrected_query,
        "result": result_summary.strip(),
        "time": elapsed,
    })

    yield gr.update(value="✅ ค้นหาเสร็จแล้ว!"), html_output

# Feedback Function
def log_feedback(feedback):
    try:
        now = datetime.now().strftime("%Y-%m-%d")
        # create table for send to airtable
        # คอลัมน์ต้องตรงกับบน airtable
        table.create({
            "model": "BGE M3",
            "timestamp": now,
            "raw_query": latest_query_result["raw_query"],
            "query": latest_query_result["query"],
            "result": latest_query_result["result"],
            "time(second)": latest_query_result["time"],
            "feedback": feedback
        })
        return "✅ Feedback saved to Airtable!"
    except Exception as e:
        return f"❌ Failed to save feedback: {str(e)}"

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## 🔎 Product Semantic Search (BGE M3 + Qdrant)")

    query_input = gr.Textbox(label="พิมพ์คำค้นหา")
    result_output = gr.HTML(label="📋 ผลลัพธ์")
    status_output = gr.Textbox(label="🕒 สถานะ", interactive=False)

    with gr.Row():
        match_btn = gr.Button("✅ ตรง")
        not_match_btn = gr.Button("❌ ไม่ตรง")

    feedback_status = gr.Textbox(label="📬 สถานะ Feedback")

    query_input.submit(
        search_product,
        inputs=[query_input],
        outputs=[status_output, result_output]
    )
    match_btn.click(fn=lambda: log_feedback("match"), outputs=feedback_status)
    not_match_btn.click(fn=lambda: log_feedback("not_match"), outputs=feedback_status)

demo.launch(share=True)