environment / app.py
sibthinon's picture
change to model bge visual
80c9031 verified
import gradio as gr
import time
from datetime import datetime
from visual_bge.modeling import Visualized_BGE
from huggingface_hub import hf_hub_download
from qdrant_client import QdrantClient
from qdrant_client.models import Filter, FieldCondition, MatchValue
import os
from rapidfuzz import fuzz
from pyairtable import Table
from pyairtable import Api
import re
import unicodedata
# Setup Qdrant Client
qdrant_client = QdrantClient(
url=os.environ.get("Qdrant_url"),
api_key=os.environ.get("Qdrant_api"),
timeout=30.0
)
# Airtable Config
AIRTABLE_API_KEY = os.environ.get("airtable_api")
BASE_ID = os.environ.get("airtable_baseid")
TABLE_NAME = "Feedback_search" # use table name
api = Api(AIRTABLE_API_KEY) # api to airtable
table = api.table(BASE_ID, TABLE_NAME) # choose table
# Preload Models
model_weight = hf_hub_download(repo_id="BAAI/bge-visualized", filename="Visualized_m3.pth")
# Load model
model = Visualized_BGE(
model_name_bge="BAAI/bge-m3",
model_weight=model_weight
)
collection_name = "product_visual_bge" # setup collection name in qdrant
threshold = 0.5 # threshold use when rerank
# Utils
def is_non_thai(text): # check if english retune true
return re.match(r'^[A-Za-z0-9&\-\s]+$', text) is not None
def normalize(text: str) -> str:
if is_non_thai(text): # send text to check english
return text.strip()
text = unicodedata.normalize("NFC", text) # change text to unicode
return text.replace("เแ", "แ").replace("เเ", "แ").strip().lower() # เเก้กรณีกด เ หลายที
# Global state
latest_query_result = {"query": "", "result": "", "raw_query": "", "time": ""} # create for send to airtable
# Search Function
def search_product(query):
yield gr.update(value="🔄 กำลังค้นหา..."), "" # when user search
start_time = time.time() # start timer
latest_query_result["raw_query"] = query # collect user qeary
corrected_query = normalize(query) # change query to normalize query
query_embed = model.encode(text=corrected_query)[0] # embed corrected_query to vector
try:
#use qdrant search
result = qdrant_client.query_points(
collection_name=collection_name, # choose collection in qdrant
query=query_embed.tolist(), # vector query
with_payload=True, # see payload
limit=50 # need 50 product
).points
except Exception as e:
yield gr.update(value="❌ Qdrant error"), f"<p>❌ Qdrant error: {str(e)}</p>" # have problem when search
return
if len(result) > 0:
topk = 50 # ดึงมา rerank แค่ 50 อันดับแรกจาก Qdrant
result = result[:topk]
scored = [] # use to collect product and score
for r in result:
name = str(r.payload.get("name", "")).lower() # get name in payload and lowercase
brand = str(r.payload.get("brand", "")).lower() # get brand in payload and lowercase
query_lower = corrected_query.lower() # lowercase corected_quey
# ถ้า query สั้นเกินไป ให้ fuzzy_score = 0 เพื่อกันเพี้ยน
if len(corrected_query) >= 3 and name:
fuzzy_name_score = fuzz.partial_ratio(query_lower, name) / 100.0 # query compare name score
fuzzy_brand_score = fuzz.partial_ratio(query_lower, brand) / 100.0 # query compare brand score
else:
fuzzy_name_score = 0.0
fuzzy_brand_score = fuzz.partial_ratio(query_lower, brand) / 100.0
# รวม hybrid score
if fuzzy_name_score < 0.5:
hybrid_score = r.score # not change qdrant score
else:
hybrid_score = 0.7 * r.score + 0.3 * fuzzy_name_score # use qdrant score 70% and fuzzy name score 30%
if fuzzy_brand_score >= 0.8:
hybrid_score = hybrid_score*1.2 # มั่นใจว่าถูกเเบรนด์ เพิ่ม score 120%
r.payload["score"] = hybrid_score # เก็บลง payload ใช้เทียบ treshold ตอนเเสดงผล
r.payload["fuzzy_name_score"] = fuzzy_name_score # เก็บไว้เผื่อ debug
r.payload["fuzzy_brand_score"] = fuzzy_brand_score # เก็บไว้เผื่อ debug
r.payload['semantic_score'] = r.score # เก็บไว้เผื่อ debug
scored.append((r, hybrid_score)) # collect product and hybrid score
# เรียงตาม hybrid score แล้วกรองผลลัพธ์ที่ hybrid score ต่ำเกิน
scored = sorted(scored, key=lambda x: x[1], reverse=True) # sort
result = [r[0] for r in scored] # collect new sort product
elapsed = time.time() - start_time # stop search time
html_output = f"<p>⏱ <strong>{elapsed:.2f} วินาที</strong></p>"
if corrected_query != query:
html_output += f"<p>🔧 แก้คำค้นจาก: <code>{query}</code> → <code>{corrected_query}</code></p>"
html_output += '<div style="display: grid; grid-template-columns: repeat(auto-fill, minmax(220px, 1fr)); gap: 20px;">'
result_summary, found = "", False
for res in result:
if res.payload["score"] >= threshold: # choose only product score more than threshold
found = True # find product
name = res.payload.get("name", "ไม่ทราบชื่อสินค้า")
score = f"{res.payload['score']:.4f}"
img_url = res.payload.get("image_url", "")
price = res.payload.get("price", "ไม่ระบุ")
brand = res.payload.get("brand", "")
html_output += f"""
<div style="border: 1px solid #ddd; border-radius: 8px; padding: 10px; text-align: center; box-shadow: 1px 1px 5px rgba(0,0,0,0.1); background: #fff;">
<img src="{img_url}" style="width: 100%; max-height: 150px; object-fit: contain; border-radius: 4px;">
<div style="margin-top: 10px;">
<div style="font-weight: bold; font-size: 14px;">{name}</div>
<div style="color: gray; font-size: 12px;">{brand}</div>
<div style="color: green; margin: 4px 0;">฿{price}</div>
<div style="font-size: 12px; color: #555;">score: {score}</div>
</div>
</div>
"""
result_summary += f"{name} (score: {score}) | "
html_output += "</div>"
if not found:
html_output += '<div style="text-align: center; font-size: 18px; color: #a00; padding: 30px;">❌ ไม่พบสินค้าที่เกี่ยวข้องกับคำค้นนี้</div>'
latest_query_result.update({
"query": corrected_query,
"result": result_summary.strip(),
"time": elapsed,
})
yield gr.update(value="✅ ค้นหาเสร็จแล้ว!"), html_output
# Feedback Function
def log_feedback(feedback):
try:
now = datetime.now().strftime("%Y-%m-%d")
# create table for send to airtable
# คอลัมน์ต้องตรงกับบน airtable
table.create({
"model": "BGE M3",
"timestamp": now,
"raw_query": latest_query_result["raw_query"],
"query": latest_query_result["query"],
"result": latest_query_result["result"],
"time(second)": latest_query_result["time"],
"feedback": feedback
})
return "✅ Feedback saved to Airtable!"
except Exception as e:
return f"❌ Failed to save feedback: {str(e)}"
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## 🔎 Product Semantic Search (BGE M3 + Qdrant)")
query_input = gr.Textbox(label="พิมพ์คำค้นหา")
result_output = gr.HTML(label="📋 ผลลัพธ์")
status_output = gr.Textbox(label="🕒 สถานะ", interactive=False)
with gr.Row():
match_btn = gr.Button("✅ ตรง")
not_match_btn = gr.Button("❌ ไม่ตรง")
feedback_status = gr.Textbox(label="📬 สถานะ Feedback")
query_input.submit(
search_product,
inputs=[query_input],
outputs=[status_output, result_output]
)
match_btn.click(fn=lambda: log_feedback("match"), outputs=feedback_status)
not_match_btn.click(fn=lambda: log_feedback("not_match"), outputs=feedback_status)
demo.launch(share=True)