Spaces:
Runtime error
Runtime error
import os | |
import subprocess | |
import gradio as gr | |
import json | |
from tqdm import tqdm | |
from langchain_community.vectorstores import FAISS | |
from langchain_google_genai import GoogleGenerativeAIEmbeddings | |
import google.generativeai as genai | |
# from playwright._impl._driver import get_driver_dir | |
from helpers import ( | |
list_docx_files, get_splits, get_json_splits_only, prompt_order, log_message, extract_metadata | |
) | |
from file_loader import get_vectorstore | |
# import asyncio | |
# key = "AIzaSyDJ4vIKuIBIPNHATLxnoHlagXWbsAz-vRs" | |
key = os.getenv("GOOGLE_API_KEY") | |
# Cấu hình API key cho Google GenAI | |
genai.configure(api_key=key) | |
vectorstore = get_vectorstore() | |
last_vector_docs = None # Lưu kết quả docs từ vectorstore.invoke trong lần gọi get_answer gần nhất | |
see_metadata = None | |
def augment_prompt(query: str, k: int = 5): | |
#define metadata | |
messages = [ | |
{"role": "user", "parts": [{"text": """ | |
{} | |
'Tai lieu ve': ['Chương trình đào tạo', 'Đề án', 'Đề cương'] | |
'Nganh': ['Trí tuệ nhân tạo', | |
'Toán kinh tế', | |
'Thống kê kinh tế', | |
'Phân tích dữ liệu trong Kinh tế', | |
'Kỹ thuật phần mềm', | |
'Khoa học máy tính', | |
'Khoa học dữ liệu', | |
'Hệ thống thông tin quản lý', | |
'Hệ thống thông tin', | |
'Định phí bảo hiểm và Quản trị rủi ro', | |
'Chương trình Công nghệ thông tin', | |
'An toàn thông tin'] | |
Nhiệm vụ của bạn là viết output dưới dạng dict để xác thực metadata, output có dạng: | |
{'Tai lieu ve': '<write here>', 'Nganh': <write here>}, nếu không có theo 2 lists bên trên thì trả về dict rỗng, nếu có nhiều kết quả cùng một key thì trả về dict rỗng. | |
"""}]}, | |
{"role": "user", "parts": [{"text": f'Câu hỏi như sau: {query}'}]}, | |
] | |
genai.configure(api_key=key) | |
model = genai.GenerativeModel("gemini-2.0-flash") | |
response = model.generate_content(contents=messages) | |
global see_metadata | |
metadata = extract_metadata(response) | |
see_metadata = metadata | |
#retrieve | |
global last_vector_docs | |
if metadata == {}: | |
retriever = vectorstore.as_retriever(search_kwargs={"k": 50}) | |
else: | |
retriever = vectorstore.as_retriever(search_kwargs={"k": k, "fetch_k": 50, "filter": metadata}) | |
results = retriever.invoke(query) | |
# Lưu kết quả để dùng cho log và lọc sau này | |
last_vector_docs = results | |
if results: | |
source_knowledge = "\n\n".join([doc.page_content for doc in results]) | |
return f"""Dữ liệu dưới đây liên quan đến Trường Công Nghệ (NCT) thuộc Đại học Kinh tế Quốc dân (NEU), dựa vào đó trả lời câu hỏi. | |
Dữ liệu: | |
{source_knowledge} | |
""" | |
else: | |
return "Không có thông tin liên quan.\n." | |
def get_answer(query, queries_list=None): | |
if queries_list is None: | |
queries_list = [] | |
messages = [ | |
{"role": "user", "parts": [{"text": "IMPORTANT: You are a super helpful, polite, Vietnamese-speaking assistant to give information of an university. If you cannot see the answer in contexts, tell user to make a more detailed question."}]}, | |
{"role": "user", "parts": [{"text": augment_prompt(query=query)}]} | |
] | |
queries_list.append(query) | |
queries = {"role": "user", "parts": [{"text": prompt_order(queries_list)}]} | |
messages_with_queries = messages.copy() | |
messages_with_queries.append(queries) | |
# Cấu hình API key và khởi tạo model Gemini | |
genai.configure(api_key=key) | |
model = genai.GenerativeModel("gemini-2.0-flash") | |
response = model.generate_content(contents=messages_with_queries, stream=True) | |
response_text = "" | |
for chunk in response: | |
response_text += chunk.text | |
yield response_text | |
messages.append({"role": "model", "parts": [{"text": response_text}]}) | |
log_message(messages) | |
def filter_vector_docs(keyword: str): | |
global last_vector_docs | |
if last_vector_docs is None: | |
return "Chưa có dữ liệu vectorstore được gọi từ get_answer." | |
else: | |
if not keyword.strip(): | |
# Nếu không nhập gì, trả về tất cả | |
filtered = [doc.page_content for doc in last_vector_docs] | |
else: | |
# Lọc các chunk chứa từ khoá (không phân biệt chữ hoa thường) | |
filtered = [doc.page_content for doc in last_vector_docs if keyword.lower() in doc.page_content.lower()] | |
if not filtered: | |
return f"Không có kết quả chứa từ khoá '{keyword}'." | |
return "\n\n".join(filtered) | |
institutions = ['Tất cả', 'Trường Công Nghệ'] | |
categories = ['Tất cả', 'Đề án', 'Chương trình đào tạo'] | |
print("Launching on space... This may take some time...") | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
# Dropdown category nếu cần | |
category1 = gr.Dropdown(choices=institutions, label="Trường", value=None) | |
category2 = gr.Dropdown(choices=categories, label="Bạn quan tâm tới", value=None) | |
# Chat Interface sử dụng ô nhập chung | |
shared_query = gr.Textbox(placeholder="Đặt câu hỏi tại đây", container=False, autoscroll=True, scale=7) | |
chat_interface = gr.ChatInterface(get_answer, textbox=shared_query, type="messages") | |
metadata_box = gr.Textbox(label="Metadata", value=str(see_metadata), interactive=False) | |
# Phần lọc các chunk: ô prompt nhập từ khoá và nút "Tìm trích xuất" nằm cùng hàng, | |
# kết quả sẽ hiển thị ở ô bên dưới. Nếu để trống, hiển thị toàn bộ. | |
with gr.Row(): | |
filter_prompt = gr.Textbox(label="Nhập từ khoá", placeholder="Nhập từ khoá để lọc (để trống để hiển thị tất cả)", interactive=True) | |
filter_button = gr.Button("Tìm trích xuất") | |
filter_output = gr.Textbox(label="Content", interactive=False) | |
filter_button.click(fn=filter_vector_docs, inputs=filter_prompt, outputs=filter_output) | |
if __name__ == "__main__": | |
demo.launch() |