Spaces:
Running
Running
File size: 8,627 Bytes
28b8e02 80c9031 28b8e02 7a2742e 7c23eb0 9ddaa27 39b722c 5629bb7 6d417ec 28b8e02 41cf03d c05c4ca 28b8e02 6d417ec 41cf03d 80c9031 7c23eb0 08defce 80c9031 08defce 6d417ec 80c9031 374e7c8 7a2742e 9ddaa27 80c9031 374e7c8 80c9031 9ddaa27 6d417ec 80c9031 50c341d 08defce 68b12c7 80c9031 79da1d1 80c9031 08defce 80c9031 5629bb7 80c9031 49c543f 80c9031 49c543f 5629bb7 80c9031 79da1d1 28b8e02 68b12c7 7a2742e 80c9031 7a2742e 80c9031 4ccade9 7a2742e 80c9031 7a2742e 6133ede 7a2742e 6133ede 80c9031 4ccade9 80c9031 6133ede 80c9031 4ccade9 6133ede 4ccade9 80c9031 7a2742e 80c9031 dbd7784 80c9031 8c1aede 49c543f 66a3591 6d417ec 08defce 28b8e02 80c9031 6d417ec 4ccade9 80c9031 6d417ec 8c1aede 6d417ec 66a3591 8c1aede 28b8e02 c68ca70 6d417ec 28b8e02 6d417ec 28b8e02 79da1d1 28b8e02 08defce 68b12c7 7c23eb0 66a3591 80c9031 7c23eb0 68b12c7 7c23eb0 cddab55 7c23eb0 ef6809f 7c23eb0 28b8e02 6d417ec 28b8e02 68b12c7 08defce 68b12c7 6d417ec 79da1d1 28b8e02 79da1d1 68b12c7 28b8e02 6d417ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import gradio as gr
import time
from datetime import datetime
from visual_bge.modeling import Visualized_BGE
from huggingface_hub import hf_hub_download
from qdrant_client import QdrantClient
from qdrant_client.models import Filter, FieldCondition, MatchValue
import os
from rapidfuzz import fuzz
from pyairtable import Table
from pyairtable import Api
import re
import unicodedata
# Setup Qdrant Client
qdrant_client = QdrantClient(
url=os.environ.get("Qdrant_url"),
api_key=os.environ.get("Qdrant_api"),
timeout=30.0
)
# Airtable Config
AIRTABLE_API_KEY = os.environ.get("airtable_api")
BASE_ID = os.environ.get("airtable_baseid")
TABLE_NAME = "Feedback_search" # use table name
api = Api(AIRTABLE_API_KEY) # api to airtable
table = api.table(BASE_ID, TABLE_NAME) # choose table
# Preload Models
model_weight = hf_hub_download(repo_id="BAAI/bge-visualized", filename="Visualized_m3.pth")
# Load model
model = Visualized_BGE(
model_name_bge="BAAI/bge-m3",
model_weight=model_weight
)
collection_name = "product_visual_bge" # setup collection name in qdrant
threshold = 0.5 # threshold use when rerank
# Utils
def is_non_thai(text): # check if english retune true
return re.match(r'^[A-Za-z0-9&\-\s]+$', text) is not None
def normalize(text: str) -> str:
if is_non_thai(text): # send text to check english
return text.strip()
text = unicodedata.normalize("NFC", text) # change text to unicode
return text.replace("เแ", "แ").replace("เเ", "แ").strip().lower() # เเก้กรณีกด เ หลายที
# Global state
latest_query_result = {"query": "", "result": "", "raw_query": "", "time": ""} # create for send to airtable
# Search Function
def search_product(query):
yield gr.update(value="🔄 กำลังค้นหา..."), "" # when user search
start_time = time.time() # start timer
latest_query_result["raw_query"] = query # collect user qeary
corrected_query = normalize(query) # change query to normalize query
query_embed = model.encode(text=corrected_query)[0] # embed corrected_query to vector
try:
#use qdrant search
result = qdrant_client.query_points(
collection_name=collection_name, # choose collection in qdrant
query=query_embed.tolist(), # vector query
with_payload=True, # see payload
limit=50 # need 50 product
).points
except Exception as e:
yield gr.update(value="❌ Qdrant error"), f"<p>❌ Qdrant error: {str(e)}</p>" # have problem when search
return
if len(result) > 0:
topk = 50 # ดึงมา rerank แค่ 50 อันดับแรกจาก Qdrant
result = result[:topk]
scored = [] # use to collect product and score
for r in result:
name = str(r.payload.get("name", "")).lower() # get name in payload and lowercase
brand = str(r.payload.get("brand", "")).lower() # get brand in payload and lowercase
query_lower = corrected_query.lower() # lowercase corected_quey
# ถ้า query สั้นเกินไป ให้ fuzzy_score = 0 เพื่อกันเพี้ยน
if len(corrected_query) >= 3 and name:
fuzzy_name_score = fuzz.partial_ratio(query_lower, name) / 100.0 # query compare name score
fuzzy_brand_score = fuzz.partial_ratio(query_lower, brand) / 100.0 # query compare brand score
else:
fuzzy_name_score = 0.0
fuzzy_brand_score = fuzz.partial_ratio(query_lower, brand) / 100.0
# รวม hybrid score
if fuzzy_name_score < 0.5:
hybrid_score = r.score # not change qdrant score
else:
hybrid_score = 0.7 * r.score + 0.3 * fuzzy_name_score # use qdrant score 70% and fuzzy name score 30%
if fuzzy_brand_score >= 0.8:
hybrid_score = hybrid_score*1.2 # มั่นใจว่าถูกเเบรนด์ เพิ่ม score 120%
r.payload["score"] = hybrid_score # เก็บลง payload ใช้เทียบ treshold ตอนเเสดงผล
r.payload["fuzzy_name_score"] = fuzzy_name_score # เก็บไว้เผื่อ debug
r.payload["fuzzy_brand_score"] = fuzzy_brand_score # เก็บไว้เผื่อ debug
r.payload['semantic_score'] = r.score # เก็บไว้เผื่อ debug
scored.append((r, hybrid_score)) # collect product and hybrid score
# เรียงตาม hybrid score แล้วกรองผลลัพธ์ที่ hybrid score ต่ำเกิน
scored = sorted(scored, key=lambda x: x[1], reverse=True) # sort
result = [r[0] for r in scored] # collect new sort product
elapsed = time.time() - start_time # stop search time
html_output = f"<p>⏱ <strong>{elapsed:.2f} วินาที</strong></p>"
if corrected_query != query:
html_output += f"<p>🔧 แก้คำค้นจาก: <code>{query}</code> → <code>{corrected_query}</code></p>"
html_output += '<div style="display: grid; grid-template-columns: repeat(auto-fill, minmax(220px, 1fr)); gap: 20px;">'
result_summary, found = "", False
for res in result:
if res.payload["score"] >= threshold: # choose only product score more than threshold
found = True # find product
name = res.payload.get("name", "ไม่ทราบชื่อสินค้า")
score = f"{res.payload['score']:.4f}"
img_url = res.payload.get("image_url", "")
price = res.payload.get("price", "ไม่ระบุ")
brand = res.payload.get("brand", "")
html_output += f"""
<div style="border: 1px solid #ddd; border-radius: 8px; padding: 10px; text-align: center; box-shadow: 1px 1px 5px rgba(0,0,0,0.1); background: #fff;">
<img src="{img_url}" style="width: 100%; max-height: 150px; object-fit: contain; border-radius: 4px;">
<div style="margin-top: 10px;">
<div style="font-weight: bold; font-size: 14px;">{name}</div>
<div style="color: gray; font-size: 12px;">{brand}</div>
<div style="color: green; margin: 4px 0;">฿{price}</div>
<div style="font-size: 12px; color: #555;">score: {score}</div>
</div>
</div>
"""
result_summary += f"{name} (score: {score}) | "
html_output += "</div>"
if not found:
html_output += '<div style="text-align: center; font-size: 18px; color: #a00; padding: 30px;">❌ ไม่พบสินค้าที่เกี่ยวข้องกับคำค้นนี้</div>'
latest_query_result.update({
"query": corrected_query,
"result": result_summary.strip(),
"time": elapsed,
})
yield gr.update(value="✅ ค้นหาเสร็จแล้ว!"), html_output
# Feedback Function
def log_feedback(feedback):
try:
now = datetime.now().strftime("%Y-%m-%d")
# create table for send to airtable
# คอลัมน์ต้องตรงกับบน airtable
table.create({
"model": "BGE M3",
"timestamp": now,
"raw_query": latest_query_result["raw_query"],
"query": latest_query_result["query"],
"result": latest_query_result["result"],
"time(second)": latest_query_result["time"],
"feedback": feedback
})
return "✅ Feedback saved to Airtable!"
except Exception as e:
return f"❌ Failed to save feedback: {str(e)}"
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## 🔎 Product Semantic Search (BGE M3 + Qdrant)")
query_input = gr.Textbox(label="พิมพ์คำค้นหา")
result_output = gr.HTML(label="📋 ผลลัพธ์")
status_output = gr.Textbox(label="🕒 สถานะ", interactive=False)
with gr.Row():
match_btn = gr.Button("✅ ตรง")
not_match_btn = gr.Button("❌ ไม่ตรง")
feedback_status = gr.Textbox(label="📬 สถานะ Feedback")
query_input.submit(
search_product,
inputs=[query_input],
outputs=[status_output, result_output]
)
match_btn.click(fn=lambda: log_feedback("match"), outputs=feedback_status)
not_match_btn.click(fn=lambda: log_feedback("not_match"), outputs=feedback_status)
demo.launch(share=True)
|