File size: 3,929 Bytes
587c522
dfabed7
d802f7b
e0f269f
d802f7b
e0f269f
d802f7b
e0f269f
 
245e246
e0f269f
d802f7b
c56e053
 
2bd39e5
 
587c522
 
 
dfabed7
587c522
dfabed7
587c522
 
c56e053
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
587c522
 
 
 
 
 
 
 
 
 
 
 
 
68d8ac5
587c522
 
 
 
68d8ac5
c56e053
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import gradio as gr
import json
from tqdm import tqdm
from langchain_community.vectorstores import FAISS
from langchain_google_genai import GoogleGenerativeAIEmbeddings
import google.generativeai as genai

from helpers import (
    list_docx_files, get_splits, get_json_splits_only, prompt_order, log_message
)

from file_loader import get_vectorstore

os.environ["USER_AGENT"] = "gradio-hf-space"

if "GOOGLE_API_KEY" not in os.environ:
    os.environ["GOOGLE_API_KEY"] = "AIzaSyDJ4vIKuIBIPNHATLxnoHlagXWbsAz-vRs"
key = "AIzaSyDJ4vIKuIBIPNHATLxnoHlagXWbsAz-vRs"

###

# Cấu hình API key cho Google GenAI
genai.configure(api_key=key)

vectorstore = get_vectorstore()
# Define the augment_prompt function
def augment_prompt(query: str, k: int = 10):
    queries = []
    queries.append(query)

    retriever = vectorstore.as_retriever(search_kwargs={"k": k})
    results = retriever.invoke(query)

    if results:
        source_knowledge = "\n\n".join([doc.page_content for doc in results])
        return f"""Using the contexts below, answer the query.
Contexts:
{source_knowledge}
"""
    else:
        return f"No relevant context found.\n."

def get_answer(query, queries_list=None):
    if queries_list is None:
        queries_list = []

    messages = [
    {"role": "user", "parts": [{"text": "IMPORTANT: You are a super energetic, helpful, polite, Vietnamese-speaking assistant. If you can not see the answer in contexts, try to search it up online by yourself but remember to give the source."}]},
    {"role": "user", "parts": [{"text": augment_prompt(query)}]}
]
#     bonus = '''
# Bạn tham kháo thêm các nguồn thông tin tại:
# Trang thông tin điện tử: https://neu.edu.vn ; https://daotao.neu.edu.vn
# Trang mạng xã hội có thông tin tuyển sinh: https://www.facebook.com/ktqdNEU ; https://www.facebook.com/tvtsneu ;
# Email tuyển sinh: [email protected]
# Số điện thoại tuyển sinh: 0888.128.558
#   '''

    queries_list.append(query)
    queries = {"role": "user", "parts": [{"text": prompt_order(queries_list)}]}
    messages_with_queries = messages.copy()
    messages_with_queries.append(queries)
    # messages_with_queries.insert(0, queries)

  # Configure API key
    genai.configure(api_key=key)

  # Initialize the Gemini model
    model = genai.GenerativeModel("gemini-2.0-flash")

    response = model.generate_content(contents=messages_with_queries, stream=True)
    response_text = ""

    for chunk in response:
        response_text += chunk.text
        yield response_text

    messages.append({"role": "model", "parts": [{"text": response_text}]})

        # user_feedback = yield "\nNhập phản hồi của bạn (hoặc nhập 'q' để thoát): "
        # if user_feedback.lower() == "q":
        #     break

        # messages.append({"role": "user", "parts": [{"text": query}]})

    log_message(messages)

institutions = ['Tất cả'] + ['Trường Công Nghệ']
categories = ['Tất cả'] + ['Đề án', 'Chương trình đào tạo']

with gr.Blocks() as demo:
    with gr.Row():
        category1 = gr.Dropdown(choices = institutions, label="Trường", value = 'Tất cả')
        category2 = gr.Dropdown(choices = categories, label="Bạn quan tâm tới", value = 'Tất cả')

    chat_interface = gr.ChatInterface(get_answer,
                                      textbox=gr.Textbox(placeholder="Đặt câu hỏi tại đây",
                                                        container=False,
                                                        autoscroll=True,
                                                        scale=7),
                                      type="messages",
                                      # textbox=prompt,
                                      # additional_inputs=[category1, category2]
                                      )

if __name__ == "__main__":
    demo.launch()