Spaces:
Sleeping
Sleeping
Commit
·
a098bc7
1
Parent(s):
59959c6
Initialize app
Browse files- .gitignore +2 -0
- Dockerfile +4 -6
- requirements.txt +0 -0
- src/rag.py +280 -0
- src/streamlit_app.py +62 -38
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
venv
|
2 |
+
.env
|
Dockerfile
CHANGED
@@ -1,13 +1,11 @@
|
|
1 |
-
FROM python:3.
|
2 |
|
3 |
WORKDIR /app
|
4 |
|
5 |
-
RUN
|
6 |
-
build-
|
7 |
curl \
|
8 |
-
|
9 |
-
git \
|
10 |
-
&& rm -rf /var/lib/apt/lists/*
|
11 |
|
12 |
COPY requirements.txt ./
|
13 |
COPY src/ ./src/
|
|
|
1 |
+
FROM python:3.12-alpine
|
2 |
|
3 |
WORKDIR /app
|
4 |
|
5 |
+
RUN apk add --no-cache \
|
6 |
+
build-base \
|
7 |
curl \
|
8 |
+
git
|
|
|
|
|
9 |
|
10 |
COPY requirements.txt ./
|
11 |
COPY src/ ./src/
|
requirements.txt
CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
|
|
src/rag.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
import langchain_google_genai as genai
|
4 |
+
import streamlit as st
|
5 |
+
from sentence_transformers import SentenceTransformer
|
6 |
+
import os
|
7 |
+
import pymongo
|
8 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
9 |
+
from sentence_transformers import CrossEncoder
|
10 |
+
from typing import List, Dict, Any, Optional
|
11 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
12 |
+
from langchain_core.output_parsers import StrOutputParser
|
13 |
+
from langchain_core.prompts import ChatPromptTemplate
|
14 |
+
from langchain_core.runnables import RunnablePassthrough, RunnableMap
|
15 |
+
import time
|
16 |
+
|
17 |
+
def safe_log_info(message):
|
18 |
+
print(f"INFO: {message}")
|
19 |
+
|
20 |
+
def safe_log_warning(message):
|
21 |
+
print(f"WARNING: {message}")
|
22 |
+
|
23 |
+
def safe_log_error(message, exc_info=False):
|
24 |
+
print(f"ERROR: {message}")
|
25 |
+
if exc_info:
|
26 |
+
import traceback
|
27 |
+
traceback.print_exc()
|
28 |
+
safe_log_error("Error occurred during logging", exc_info=True)
|
29 |
+
|
30 |
+
load_dotenv()
|
31 |
+
google_api_key = os.environ.get("GOOGLE_API_KEY")
|
32 |
+
mongo_uri = os.environ.get("MONGODB_URI")
|
33 |
+
|
34 |
+
@st.cache_resource
|
35 |
+
def load_generative_model():
|
36 |
+
llm = ChatGoogleGenerativeAI(
|
37 |
+
model = 'models/gemini-2.0-flash',
|
38 |
+
temperature=0.2,
|
39 |
+
max_tokens = None,
|
40 |
+
timeout = 180,
|
41 |
+
max_retries = 2,
|
42 |
+
convert_system_message_to_human= True,
|
43 |
+
api_key = google_api_key
|
44 |
+
)
|
45 |
+
return llm
|
46 |
+
|
47 |
+
@st.cache_resource
|
48 |
+
def load_embedding_model():
|
49 |
+
embedding_model = SentenceTransformer("namdp-ptit/Vidense")
|
50 |
+
return embedding_model
|
51 |
+
|
52 |
+
@st.cache_resource
|
53 |
+
def load_mongo_collection():
|
54 |
+
client = pymongo.MongoClient(mongo_uri)
|
55 |
+
db = client['vietnamese-llms']
|
56 |
+
collection = db['vietnamese-llms-data']
|
57 |
+
return collection
|
58 |
+
|
59 |
+
@st.cache_resource
|
60 |
+
def load_reranker():
|
61 |
+
reranker = CrossEncoder("namdp-ptit/ViRanker")
|
62 |
+
return reranker
|
63 |
+
|
64 |
+
def get_embedding(text: str) -> list[float]:
|
65 |
+
embedding_model = load_embedding_model()
|
66 |
+
embedding = embedding_model.encode(text).tolist()
|
67 |
+
return embedding
|
68 |
+
|
69 |
+
def find_similar_documents_hybrid_search(
|
70 |
+
query_vector: list[float],
|
71 |
+
search_query: str,
|
72 |
+
limit: int = 10,
|
73 |
+
candidates: int = 20,
|
74 |
+
vector_search_index: str = "embedding_search",
|
75 |
+
atlas_search_index: str = "header_text"
|
76 |
+
) -> list[dict]:
|
77 |
+
"""
|
78 |
+
Hybrid search combining vector and text search with parallel execution.
|
79 |
+
"""
|
80 |
+
all_results = []
|
81 |
+
collection = load_mongo_collection()
|
82 |
+
def perform_vector_search():
|
83 |
+
"""Perform vector search in parallel."""
|
84 |
+
try:
|
85 |
+
vector_pipeline = [
|
86 |
+
{
|
87 |
+
"$vectorSearch": {
|
88 |
+
"index": vector_search_index,
|
89 |
+
"path": "embedding",
|
90 |
+
"queryVector": query_vector,
|
91 |
+
"limit": limit,
|
92 |
+
"numCandidates": candidates
|
93 |
+
}
|
94 |
+
},
|
95 |
+
{
|
96 |
+
"$project": {
|
97 |
+
'_id': 1,
|
98 |
+
'header' : 1,
|
99 |
+
'content': 1,
|
100 |
+
"vector_score": {"$meta": "vectorSearchScore"}
|
101 |
+
}
|
102 |
+
}
|
103 |
+
]
|
104 |
+
|
105 |
+
vector_results = list(collection.aggregate(vector_pipeline))
|
106 |
+
safe_log_info(f"Vector search returned {len(vector_results)} results")
|
107 |
+
for doc in vector_results:
|
108 |
+
doc['search_type'] = 'vector'
|
109 |
+
doc['combined_score'] = doc.get('vector_score', 0) * 0.6 # Weight vector score
|
110 |
+
return vector_results
|
111 |
+
except Exception as e:
|
112 |
+
safe_log_warning(f"Vector search failed: {e}")
|
113 |
+
return []
|
114 |
+
|
115 |
+
def perform_text_search():
|
116 |
+
"""Perform text search in parallel."""
|
117 |
+
if not search_query or not search_query.strip():
|
118 |
+
return []
|
119 |
+
|
120 |
+
try:
|
121 |
+
text_pipeline = [
|
122 |
+
{
|
123 |
+
"$search": {
|
124 |
+
"index": atlas_search_index,
|
125 |
+
"compound": {
|
126 |
+
"must": [
|
127 |
+
{
|
128 |
+
"text": {
|
129 |
+
"query": search_query,
|
130 |
+
"path": ["header", "content"]
|
131 |
+
}
|
132 |
+
}
|
133 |
+
]
|
134 |
+
}
|
135 |
+
}
|
136 |
+
},
|
137 |
+
{
|
138 |
+
"$project": {
|
139 |
+
'_id': 1,
|
140 |
+
'header': 1,
|
141 |
+
'content': 1,
|
142 |
+
"text_score": {"$meta": "searchScore"}
|
143 |
+
}
|
144 |
+
}
|
145 |
+
]
|
146 |
+
|
147 |
+
text_results = list(collection.aggregate(text_pipeline))
|
148 |
+
safe_log_info(f"Text search returned {len(text_results)} results")
|
149 |
+
for doc in text_results:
|
150 |
+
doc['search_type'] = 'text'
|
151 |
+
doc['combined_score'] = doc.get('text_score', 0) * 0.4 # Weight text score
|
152 |
+
return text_results
|
153 |
+
except Exception as e:
|
154 |
+
safe_log_warning(f"Text search failed: {e}")
|
155 |
+
return []
|
156 |
+
|
157 |
+
try:
|
158 |
+
# Run both searches in parallel
|
159 |
+
start_time = time.time()
|
160 |
+
with ThreadPoolExecutor(max_workers=2) as executor:
|
161 |
+
vector_future = executor.submit(perform_vector_search)
|
162 |
+
text_future = executor.submit(perform_text_search)
|
163 |
+
|
164 |
+
# Collect results as they complete
|
165 |
+
for future in as_completed([vector_future, text_future]):
|
166 |
+
try:
|
167 |
+
results = future.result()
|
168 |
+
all_results.extend(results)
|
169 |
+
except Exception as e:
|
170 |
+
safe_log_error(f"Error in parallel search: {e}")
|
171 |
+
|
172 |
+
search_time = time.time() - start_time
|
173 |
+
safe_log_info(f"Parallel search completed in {search_time:.3f}s")
|
174 |
+
|
175 |
+
# 3. Merge và deduplicate results
|
176 |
+
seen_ids = set()
|
177 |
+
merged_results = []
|
178 |
+
|
179 |
+
for doc in all_results:
|
180 |
+
doc_id = str(doc['_id'])
|
181 |
+
if doc_id not in seen_ids:
|
182 |
+
seen_ids.add(doc_id)
|
183 |
+
# Clean up the document for final result
|
184 |
+
final_doc = {
|
185 |
+
'_id': doc['_id'],
|
186 |
+
'content': doc.get('content', ''),
|
187 |
+
# 'uploader_username': doc.get('uploader_username', ''), # Removed
|
188 |
+
'header': doc.get('header', ''),
|
189 |
+
'score': doc.get('combined_score', 0)
|
190 |
+
}
|
191 |
+
merged_results.append(final_doc)
|
192 |
+
else:
|
193 |
+
# If document already exists, boost its score
|
194 |
+
for existing_doc in merged_results:
|
195 |
+
if str(existing_doc['_id']) == doc_id:
|
196 |
+
existing_doc['score'] += doc.get('combined_score', 0) * 0.5
|
197 |
+
break
|
198 |
+
|
199 |
+
# Sort by combined score
|
200 |
+
merged_results.sort(key=lambda x: x.get('score', 0), reverse=True)
|
201 |
+
|
202 |
+
# Return top results
|
203 |
+
final_results = merged_results[:limit]
|
204 |
+
safe_log_info(f"Hybrid search final results: {len(final_results)} documents")
|
205 |
+
|
206 |
+
return final_results
|
207 |
+
|
208 |
+
except Exception as e:
|
209 |
+
safe_log_error(f"Error in hybrid search: {e}", exc_info=True)
|
210 |
+
|
211 |
+
def rerank_documents(query: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
212 |
+
"""
|
213 |
+
Reranks a list of documents based on their relevance to the query using a reranker model.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
query: The original search query.
|
217 |
+
documents: A list of dictionaries, where each dictionary represents a document
|
218 |
+
and contains a 'content' key with the document's text.
|
219 |
+
|
220 |
+
Returns:
|
221 |
+
A list of dictionaries representing the reranked documents, sorted by relevance score.
|
222 |
+
"""
|
223 |
+
if not documents:
|
224 |
+
return []
|
225 |
+
reranker_model = load_reranker()
|
226 |
+
# Prepare pairs for the reranker model
|
227 |
+
sentence_pairs = [[query, doc.get('content', '')] for doc in documents]
|
228 |
+
|
229 |
+
# Get reranking scores
|
230 |
+
rerank_scores = reranker_model.predict(sentence_pairs)
|
231 |
+
|
232 |
+
# Add reranking scores to the documents
|
233 |
+
for i, doc in enumerate(documents):
|
234 |
+
doc['rerank_score'] = float(rerank_scores[i]) # Convert to float for potential serialization
|
235 |
+
|
236 |
+
# Sort documents by reranking score in descending order
|
237 |
+
reranked_documents = sorted(documents, key=lambda x: x.get('rerank_score', -1), reverse=True)
|
238 |
+
|
239 |
+
return reranked_documents
|
240 |
+
|
241 |
+
def format_docs(docs):
|
242 |
+
return "\n\n".join([doc.get('header', '') + doc.get('content', '') for doc in docs if isinstance(doc, dict) and 'content' in doc and 'header' in doc])
|
243 |
+
|
244 |
+
def get_answer_with_rag(query:str) -> str:
|
245 |
+
|
246 |
+
revised_template = ChatPromptTemplate.from_messages([
|
247 |
+
('system', """bạn là một trợ lý AI thân thiện, được thiết kế để giúp khám phá mọi điều về Học viện Bưu chính Viễn thông (PTIT).
|
248 |
+
Bạn sẽ sử dụng thông tin được cung cấp để trả lời các câu hỏi của người dùng một cách chi tiết và dễ hiểu nhất.
|
249 |
+
Hãy nhớ rằng, bạn chỉ có thể trả lời dựa trên thông tin bạn cung cấp. Nếu câu hỏi nằm ngoài phạm vi thông tin đó, bạn sẽ cho người dùng biết."""),
|
250 |
+
('human', "Thông tin tham khảo:\n```\n{context}\n```\n\nCâu hỏi của tôi:\n{question}")
|
251 |
+
])
|
252 |
+
llm = load_generative_model()
|
253 |
+
query_embedding = get_embedding(query)
|
254 |
+
|
255 |
+
context_docs = find_similar_documents_hybrid_search(
|
256 |
+
query_vector=query_embedding,
|
257 |
+
search_query=query,
|
258 |
+
limit=10,
|
259 |
+
candidates=20,
|
260 |
+
vector_search_index="embedding_search",
|
261 |
+
atlas_search_index="header_text"
|
262 |
+
)
|
263 |
+
|
264 |
+
reranked_docs = rerank_documents(query, context_docs)
|
265 |
+
top_n_docs = reranked_docs[:10]
|
266 |
+
context = format_docs(top_n_docs)
|
267 |
+
|
268 |
+
chain = (
|
269 |
+
RunnableMap({
|
270 |
+
"context": RunnablePassthrough(),
|
271 |
+
"question": RunnablePassthrough()
|
272 |
+
})
|
273 |
+
| revised_template
|
274 |
+
| llm
|
275 |
+
| StrOutputParser()
|
276 |
+
)
|
277 |
+
response = chain.invoke({
|
278 |
+
"context": context,
|
279 |
+
"question": query})
|
280 |
+
return response
|
src/streamlit_app.py
CHANGED
@@ -1,40 +1,64 @@
|
|
1 |
-
import altair as alt
|
2 |
-
import numpy as np
|
3 |
-
import pandas as pd
|
4 |
import streamlit as st
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
"
|
29 |
-
"
|
30 |
-
|
31 |
-
}
|
32 |
-
|
33 |
-
st.
|
34 |
-
|
35 |
-
.
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
from rag import get_answer_with_rag
|
3 |
+
from rag import get_answer_with_rag, load_generative_model, load_reranker, load_embedding_model, load_mongo_collection
|
4 |
|
5 |
+
load_generative_model()
|
6 |
+
load_reranker()
|
7 |
+
load_embedding_model()
|
8 |
+
load_mongo_collection()
|
9 |
+
st.set_page_config(
|
10 |
+
page_title="PTIT RAG Chatbot",
|
11 |
+
page_icon="🤖",
|
12 |
+
layout="wide"
|
13 |
+
)
|
14 |
+
|
15 |
+
# --- GIAO DIỆN CHÍNH ---
|
16 |
+
st.title("🤖 PTIT RAG Chatbot")
|
17 |
+
st.caption("Trợ lý ảo thông minh về Học viện Bưu chính Viễn thông")
|
18 |
+
|
19 |
+
# Khởi tạo session state để lưu trữ lịch sử trò chuyện
|
20 |
+
if "messages" not in st.session_state:
|
21 |
+
st.session_state.messages = [
|
22 |
+
{"role": "assistant", "content": "Xin chào! Tôi có thể giúp gì cho bạn về các thông tin tại PTIT?"}
|
23 |
+
]
|
24 |
+
|
25 |
+
# --- SIDEBAR ---
|
26 |
+
with st.sidebar:
|
27 |
+
st.header("Tùy chọn")
|
28 |
+
if st.button("🗑️ Xóa cuộc trò chuyện", use_container_width=True):
|
29 |
+
st.session_state.messages = [
|
30 |
+
{"role": "assistant", "content": "Cuộc trò chuyện đã được xóa. Hãy bắt đầu lại nhé!"}
|
31 |
+
]
|
32 |
+
st.rerun()
|
33 |
+
|
34 |
+
st.markdown("---")
|
35 |
+
st.markdown("### Về ứng dụng")
|
36 |
+
st.info("Ứng dụng này sử dụng RAG để trả lời câu hỏi dựa trên tài liệu về PTIT.")
|
37 |
+
|
38 |
+
# Hiển thị lịch sử trò chuyện
|
39 |
+
for message in st.session_state.messages:
|
40 |
+
avatar = "🧑💻" if message["role"] == "user" else "🤖"
|
41 |
+
with st.chat_message(message["role"], avatar=avatar):
|
42 |
+
st.write(message["content"])
|
43 |
+
|
44 |
+
def submit_question(question: str):
|
45 |
+
st.session_state.messages.append({"role": "user", "content": question})
|
46 |
+
st.rerun()
|
47 |
+
|
48 |
+
# Khu vực nhập liệu của người dùng
|
49 |
+
if prompt := st.chat_input("Nhập câu hỏi của bạn..."):
|
50 |
+
# Thêm tin nhắn của người dùng vào session state và hiển thị ngay
|
51 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
52 |
+
with st.chat_message("user", avatar="🧑💻"):
|
53 |
+
st.write(prompt)
|
54 |
+
|
55 |
+
with st.chat_message("assistant", avatar="🤖"):
|
56 |
+
with st.spinner("🤖 Tôi đang suy nghĩ, bạn chờ chút nhé..."):
|
57 |
+
try:
|
58 |
+
response = get_answer_with_rag(prompt)
|
59 |
+
st.markdown(response)
|
60 |
+
st.session_state.messages.append({"role": "assistant", "content": response})
|
61 |
+
except Exception as e:
|
62 |
+
error_message = "Rất tiếc, đã có lỗi xảy ra. Vui lòng thử lại sau!"
|
63 |
+
st.error(error_message)
|
64 |
+
st.session_state.messages.append({"role": "assistant", "content": error_message})
|