hoangchihien3011 commited on
Commit
a098bc7
·
1 Parent(s): 59959c6

Initialize app

Browse files
Files changed (5) hide show
  1. .gitignore +2 -0
  2. Dockerfile +4 -6
  3. requirements.txt +0 -0
  4. src/rag.py +280 -0
  5. src/streamlit_app.py +62 -38
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ venv
2
+ .env
Dockerfile CHANGED
@@ -1,13 +1,11 @@
1
- FROM python:3.9-slim
2
 
3
  WORKDIR /app
4
 
5
- RUN apt-get update && apt-get install -y \
6
- build-essential \
7
  curl \
8
- software-properties-common \
9
- git \
10
- && rm -rf /var/lib/apt/lists/*
11
 
12
  COPY requirements.txt ./
13
  COPY src/ ./src/
 
1
+ FROM python:3.12-alpine
2
 
3
  WORKDIR /app
4
 
5
+ RUN apk add --no-cache \
6
+ build-base \
7
  curl \
8
+ git
 
 
9
 
10
  COPY requirements.txt ./
11
  COPY src/ ./src/
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
 
src/rag.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ import langchain_google_genai as genai
4
+ import streamlit as st
5
+ from sentence_transformers import SentenceTransformer
6
+ import os
7
+ import pymongo
8
+ from langchain_google_genai import ChatGoogleGenerativeAI
9
+ from sentence_transformers import CrossEncoder
10
+ from typing import List, Dict, Any, Optional
11
+ from concurrent.futures import ThreadPoolExecutor, as_completed
12
+ from langchain_core.output_parsers import StrOutputParser
13
+ from langchain_core.prompts import ChatPromptTemplate
14
+ from langchain_core.runnables import RunnablePassthrough, RunnableMap
15
+ import time
16
+
17
+ def safe_log_info(message):
18
+ print(f"INFO: {message}")
19
+
20
+ def safe_log_warning(message):
21
+ print(f"WARNING: {message}")
22
+
23
+ def safe_log_error(message, exc_info=False):
24
+ print(f"ERROR: {message}")
25
+ if exc_info:
26
+ import traceback
27
+ traceback.print_exc()
28
+ safe_log_error("Error occurred during logging", exc_info=True)
29
+
30
+ load_dotenv()
31
+ google_api_key = os.environ.get("GOOGLE_API_KEY")
32
+ mongo_uri = os.environ.get("MONGODB_URI")
33
+
34
+ @st.cache_resource
35
+ def load_generative_model():
36
+ llm = ChatGoogleGenerativeAI(
37
+ model = 'models/gemini-2.0-flash',
38
+ temperature=0.2,
39
+ max_tokens = None,
40
+ timeout = 180,
41
+ max_retries = 2,
42
+ convert_system_message_to_human= True,
43
+ api_key = google_api_key
44
+ )
45
+ return llm
46
+
47
+ @st.cache_resource
48
+ def load_embedding_model():
49
+ embedding_model = SentenceTransformer("namdp-ptit/Vidense")
50
+ return embedding_model
51
+
52
+ @st.cache_resource
53
+ def load_mongo_collection():
54
+ client = pymongo.MongoClient(mongo_uri)
55
+ db = client['vietnamese-llms']
56
+ collection = db['vietnamese-llms-data']
57
+ return collection
58
+
59
+ @st.cache_resource
60
+ def load_reranker():
61
+ reranker = CrossEncoder("namdp-ptit/ViRanker")
62
+ return reranker
63
+
64
+ def get_embedding(text: str) -> list[float]:
65
+ embedding_model = load_embedding_model()
66
+ embedding = embedding_model.encode(text).tolist()
67
+ return embedding
68
+
69
+ def find_similar_documents_hybrid_search(
70
+ query_vector: list[float],
71
+ search_query: str,
72
+ limit: int = 10,
73
+ candidates: int = 20,
74
+ vector_search_index: str = "embedding_search",
75
+ atlas_search_index: str = "header_text"
76
+ ) -> list[dict]:
77
+ """
78
+ Hybrid search combining vector and text search with parallel execution.
79
+ """
80
+ all_results = []
81
+ collection = load_mongo_collection()
82
+ def perform_vector_search():
83
+ """Perform vector search in parallel."""
84
+ try:
85
+ vector_pipeline = [
86
+ {
87
+ "$vectorSearch": {
88
+ "index": vector_search_index,
89
+ "path": "embedding",
90
+ "queryVector": query_vector,
91
+ "limit": limit,
92
+ "numCandidates": candidates
93
+ }
94
+ },
95
+ {
96
+ "$project": {
97
+ '_id': 1,
98
+ 'header' : 1,
99
+ 'content': 1,
100
+ "vector_score": {"$meta": "vectorSearchScore"}
101
+ }
102
+ }
103
+ ]
104
+
105
+ vector_results = list(collection.aggregate(vector_pipeline))
106
+ safe_log_info(f"Vector search returned {len(vector_results)} results")
107
+ for doc in vector_results:
108
+ doc['search_type'] = 'vector'
109
+ doc['combined_score'] = doc.get('vector_score', 0) * 0.6 # Weight vector score
110
+ return vector_results
111
+ except Exception as e:
112
+ safe_log_warning(f"Vector search failed: {e}")
113
+ return []
114
+
115
+ def perform_text_search():
116
+ """Perform text search in parallel."""
117
+ if not search_query or not search_query.strip():
118
+ return []
119
+
120
+ try:
121
+ text_pipeline = [
122
+ {
123
+ "$search": {
124
+ "index": atlas_search_index,
125
+ "compound": {
126
+ "must": [
127
+ {
128
+ "text": {
129
+ "query": search_query,
130
+ "path": ["header", "content"]
131
+ }
132
+ }
133
+ ]
134
+ }
135
+ }
136
+ },
137
+ {
138
+ "$project": {
139
+ '_id': 1,
140
+ 'header': 1,
141
+ 'content': 1,
142
+ "text_score": {"$meta": "searchScore"}
143
+ }
144
+ }
145
+ ]
146
+
147
+ text_results = list(collection.aggregate(text_pipeline))
148
+ safe_log_info(f"Text search returned {len(text_results)} results")
149
+ for doc in text_results:
150
+ doc['search_type'] = 'text'
151
+ doc['combined_score'] = doc.get('text_score', 0) * 0.4 # Weight text score
152
+ return text_results
153
+ except Exception as e:
154
+ safe_log_warning(f"Text search failed: {e}")
155
+ return []
156
+
157
+ try:
158
+ # Run both searches in parallel
159
+ start_time = time.time()
160
+ with ThreadPoolExecutor(max_workers=2) as executor:
161
+ vector_future = executor.submit(perform_vector_search)
162
+ text_future = executor.submit(perform_text_search)
163
+
164
+ # Collect results as they complete
165
+ for future in as_completed([vector_future, text_future]):
166
+ try:
167
+ results = future.result()
168
+ all_results.extend(results)
169
+ except Exception as e:
170
+ safe_log_error(f"Error in parallel search: {e}")
171
+
172
+ search_time = time.time() - start_time
173
+ safe_log_info(f"Parallel search completed in {search_time:.3f}s")
174
+
175
+ # 3. Merge và deduplicate results
176
+ seen_ids = set()
177
+ merged_results = []
178
+
179
+ for doc in all_results:
180
+ doc_id = str(doc['_id'])
181
+ if doc_id not in seen_ids:
182
+ seen_ids.add(doc_id)
183
+ # Clean up the document for final result
184
+ final_doc = {
185
+ '_id': doc['_id'],
186
+ 'content': doc.get('content', ''),
187
+ # 'uploader_username': doc.get('uploader_username', ''), # Removed
188
+ 'header': doc.get('header', ''),
189
+ 'score': doc.get('combined_score', 0)
190
+ }
191
+ merged_results.append(final_doc)
192
+ else:
193
+ # If document already exists, boost its score
194
+ for existing_doc in merged_results:
195
+ if str(existing_doc['_id']) == doc_id:
196
+ existing_doc['score'] += doc.get('combined_score', 0) * 0.5
197
+ break
198
+
199
+ # Sort by combined score
200
+ merged_results.sort(key=lambda x: x.get('score', 0), reverse=True)
201
+
202
+ # Return top results
203
+ final_results = merged_results[:limit]
204
+ safe_log_info(f"Hybrid search final results: {len(final_results)} documents")
205
+
206
+ return final_results
207
+
208
+ except Exception as e:
209
+ safe_log_error(f"Error in hybrid search: {e}", exc_info=True)
210
+
211
+ def rerank_documents(query: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
212
+ """
213
+ Reranks a list of documents based on their relevance to the query using a reranker model.
214
+
215
+ Args:
216
+ query: The original search query.
217
+ documents: A list of dictionaries, where each dictionary represents a document
218
+ and contains a 'content' key with the document's text.
219
+
220
+ Returns:
221
+ A list of dictionaries representing the reranked documents, sorted by relevance score.
222
+ """
223
+ if not documents:
224
+ return []
225
+ reranker_model = load_reranker()
226
+ # Prepare pairs for the reranker model
227
+ sentence_pairs = [[query, doc.get('content', '')] for doc in documents]
228
+
229
+ # Get reranking scores
230
+ rerank_scores = reranker_model.predict(sentence_pairs)
231
+
232
+ # Add reranking scores to the documents
233
+ for i, doc in enumerate(documents):
234
+ doc['rerank_score'] = float(rerank_scores[i]) # Convert to float for potential serialization
235
+
236
+ # Sort documents by reranking score in descending order
237
+ reranked_documents = sorted(documents, key=lambda x: x.get('rerank_score', -1), reverse=True)
238
+
239
+ return reranked_documents
240
+
241
+ def format_docs(docs):
242
+ return "\n\n".join([doc.get('header', '') + doc.get('content', '') for doc in docs if isinstance(doc, dict) and 'content' in doc and 'header' in doc])
243
+
244
+ def get_answer_with_rag(query:str) -> str:
245
+
246
+ revised_template = ChatPromptTemplate.from_messages([
247
+ ('system', """bạn là một trợ lý AI thân thiện, được thiết kế để giúp khám phá mọi điều về Học viện Bưu chính Viễn thông (PTIT).
248
+ Bạn sẽ sử dụng thông tin được cung cấp để trả lời các câu hỏi của người dùng một cách chi tiết và dễ hiểu nhất.
249
+ Hãy nhớ rằng, bạn chỉ có thể trả lời dựa trên thông tin bạn cung cấp. Nếu câu hỏi nằm ngoài phạm vi thông tin đó, bạn sẽ cho người dùng biết."""),
250
+ ('human', "Thông tin tham khảo:\n```\n{context}\n```\n\nCâu hỏi của tôi:\n{question}")
251
+ ])
252
+ llm = load_generative_model()
253
+ query_embedding = get_embedding(query)
254
+
255
+ context_docs = find_similar_documents_hybrid_search(
256
+ query_vector=query_embedding,
257
+ search_query=query,
258
+ limit=10,
259
+ candidates=20,
260
+ vector_search_index="embedding_search",
261
+ atlas_search_index="header_text"
262
+ )
263
+
264
+ reranked_docs = rerank_documents(query, context_docs)
265
+ top_n_docs = reranked_docs[:10]
266
+ context = format_docs(top_n_docs)
267
+
268
+ chain = (
269
+ RunnableMap({
270
+ "context": RunnablePassthrough(),
271
+ "question": RunnablePassthrough()
272
+ })
273
+ | revised_template
274
+ | llm
275
+ | StrOutputParser()
276
+ )
277
+ response = chain.invoke({
278
+ "context": context,
279
+ "question": query})
280
+ return response
src/streamlit_app.py CHANGED
@@ -1,40 +1,64 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from rag import get_answer_with_rag
3
+ from rag import get_answer_with_rag, load_generative_model, load_reranker, load_embedding_model, load_mongo_collection
4
 
5
+ load_generative_model()
6
+ load_reranker()
7
+ load_embedding_model()
8
+ load_mongo_collection()
9
+ st.set_page_config(
10
+ page_title="PTIT RAG Chatbot",
11
+ page_icon="🤖",
12
+ layout="wide"
13
+ )
14
+
15
+ # --- GIAO DIỆN CHÍNH ---
16
+ st.title("🤖 PTIT RAG Chatbot")
17
+ st.caption("Trợ lý ảo thông minh về Học viện Bưu chính Viễn thông")
18
+
19
+ # Khởi tạo session state để lưu trữ lịch sử trò chuyện
20
+ if "messages" not in st.session_state:
21
+ st.session_state.messages = [
22
+ {"role": "assistant", "content": "Xin chào! Tôi có thể giúp gì cho bạn về các thông tin tại PTIT?"}
23
+ ]
24
+
25
+ # --- SIDEBAR ---
26
+ with st.sidebar:
27
+ st.header("Tùy chọn")
28
+ if st.button("🗑️ Xóa cuộc trò chuyện", use_container_width=True):
29
+ st.session_state.messages = [
30
+ {"role": "assistant", "content": "Cuộc trò chuyện đã được xóa. Hãy bắt đầu lại nhé!"}
31
+ ]
32
+ st.rerun()
33
+
34
+ st.markdown("---")
35
+ st.markdown("### Về ứng dụng")
36
+ st.info("Ứng dụng này sử dụng RAG để trả lời câu hỏi dựa trên tài liệu về PTIT.")
37
+
38
+ # Hiển thị lịch sử trò chuyện
39
+ for message in st.session_state.messages:
40
+ avatar = "🧑‍💻" if message["role"] == "user" else "🤖"
41
+ with st.chat_message(message["role"], avatar=avatar):
42
+ st.write(message["content"])
43
+
44
+ def submit_question(question: str):
45
+ st.session_state.messages.append({"role": "user", "content": question})
46
+ st.rerun()
47
+
48
+ # Khu vực nhập liệu của người dùng
49
+ if prompt := st.chat_input("Nhập câu hỏi của bạn..."):
50
+ # Thêm tin nhắn của người dùng vào session state và hiển thị ngay
51
+ st.session_state.messages.append({"role": "user", "content": prompt})
52
+ with st.chat_message("user", avatar="🧑‍💻"):
53
+ st.write(prompt)
54
+
55
+ with st.chat_message("assistant", avatar="🤖"):
56
+ with st.spinner("🤖 Tôi đang suy nghĩ, bạn chờ chút nhé..."):
57
+ try:
58
+ response = get_answer_with_rag(prompt)
59
+ st.markdown(response)
60
+ st.session_state.messages.append({"role": "assistant", "content": response})
61
+ except Exception as e:
62
+ error_message = "Rất tiếc, đã có lỗi xảy ra. Vui lòng thử lại sau!"
63
+ st.error(error_message)
64
+ st.session_state.messages.append({"role": "assistant", "content": error_message})