Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import time | |
| from datetime import datetime | |
| import pandas as pd | |
| from sentence_transformers import SentenceTransformer | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.models import Filter, FieldCondition, MatchValue | |
| import os | |
| from rapidfuzz import process, fuzz | |
| from pythainlp.tokenize import word_tokenize | |
| from pyairtable import Table | |
| from pyairtable import Api | |
| import pickle | |
| import re | |
| import unicodedata | |
| from FlagEmbedding import FlagReranker | |
| # Setup Qdrant Client | |
| qdrant_client = QdrantClient( | |
| url=os.environ.get("Qdrant_url"), | |
| api_key=os.environ.get("Qdrant_api"), | |
| timeout=30.0 | |
| ) | |
| # Airtable Config | |
| AIRTABLE_API_KEY = os.environ.get("airtable_api") | |
| BASE_ID = os.environ.get("airtable_baseid") | |
| TABLE_NAME = "Feedback_search" | |
| api = Api(AIRTABLE_API_KEY) | |
| table = api.table(BASE_ID, TABLE_NAME) | |
| # Load whitelist | |
| with open("keyword_whitelist.pkl", "rb") as f: | |
| keyword_whitelist = pickle.load(f) | |
| # Preload Models | |
| models = { | |
| "E5 Finetuned": { | |
| "model": SentenceTransformer("e5_finetuned"), | |
| "collection": "product_E5_finetune", | |
| "threshold": 0.8, | |
| "prefix": "query: " | |
| }, | |
| "BGE M3": { | |
| "model": SentenceTransformer("BAAI/bge-m3"), | |
| "collection": "product_bge-m3", | |
| "threshold": 0.45, | |
| "prefix": "" | |
| } | |
| } | |
| reranker = FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=True) | |
| # Utils | |
| def is_non_thai(text): | |
| return re.match(r'^[A-Za-z0-9&\-\s]+$', text) is not None | |
| def join_corrected_tokens(corrected: list) -> str: | |
| if corrected and is_non_thai("".join(corrected)): | |
| return " ".join([w for w in corrected if len(w) > 1 or w in keyword_whitelist]) | |
| else: | |
| return "".join([w for w in corrected if len(w) > 1 or w in keyword_whitelist]) | |
| def normalize(text: str) -> str: | |
| if is_non_thai(text): | |
| return text.strip() | |
| text = unicodedata.normalize("NFC", text) | |
| return text.replace("เแ", "แ").replace("เเ", "แ").strip().lower() | |
| def smart_tokenize(text: str) -> list: | |
| tokens = word_tokenize(text.strip(), engine="newmm") | |
| return tokens if tokens and len("".join(tokens)) >= len(text.strip()) * 0.5 else [text.strip()] | |
| def correct_query_merge_phrases(query: str, whitelist, threshold=80, max_ngram=3): | |
| query_norm = normalize(query) | |
| tokens = smart_tokenize(query_norm) | |
| corrected = [] | |
| i = 0 | |
| while i < len(tokens): | |
| matched = False | |
| for n in range(min(max_ngram, len(tokens) - i), 0, -1): | |
| phrase = "".join(tokens[i:i+n]) | |
| if phrase in whitelist: | |
| corrected.append(phrase) | |
| i += n | |
| matched = True | |
| break | |
| match, score, _ = process.extractOne( | |
| phrase, | |
| whitelist, | |
| scorer=fuzz.token_sort_ratio, | |
| processor=lambda x: x.lower() | |
| ) | |
| if score >= threshold: | |
| corrected.append(match) | |
| i += n | |
| matched = True | |
| break | |
| if not matched: | |
| corrected.append(tokens[i]) | |
| i += 1 | |
| return join_corrected_tokens(corrected) | |
| # Global state | |
| latest_query_result = {"query": "", "result": "", "raw_query": "", "time": ""} | |
| # Search Function | |
| def search_product(query, model_choice): | |
| start_time = time.time() | |
| latest_query_result["raw_query"] = query | |
| selected = models[model_choice] | |
| model = selected["model"] | |
| collection_name = selected["collection"] | |
| threshold = selected["threshold"] | |
| prefix = selected["prefix"] | |
| corrected_query = correct_query_merge_phrases(query, keyword_whitelist) | |
| query_embed = model.encode(prefix + corrected_query) | |
| try: | |
| # 🔍 ดึง top-50 ก่อน rerank | |
| result = qdrant_client.query_points( | |
| collection_name=collection_name, | |
| query=query_embed.tolist(), | |
| with_payload=True, | |
| query_filter=Filter(must=[FieldCondition(key="type", match=MatchValue(value="product"))]), | |
| limit=50 | |
| ).points | |
| except Exception as e: | |
| return f"<p>❌ Qdrant error: {str(e)}</p>" | |
| # ✅ Rerank Top 10 ด้วย Cross-Encoder (เฉพาะ BGE M3 เท่านั้น) | |
| if model_choice == "BGE M3" and len(result) > 0: | |
| topk = 10 | |
| docs = [r.payload.get("name", "") for r in result[:topk]] | |
| pairs = [[corrected_query, d] for d in docs] | |
| scores = reranker.compute_score(pairs, normalize=True) | |
| # ผสมคะแนน: 0.6 จาก embedding, 0.4 จาก reranker | |
| result[:topk] = sorted( | |
| zip(result[:topk], scores), | |
| key=lambda x: 0.6 * x[0].score + 0.4 * x[1], | |
| reverse=True | |
| ) | |
| result[:topk] = [r[0] for r in result[:topk]] | |
| elapsed = time.time() - start_time | |
| html_output = f"<p>⏱ <strong>{elapsed:.2f} วินาที</strong></p>" | |
| if corrected_query != query: | |
| html_output += f"<p>🔧 แก้คำค้นจาก: <code>{query}</code> → <code>{corrected_query}</code></p>" | |
| html_output += '<div style="display: grid; grid-template-columns: repeat(auto-fill, minmax(220px, 1fr)); gap: 20px;">' | |
| result_summary, found = "", False | |
| for res in result: | |
| if res.score >= threshold: | |
| found = True | |
| name = res.payload.get("name", "ไม่ทราบชื่อสินค้า") | |
| score = f"{res.score:.4f}" | |
| img_url = res.payload.get("imageUrl", "") | |
| price = res.payload.get("price", "ไม่ระบุ") | |
| brand = res.payload.get("brand", "") | |
| html_output += f""" | |
| <div style="border: 1px solid #ddd; border-radius: 8px; padding: 10px; text-align: center; box-shadow: 1px 1px 5px rgba(0,0,0,0.1); background: #fff;"> | |
| <img src="{img_url}" style="width: 100%; max-height: 150px; object-fit: contain; border-radius: 4px;"> | |
| <div style="margin-top: 10px;"> | |
| <div style="font-weight: bold; font-size: 14px;">{name}</div> | |
| <div style="color: gray; font-size: 12px;">{brand}</div> | |
| <div style="color: green; margin: 4px 0;">฿{price}</div> | |
| <div style="font-size: 12px; color: #555;">score: {score}</div> | |
| </div> | |
| </div> | |
| """ | |
| result_summary += f"{name} (score: {score}) | " | |
| html_output += "</div>" | |
| if not found: | |
| html_output += '<div style="text-align: center; font-size: 18px; color: #a00; padding: 30px;">❌ ไม่พบสินค้าที่เกี่ยวข้องกับคำค้นนี้</div>' | |
| return html_output | |
| latest_query_result.update({ | |
| "query": corrected_query, | |
| "result": result_summary.strip(), | |
| "time": elapsed, | |
| }) | |
| return html_output | |
| # Feedback Function | |
| def log_feedback(feedback, model_choice): | |
| try: | |
| now = datetime.now().strftime("%Y-%m-%d") | |
| table.create({ | |
| "model": model_choice, | |
| "timestamp": now, | |
| "raw_query": latest_query_result["raw_query"], | |
| "query": latest_query_result["query"], | |
| "result": latest_query_result["result"], | |
| "time(second)": latest_query_result["time"], | |
| "feedback": feedback | |
| }) | |
| return "✅ Feedback saved to Airtable!" | |
| except Exception as e: | |
| return f"❌ Failed to save feedback: {str(e)}" | |
| # Gradio UI | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 🔎 Product Semantic Search (Vector Search + Qdrant)") | |
| with gr.Row(): | |
| model_selector = gr.Dropdown(label="🔍 เลือกโมเดล", choices=list(models.keys()), value="E5 Finetuned") | |
| query_input = gr.Textbox(label="พิมพ์คำค้นหา") | |
| result_output = gr.HTML(label="📋 ผลลัพธ์") | |
| with gr.Row(): | |
| match_btn = gr.Button("✅ ตรง") | |
| not_match_btn = gr.Button("❌ ไม่ตรง") | |
| feedback_status = gr.Textbox(label="📬 สถานะ Feedback") | |
| query_input.submit(search_product, inputs=[query_input, model_selector], outputs=result_output) | |
| match_btn.click(fn=lambda model: log_feedback("match", model), inputs=model_selector, outputs=feedback_status) | |
| not_match_btn.click(fn=lambda model: log_feedback("not_match", model), inputs=model_selector, outputs=feedback_status) | |
| demo.launch(share=True) | |