environment / app.py
sibthinon's picture
Update app.py
41cf03d verified
raw
history blame
7.85 kB
import gradio as gr
import time
from datetime import datetime
import pandas as pd
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from qdrant_client.models import Filter, FieldCondition, MatchValue
import os
from rapidfuzz import process, fuzz
from pythainlp.tokenize import word_tokenize
from pyairtable import Table
from pyairtable import Api
import pickle
import re
import unicodedata
qdrant_client = QdrantClient(
url=os.environ.get("Qdrant_url"),
api_key=os.environ.get("Qdrant_api"),
)
AIRTABLE_API_KEY = os.environ.get("airtable_api")
BASE_ID = os.environ.get("airtable_baseid")
TABLE_NAME = "Feedback_search" # หรือเปลี่ยนชื่อให้ชัดเช่น 'Feedback'
api = Api(AIRTABLE_API_KEY)
table = api.table(BASE_ID, TABLE_NAME)
# โมเดลที่โหลดล่วงหน้า
models = {
"E5 (intfloat/multilingual-e5-small)": SentenceTransformer('intfloat/multilingual-e5-small'),
"E5 large instruct (multilingual-e5-large-instruct)": SentenceTransformer("intfloat/multilingual-e5-large-instruct"),
"Kalm (KaLM-embedding-multilingual-mini-v1)": SentenceTransformer('HIT-TMG/KaLM-embedding-multilingual-mini-v1')
}
model_config = {
"E5 (intfloat/multilingual-e5-small)": {
"func": lambda query: models["E5 (intfloat/multilingual-e5-small)"].encode("query: " + query),
"collection": "product_E5",
},
"E5 large instruct (multilingual-e5-large-instruct)": {
"func": lambda query: models["E5 large instruct (multilingual-e5-large-instruct)"].encode(
"Instruct: Given a product search query, retrieve relevant product listings\nQuery: " + query, convert_to_tensor=False, normalize_embeddings=True),
"collection": "product_E5_large_instruct",
},
"Kalm (KaLM-embedding-multilingual-mini-v1)": {
"func": lambda query: models["Kalm (KaLM-embedding-multilingual-mini-v1)"].encode(query, normalize_embeddings=True),
"collection": "product_kalm",
}
}
# Global memory to hold feedback state
latest_query_result = {"query": "", "result": "", "model": "", "raw_query": "", "time": ""}
with open("keyword_whitelist.pkl", "rb") as f:
keyword_whitelist = pickle.load(f)
def normalize(text: str) -> str:
text = unicodedata.normalize("NFC", text)
text = text.replace("เแ", "แ").replace("เเ", "แ")
return text.strip().lower()
def smart_tokenize(text: str) -> list:
tokens = word_tokenize(text.strip(), engine="newmm")
if not tokens or len("".join(tokens)) < len(text.strip()) * 0.5:
return [text.strip()]
return tokens
def correct_query_merge_phrases(query: str, whitelist, threshold=75, max_ngram=3):
query_norm = normalize(query)
tokens = smart_tokenize(query_norm)
corrected = []
i = 0
while i < len(tokens):
matched = False
for n in range(min(max_ngram, len(tokens) - i), 0, -1):
phrase = "".join(tokens[i:i+n])
match, score, _ = process.extractOne(phrase, whitelist, scorer=fuzz.ratio)
if score >= threshold:
corrected.append(match)
i += n
matched = True
break
if not matched:
corrected.append(tokens[i])
i += 1
# ✅ ตัดคำที่มีความยาว 1 ตัวอักษรและไม่ได้อยู่ใน whitelist
cleaned = [word for word in corrected if len(word) > 1 or word in whitelist]
return " ".join(cleaned)
# 🌟 Main search function
def search_product(query, model_name):
start_time = time.time()
if model_name not in model_config:
return "<p>❌ ไม่พบโมเดล</p>"
latest_query_result["raw_query"] = query
corrected_query = correct_query_merge_phrases(query,keyword_whitelist)
query_embed = model_config[model_name]["func"](corrected_query)
collection_name = model_config[model_name]["collection"]
try:
result = qdrant_client.query_points(
collection_name=collection_name,
query=query_embed.tolist(),
with_payload=True,
query_filter=Filter(must=[FieldCondition(key="type", match=MatchValue(value="product"))]),
limit=10
).points
except Exception as e:
return f"<p>❌ Qdrant error: {str(e)}</p>"
elapsed = time.time() - start_time
html_output = f"<p>⏱ <strong>{elapsed:.2f} วินาที</strong></p>"
if corrected_query != query:
html_output += f"<p>🔧 แก้คำค้นจาก: <code>{query}</code> → <code>{corrected_query}</code></p>"
html_output += """
<div style="display: grid; grid-template-columns: repeat(auto-fill, minmax(220px, 1fr)); gap: 20px;">
"""
result_summary = ""
for res in result:
name = res.payload.get("name", "ไม่ทราบชื่อสินค้า")
score = f"{res.score:.4f}"
img_url = res.payload.get("imageUrl", "")
price = res.payload.get("price", "ไม่ระบุ")
brand = res.payload.get("brand", "")
html_output += f"""
<div style="border: 1px solid #ddd; border-radius: 8px; padding: 10px; text-align: center; box-shadow: 1px 1px 5px rgba(0,0,0,0.1); background: #fff;">
<img src="{img_url}" style="width: 100%; max-height: 150px; object-fit: contain; border-radius: 4px;">
<div style="margin-top: 10px;">
<div style="font-weight: bold; font-size: 14px;">{name}</div>
<div style="color: gray; font-size: 12px;">{brand}</div>
<div style="color: green; margin: 4px 0;">฿{price}</div>
<div style="font-size: 12px; color: #555;">score: {score}</div>
</div>
</div>
"""
result_summary += f"{name} (score: {score}) | "
html_output += "</div>"
latest_query_result["query"] = corrected_query
latest_query_result["result"] = result_summary.strip()
latest_query_result["model"] = model_name
latest_query_result["time"] = elapsed
return html_output
# 📝 Logging feedback
def log_feedback(feedback):
try:
now = datetime.now().strftime("%Y-%m-%d")
table.create({
"timestamp": now,
"raw_query": latest_query_result["raw_query"],
"model": latest_query_result["model"],
"query": latest_query_result["query"],
"result": latest_query_result["result"],
"time(second)": latest_query_result["time"],
"feedback": feedback
})
return "✅ Feedback saved to Airtable!"
except Exception as e:
return f"❌ Failed to save feedback: {str(e)}"
# 🎨 Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## 🔎 Product Semantic Search (Vector Search + Qdrant)")
with gr.Row():
model_selector = gr.Dropdown(
choices=list(models.keys()),
label="เลือกโมเดล",
value="E5 (intfloat/multilingual-e5-small)"
)
query_input = gr.Textbox(label="พิมพ์คำค้นหา")
result_output = gr.HTML(label="📋 ผลลัพธ์") # HTML แสดงผลลัพธ์พร้อมรูป
with gr.Row():
match_btn = gr.Button("✅ ตรง")
not_match_btn = gr.Button("❌ ไม่ตรง")
feedback_status = gr.Textbox(label="📬 สถานะ Feedback")
query_input.submit(search_product, inputs=[query_input, model_selector], outputs=result_output)
match_btn.click(lambda: log_feedback("match"), outputs=feedback_status)
not_match_btn.click(lambda: log_feedback("not_match"), outputs=feedback_status)
# Run app
demo.launch(share=True)