Spaces:
Running
Running
import gradio as gr | |
import time | |
from datetime import datetime | |
from visual_bge.modeling import Visualized_BGE | |
from huggingface_hub import hf_hub_download | |
from qdrant_client import QdrantClient | |
from qdrant_client.models import Filter, FieldCondition, MatchValue | |
import os | |
from rapidfuzz import fuzz | |
from pyairtable import Table | |
from pyairtable import Api | |
import re | |
import unicodedata | |
# Setup Qdrant Client | |
qdrant_client = QdrantClient( | |
url=os.environ.get("Qdrant_url"), | |
api_key=os.environ.get("Qdrant_api"), | |
timeout=30.0 | |
) | |
# Airtable Config | |
AIRTABLE_API_KEY = os.environ.get("airtable_api") | |
BASE_ID = os.environ.get("airtable_baseid") | |
TABLE_NAME = "Feedback_search" # use table name | |
api = Api(AIRTABLE_API_KEY) # api to airtable | |
table = api.table(BASE_ID, TABLE_NAME) # choose table | |
# Preload Models | |
model_weight = hf_hub_download(repo_id="BAAI/bge-visualized", filename="Visualized_m3.pth") | |
# Load model | |
model = Visualized_BGE( | |
model_name_bge="BAAI/bge-m3", | |
model_weight=model_weight | |
) | |
collection_name = "product_visual_bge" # setup collection name in qdrant | |
threshold = 0.5 # threshold use when rerank | |
# Utils | |
def is_non_thai(text): # check if english retune true | |
return re.match(r'^[A-Za-z0-9&\-\s]+$', text) is not None | |
def normalize(text: str) -> str: | |
if is_non_thai(text): # send text to check english | |
return text.strip() | |
text = unicodedata.normalize("NFC", text) # change text to unicode | |
return text.replace("เแ", "แ").replace("เเ", "แ").strip().lower() # เเก้กรณีกด เ หลายที | |
# Global state | |
latest_query_result = {"query": "", "result": "", "raw_query": "", "time": ""} # create for send to airtable | |
# Search Function | |
def search_product(query): | |
yield gr.update(value="🔄 กำลังค้นหา..."), "" # when user search | |
start_time = time.time() # start timer | |
latest_query_result["raw_query"] = query # collect user qeary | |
corrected_query = normalize(query) # change query to normalize query | |
query_embed = model.encode(text=corrected_query)[0] # embed corrected_query to vector | |
try: | |
#use qdrant search | |
result = qdrant_client.query_points( | |
collection_name=collection_name, # choose collection in qdrant | |
query=query_embed.tolist(), # vector query | |
with_payload=True, # see payload | |
limit=50 # need 50 product | |
).points | |
except Exception as e: | |
yield gr.update(value="❌ Qdrant error"), f"<p>❌ Qdrant error: {str(e)}</p>" # have problem when search | |
return | |
if len(result) > 0: | |
topk = 50 # ดึงมา rerank แค่ 50 อันดับแรกจาก Qdrant | |
result = result[:topk] | |
scored = [] # use to collect product and score | |
for r in result: | |
name = str(r.payload.get("name", "")).lower() # get name in payload and lowercase | |
brand = str(r.payload.get("brand", "")).lower() # get brand in payload and lowercase | |
query_lower = corrected_query.lower() # lowercase corected_quey | |
# ถ้า query สั้นเกินไป ให้ fuzzy_score = 0 เพื่อกันเพี้ยน | |
if len(corrected_query) >= 3 and name: | |
fuzzy_name_score = fuzz.partial_ratio(query_lower, name) / 100.0 # query compare name score | |
fuzzy_brand_score = fuzz.partial_ratio(query_lower, brand) / 100.0 # query compare brand score | |
else: | |
fuzzy_name_score = 0.0 | |
fuzzy_brand_score = fuzz.partial_ratio(query_lower, brand) / 100.0 | |
# รวม hybrid score | |
if fuzzy_name_score < 0.5: | |
hybrid_score = r.score # not change qdrant score | |
else: | |
hybrid_score = 0.7 * r.score + 0.3 * fuzzy_name_score # use qdrant score 70% and fuzzy name score 30% | |
if fuzzy_brand_score >= 0.8: | |
hybrid_score = hybrid_score*1.2 # มั่นใจว่าถูกเเบรนด์ เพิ่ม score 120% | |
r.payload["score"] = hybrid_score # เก็บลง payload ใช้เทียบ treshold ตอนเเสดงผล | |
r.payload["fuzzy_name_score"] = fuzzy_name_score # เก็บไว้เผื่อ debug | |
r.payload["fuzzy_brand_score"] = fuzzy_brand_score # เก็บไว้เผื่อ debug | |
r.payload['semantic_score'] = r.score # เก็บไว้เผื่อ debug | |
scored.append((r, hybrid_score)) # collect product and hybrid score | |
# เรียงตาม hybrid score แล้วกรองผลลัพธ์ที่ hybrid score ต่ำเกิน | |
scored = sorted(scored, key=lambda x: x[1], reverse=True) # sort | |
result = [r[0] for r in scored] # collect new sort product | |
elapsed = time.time() - start_time # stop search time | |
html_output = f"<p>⏱ <strong>{elapsed:.2f} วินาที</strong></p>" | |
if corrected_query != query: | |
html_output += f"<p>🔧 แก้คำค้นจาก: <code>{query}</code> → <code>{corrected_query}</code></p>" | |
html_output += '<div style="display: grid; grid-template-columns: repeat(auto-fill, minmax(220px, 1fr)); gap: 20px;">' | |
result_summary, found = "", False | |
for res in result: | |
if res.payload["score"] >= threshold: # choose only product score more than threshold | |
found = True # find product | |
name = res.payload.get("name", "ไม่ทราบชื่อสินค้า") | |
score = f"{res.payload['score']:.4f}" | |
img_url = res.payload.get("image_url", "") | |
price = res.payload.get("price", "ไม่ระบุ") | |
brand = res.payload.get("brand", "") | |
html_output += f""" | |
<div style="border: 1px solid #ddd; border-radius: 8px; padding: 10px; text-align: center; box-shadow: 1px 1px 5px rgba(0,0,0,0.1); background: #fff;"> | |
<img src="{img_url}" style="width: 100%; max-height: 150px; object-fit: contain; border-radius: 4px;"> | |
<div style="margin-top: 10px;"> | |
<div style="font-weight: bold; font-size: 14px;">{name}</div> | |
<div style="color: gray; font-size: 12px;">{brand}</div> | |
<div style="color: green; margin: 4px 0;">฿{price}</div> | |
<div style="font-size: 12px; color: #555;">score: {score}</div> | |
</div> | |
</div> | |
""" | |
result_summary += f"{name} (score: {score}) | " | |
html_output += "</div>" | |
if not found: | |
html_output += '<div style="text-align: center; font-size: 18px; color: #a00; padding: 30px;">❌ ไม่พบสินค้าที่เกี่ยวข้องกับคำค้นนี้</div>' | |
latest_query_result.update({ | |
"query": corrected_query, | |
"result": result_summary.strip(), | |
"time": elapsed, | |
}) | |
yield gr.update(value="✅ ค้นหาเสร็จแล้ว!"), html_output | |
# Feedback Function | |
def log_feedback(feedback): | |
try: | |
now = datetime.now().strftime("%Y-%m-%d") | |
# create table for send to airtable | |
# คอลัมน์ต้องตรงกับบน airtable | |
table.create({ | |
"model": "BGE M3", | |
"timestamp": now, | |
"raw_query": latest_query_result["raw_query"], | |
"query": latest_query_result["query"], | |
"result": latest_query_result["result"], | |
"time(second)": latest_query_result["time"], | |
"feedback": feedback | |
}) | |
return "✅ Feedback saved to Airtable!" | |
except Exception as e: | |
return f"❌ Failed to save feedback: {str(e)}" | |
# Gradio UI | |
with gr.Blocks() as demo: | |
gr.Markdown("## 🔎 Product Semantic Search (BGE M3 + Qdrant)") | |
query_input = gr.Textbox(label="พิมพ์คำค้นหา") | |
result_output = gr.HTML(label="📋 ผลลัพธ์") | |
status_output = gr.Textbox(label="🕒 สถานะ", interactive=False) | |
with gr.Row(): | |
match_btn = gr.Button("✅ ตรง") | |
not_match_btn = gr.Button("❌ ไม่ตรง") | |
feedback_status = gr.Textbox(label="📬 สถานะ Feedback") | |
query_input.submit( | |
search_product, | |
inputs=[query_input], | |
outputs=[status_output, result_output] | |
) | |
match_btn.click(fn=lambda: log_feedback("match"), outputs=feedback_status) | |
not_match_btn.click(fn=lambda: log_feedback("not_match"), outputs=feedback_status) | |
demo.launch(share=True) | |