Spaces:
Sleeping
Sleeping
import gradio as gr | |
import time | |
from datetime import datetime | |
import pandas as pd | |
from sentence_transformers import SentenceTransformer | |
from qdrant_client import QdrantClient | |
from qdrant_client.models import Filter, FieldCondition, MatchValue | |
import os | |
from rapidfuzz import process, fuzz | |
from pythainlp.tokenize import word_tokenize | |
from pyairtable import Table | |
from pyairtable import Api | |
import pickle | |
import re | |
import unicodedata | |
qdrant_client = QdrantClient( | |
url=os.environ.get("Qdrant_url"), | |
api_key=os.environ.get("Qdrant_api"), | |
) | |
AIRTABLE_API_KEY = os.environ.get("airtable_api") | |
BASE_ID = os.environ.get("airtable_baseid") | |
TABLE_NAME = "Feedback_search" # หรือเปลี่ยนชื่อให้ชัดเช่น 'Feedback' | |
api = Api(AIRTABLE_API_KEY) | |
table = api.table(BASE_ID, TABLE_NAME) | |
# โมเดลที่โหลดล่วงหน้า | |
models = { | |
"E5 (intfloat/multilingual-e5-small)": SentenceTransformer('intfloat/multilingual-e5-small'), | |
"E5 large instruct (multilingual-e5-large-instruct)": SentenceTransformer("intfloat/multilingual-e5-large-instruct"), | |
"Kalm (KaLM-embedding-multilingual-mini-v1)": SentenceTransformer('HIT-TMG/KaLM-embedding-multilingual-mini-v1') | |
} | |
model_config = { | |
"E5 (intfloat/multilingual-e5-small)": { | |
"func": lambda query: models["E5 (intfloat/multilingual-e5-small)"].encode("query: " + query), | |
"collection": "product_E5", | |
}, | |
"E5 large instruct (multilingual-e5-large-instruct)": { | |
"func": lambda query: models["E5 large instruct (multilingual-e5-large-instruct)"].encode( | |
"Instruct: Given a product search query, retrieve relevant product listings\nQuery: " + query, convert_to_tensor=False, normalize_embeddings=True), | |
"collection": "product_E5_large_instruct", | |
}, | |
"Kalm (KaLM-embedding-multilingual-mini-v1)": { | |
"func": lambda query: models["Kalm (KaLM-embedding-multilingual-mini-v1)"].encode(query, normalize_embeddings=True), | |
"collection": "product_kalm", | |
} | |
} | |
# Global memory to hold feedback state | |
latest_query_result = {"query": "", "result": "", "model": "", "raw_query": "", "time": ""} | |
with open("keyword_whitelist.pkl", "rb") as f: | |
keyword_whitelist = pickle.load(f) | |
def normalize(text: str) -> str: | |
text = unicodedata.normalize("NFC", text) | |
text = text.replace("เแ", "แ").replace("เเ", "แ") | |
return text.strip().lower() | |
def smart_tokenize(text: str) -> list: | |
tokens = word_tokenize(text.strip(), engine="newmm") | |
if not tokens or len("".join(tokens)) < len(text.strip()) * 0.5: | |
return [text.strip()] | |
return tokens | |
def correct_query_merge_phrases(query: str, whitelist, threshold=75, max_ngram=3): | |
query_norm = normalize(query) | |
tokens = smart_tokenize(query_norm) | |
corrected = [] | |
i = 0 | |
while i < len(tokens): | |
matched = False | |
for n in range(min(max_ngram, len(tokens) - i), 0, -1): | |
phrase = "".join(tokens[i:i+n]) | |
match, score, _ = process.extractOne(phrase, whitelist, scorer=fuzz.ratio) | |
if score >= threshold: | |
corrected.append(match) | |
i += n | |
matched = True | |
break | |
if not matched: | |
corrected.append(tokens[i]) | |
i += 1 | |
# ✅ ตัดคำที่มีความยาว 1 ตัวอักษรและไม่ได้อยู่ใน whitelist | |
cleaned = [word for word in corrected if len(word) > 1 or word in whitelist] | |
return " ".join(cleaned) | |
# 🌟 Main search function | |
def search_product(query, model_name): | |
start_time = time.time() | |
if model_name not in model_config: | |
return "<p>❌ ไม่พบโมเดล</p>" | |
latest_query_result["raw_query"] = query | |
corrected_query = correct_query_merge_phrases(query,keyword_whitelist) | |
query_embed = model_config[model_name]["func"](corrected_query) | |
collection_name = model_config[model_name]["collection"] | |
try: | |
result = qdrant_client.query_points( | |
collection_name=collection_name, | |
query=query_embed.tolist(), | |
with_payload=True, | |
query_filter=Filter(must=[FieldCondition(key="type", match=MatchValue(value="product"))]), | |
limit=10 | |
).points | |
except Exception as e: | |
return f"<p>❌ Qdrant error: {str(e)}</p>" | |
elapsed = time.time() - start_time | |
html_output = f"<p>⏱ <strong>{elapsed:.2f} วินาที</strong></p>" | |
if corrected_query != query: | |
html_output += f"<p>🔧 แก้คำค้นจาก: <code>{query}</code> → <code>{corrected_query}</code></p>" | |
html_output += """ | |
<div style="display: grid; grid-template-columns: repeat(auto-fill, minmax(220px, 1fr)); gap: 20px;"> | |
""" | |
result_summary = "" | |
for res in result: | |
name = res.payload.get("name", "ไม่ทราบชื่อสินค้า") | |
score = f"{res.score:.4f}" | |
img_url = res.payload.get("imageUrl", "") | |
price = res.payload.get("price", "ไม่ระบุ") | |
brand = res.payload.get("brand", "") | |
html_output += f""" | |
<div style="border: 1px solid #ddd; border-radius: 8px; padding: 10px; text-align: center; box-shadow: 1px 1px 5px rgba(0,0,0,0.1); background: #fff;"> | |
<img src="{img_url}" style="width: 100%; max-height: 150px; object-fit: contain; border-radius: 4px;"> | |
<div style="margin-top: 10px;"> | |
<div style="font-weight: bold; font-size: 14px;">{name}</div> | |
<div style="color: gray; font-size: 12px;">{brand}</div> | |
<div style="color: green; margin: 4px 0;">฿{price}</div> | |
<div style="font-size: 12px; color: #555;">score: {score}</div> | |
</div> | |
</div> | |
""" | |
result_summary += f"{name} (score: {score}) | " | |
html_output += "</div>" | |
latest_query_result["query"] = corrected_query | |
latest_query_result["result"] = result_summary.strip() | |
latest_query_result["model"] = model_name | |
latest_query_result["time"] = elapsed | |
return html_output | |
# 📝 Logging feedback | |
def log_feedback(feedback): | |
try: | |
now = datetime.now().strftime("%Y-%m-%d") | |
table.create({ | |
"timestamp": now, | |
"raw_query": latest_query_result["raw_query"], | |
"model": latest_query_result["model"], | |
"query": latest_query_result["query"], | |
"result": latest_query_result["result"], | |
"time(second)": latest_query_result["time"], | |
"feedback": feedback | |
}) | |
return "✅ Feedback saved to Airtable!" | |
except Exception as e: | |
return f"❌ Failed to save feedback: {str(e)}" | |
# 🎨 Gradio UI | |
with gr.Blocks() as demo: | |
gr.Markdown("## 🔎 Product Semantic Search (Vector Search + Qdrant)") | |
with gr.Row(): | |
model_selector = gr.Dropdown( | |
choices=list(models.keys()), | |
label="เลือกโมเดล", | |
value="E5 (intfloat/multilingual-e5-small)" | |
) | |
query_input = gr.Textbox(label="พิมพ์คำค้นหา") | |
result_output = gr.HTML(label="📋 ผลลัพธ์") # HTML แสดงผลลัพธ์พร้อมรูป | |
with gr.Row(): | |
match_btn = gr.Button("✅ ตรง") | |
not_match_btn = gr.Button("❌ ไม่ตรง") | |
feedback_status = gr.Textbox(label="📬 สถานะ Feedback") | |
query_input.submit(search_product, inputs=[query_input, model_selector], outputs=result_output) | |
match_btn.click(lambda: log_feedback("match"), outputs=feedback_status) | |
not_match_btn.click(lambda: log_feedback("not_match"), outputs=feedback_status) | |
# Run app | |
demo.launch(share=True) |