ngcanh commited on
Commit
1112aa6
·
verified ·
1 Parent(s): 1b8c222

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +232 -142
app.py CHANGED
@@ -1,150 +1,240 @@
1
- __import__('pysqlite3')
2
- import sys
3
- sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- import streamlit as st
6
- from huggingface_hub import InferenceClient
7
- from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, ServiceContext, PromptTemplate
8
- from llama_index.vector_stores.chroma import ChromaVectorStore
9
- from llama_index.core import StorageContext
10
- from langchain.embeddings import HuggingFaceEmbeddings
11
- from langchain.text_splitter import CharacterTextSplitter
12
- from langchain.vectorstores import Chroma
13
- import chromadb
14
- from langchain.memory import ConversationBufferMemory
15
- import pandas as pd
16
- from langchain.schema import Document
 
 
 
 
 
 
 
 
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- # Set page config
20
- st.set_page_config(page_title="MBAL Chatbot", page_icon="🛡️", layout="wide")
 
21
 
22
- # Set your Hugging Face token here
 
23
 
24
- HF_TOKEN = st.secrets["HF_TOKEN"]
 
 
25
 
26
- @st.cache_resource
27
- def init_chroma():
28
- persist_directory = "chroma_db"
29
- chroma_client = chromadb.PersistentClient(path=persist_directory)
30
- chroma_collection = chroma_client.get_or_create_collection("my_collection")
31
- return chroma_client, chroma_collection
32
 
33
- @st.cache_resource
34
- def init_vectorstore():
35
- persist_directory = "chroma_db"
36
- embeddings = HuggingFaceEmbeddings()
37
- vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embeddings, collection_name="my_collection")
38
- return vectorstore
39
- @st.cache_resource
40
- def setup_vector():
41
- # Đọc dữ liệu từ file Excel
42
- df = pd.read_excel("chunk_metadata_template.xlsx")
43
- chunks = []
44
-
45
- # Tạo danh sách các Document có metadata
46
- for _, row in df.iterrows():
47
- chunk_with_metadata = Document(
48
- page_content=row['page_content'],
49
- metadata={
50
- 'chunk_id': row['chunk_id'],
51
- 'document_title': row['document_title'],
52
- 'topic': row['topic'],
53
- 'access': row['access']
54
- }
55
- )
56
- chunks.append(chunk_with_metadata)
57
-
58
- # Khởi tạo embedding
59
- embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
60
-
61
- # Khởi tạo hoặc ghi vào vectorstore đã tồn tại
62
- persist_directory = "chroma_db"
63
- collection_name = "my_collection"
64
-
65
- # Tạo vectorstore từ dữ liệu ghi vào Chroma
66
- vectorstore = Chroma.from_documents(
67
- documents=chunks,
68
- embedding=embeddings,
69
- persist_directory=persist_directory,
70
- collection_name=collection_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  )
72
-
73
- # Ghi xuống đĩa để đảm bảo dữ liệu được lưu
74
- vectorstore.persist()
75
-
76
- return vectorstore
77
-
78
- # Initialize components
79
- client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3", token=HF_TOKEN)
80
- chroma_client, chroma_collection = init_chroma()
81
- init_vectorstore()
82
- vectorstore = setup_vector()
83
-
84
- # Initialize memory buffer
85
- memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
86
-
87
- def rag_query(query):
88
- # Lấy tài liệu liên quan
89
- retrieved_docs = vectorstore.similarity_search(query, k=5)
90
- context = "\n".join([doc.page_content for doc in retrieved_docs]) if retrieved_docs else ""
91
-
92
- # Lấy tương tác cũ
93
- past_interactions = memory.load_memory_variables({})[memory.memory_key]
94
- context_with_memory = f"{context}\n\nConversation History:\n{past_interactions}"
95
-
96
- # Chuẩn bị prompt
97
- messages = [
98
- {
99
- "role": "user",
100
- "content": f"""You are a consultant advising clients on insurance products from MB Ageas Life in Vietnam. Please respond professionally and accurately, and suggest suitable products by asking a few questions about the customer's needs. All information provided must remain within the scope of MBAL. Invite the customer to register for a more detailed consultation at https://www.mbageas.life/
101
- {context_with_memory}
102
- Question: {query}
103
- Answer:"""
104
- }
105
- ]
106
-
107
- response_content = client.chat_completion(messages=messages, max_tokens=1024, stream=False)
108
- response = response_content.choices[0].message.content.split("Answer:")[-1].strip()
109
- return response
110
-
111
-
112
- def process_feedback(query, response, feedback):
113
- # st.write(f"Feedback received: {'👍' if feedback else '👎'} for query: {query}")
114
- if feedback:
115
- # If thumbs up, store the response in memory buffer
116
- memory.chat_memory.add_ai_message(response)
117
- else:
118
- # If thumbs down, remove the response from memory buffer and regenerate the response
119
- # memory.chat_memory.messages = [msg for msg in memory.chat_memory.messages if msg.get("content") != response]
120
- new_query=f"{query}. Tạo câu trả lời đúng với câu hỏi"
121
- new_response = rag_query(new_query)
122
- st.markdown(new_response)
123
- memory.chat_memory.add_ai_message(new_response)
124
-
125
- # Streamlit interface
126
-
127
- st.title("Chào mừng bạn đã đến với MBAL Chatbot")
128
- st.markdown("***")
129
- st.info('''
130
- Tôi sẽ giải đáp các thắc mắc của bạn liên quan đến các sản phẩm bảo hiểm nhân thọ của MB Ageas Life''')
131
-
132
- col1, col2 = st.columns(2)
133
-
134
- with col1:
135
- chat = st.button("Chat")
136
- if chat:
137
- st.switch_page("pages/chatbot.py")
138
- import streamlit as st
139
-
140
-
141
- Sidebar
142
- with st.sidebar:
143
- st.header("Lựa chọn khác")
144
- if st.button("Xóa lịch sử chat"):
145
- st.session_state.messages = []
146
- memory.clear()
147
- st.rerun()
148
-
149
-
150
- st.markdown("<div style='text-align:center;'></div>", unsafe_allow_html=True)
 
1
+ import streamlit as st #? run app streamlit run file_name.py
2
+ import tempfile
3
+ import os
4
+ import torch
5
+
6
+ from transformers.utils.quantization_config import BitsAndBytesConfig # for compressing model e.g. 16bits -> 4bits
7
+ from transformers import (
8
+ AutoTokenizer, # Tokenize Model
9
+ AutoModelForCausalLM, # LLM Loader - used for loading and using pre-trained models designed for causal language modeling tasks
10
+ pipeline) # pipline to setup llm-task oritented model
11
+ # pipline("text-classification", model='model', device=0)
12
+
13
+ from langchain_huggingface import HuggingFaceEmbeddings # huggingface sentence_transformer embedding models
14
+ from langchain_huggingface.llms import HuggingFacePipeline # like transformer pipeline
15
+
16
+ from langchain.memory import ConversationBufferMemory # Deprecated
17
+ from langchain_community.chat_message_histories import ChatMessageHistory # Deprecated
18
+ from langchain_community.document_loaders import PyPDFLoader, TextLoader # PDF Processing
19
+ from langchain.chains import ConversationalRetrievalChain # Deprecated
20
+ from langchain_experimental.text_splitter import SemanticChunker # module for chunking text
21
+
22
+ from langchain_chroma import Chroma # AI-native vector databases (ai-native mean built for handle large-scale AI workloads efficiently)
23
+ from langchain_text_splitters import RecursiveCharacterTextSplitter # recursively divide text, then merge them together if merge_size < chunk_size
24
+ from langchain_core.runnables import RunnablePassthrough # Use for testing (make 'example' easy to execute and experiment with)
25
+ from langchain_core.output_parsers import StrOutputParser # format LLM's output text into (list, dict or any custom structure we can work with)
26
+ from langchain import hub
27
+ from langchain_core.prompts import PromptTemplate
28
+ import json
29
+
30
+ # Save RAG chain builded from PDF
31
+ if 'rag_chain' not in st.session_state:
32
+ st.session_state.rag_chain = None
33
+
34
+ # Check if models downloaded or not
35
+ if 'models_loaded' not in st.session_state:
36
+ st.session_state.models_loaded = False
37
+
38
+ # save downloaded embeding model
39
+ if 'embeddings' not in st.session_state:
40
+ st.session_state.embeddings = None
41
+
42
+ # Save downloaded LLM
43
+ if 'llm' not in st.session_state:
44
+ st.session_state.llm = None
45
+
46
+ @st.cache_resource # cache model embeddings, avoid model reloading each runtime
47
+ def load_embeddings():
48
+ return HuggingFaceEmbeddings(model_name='bkai-foundation-models/vietnamese-bi-encoder')
49
+
50
+
51
+ # set up config
52
+ nf4_config = BitsAndBytesConfig(
53
+ load_in_4bit=True,
54
+ bnb_4bit_quant_type="nf4",
55
+ bnb_4bit_use_double_quant=True,
56
+ bnb_4bit_compute_dtype=torch.bfloat16
57
+ )
58
+
59
+ #? Read huggingface token in token.txt file. Please paste your huggingface token in token.txt
60
+ @st.cache_resource
61
+ def get_hg_token():
62
+ with open('token.txt', 'r') as f:
63
+ hg_token = f.read()
64
 
65
+ @st.cache_resource
66
+ def load_llm():
67
+ # MODEL_NAME= "lmsys/vicuna-7b-v1.5"
68
+ MODEL_NAME = "google/gemma-2b-it"
69
+
70
+ model = AutoModelForCausalLM.from_pretrained(
71
+ MODEL_NAME,
72
+ quantization_config=nf4_config, # add config
73
+ torch_dtype=torch.bfloat16, # save memory using float16
74
+ # low_cpu_mem_usage=True,
75
+ token=get_hg_token(),
76
+ ).to("cuda")
77
+
78
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
79
+ model_pipeline = pipeline(
80
+ 'text-generation',
81
+ model=model,
82
+ tokenizer=tokenizer,
83
+ max_new_tokens=1024, # output token
84
+ device_map="auto" # auto allocate GPU if available
85
+ )
86
 
87
+ return HuggingFacePipeline(pipeline=model_pipeline)
88
+
89
+ def format_docs(docs):
90
+ return "\n\n".join(doc.page_content for doc in docs)
91
+
92
+ def process_pdf(uploaded_file):
93
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
94
+ tmp_file.write(uploaded_file.getvalue())
95
+ tmp_file_path = tmp_file.name
96
+
97
+ try:
98
+ loader = PyPDFLoader(tmp_file_path)
99
+ documents = loader.load()
100
+ except Exception as e:
101
+ st.error(f"Đọc file thất bại: {e}")
102
+ return None, 0
103
+
104
+ semantic_splitter = SemanticChunker(
105
+ embeddings=st.session_state.embeddings,
106
+ buffer_size=1, # total sentence collected before perform text split
107
+ breakpoint_threshold_type='percentile', # set splitting style: 'percentage' of similarity
108
+ breakpoint_threshold_amount=95, # split text if similarity score > 95%
109
+ min_chunk_size=500,
110
+ add_start_index=True, # assign index for chunk
111
+ )
112
 
113
+ docs = semantic_splitter.split_documents(documents)
114
+ vector_db = Chroma.from_documents(documents=docs,
115
+ embedding=st.session_state.embeddings)
116
 
117
+ retriever = vector_db.as_retriever()
118
+ parser = StrOutputParser()
119
 
120
+ # prompt = PromptTemplate.from_template("""
121
+ # Trả lời ngắn gọn, rõ ràng bằng tiếng việt và chỉ dựa trên thông tin có sẵn bên dưới.
122
+ # Nếu không tìm thấy thông tin, hãy nói rõ là không có dữ liệu liên quan.
123
 
124
+ # Nội dung tài liệu:
125
+ # {context}
 
 
 
 
126
 
127
+ # Câu hỏi:
128
+ # {question}
129
+
130
+ # Trả lời:
131
+ # """)
132
+
133
+
134
+ # prompt = PromptTemplate.from_template("""
135
+ # Dựa vào nội dung sau, hãy:
136
+ # 1. Tóm tắt tối đa 3 ý chính, kèm theo số trang nếu có.
137
+ # 2. Trả lời câu hỏi bằng tiếng Việt ngắn gọn và chính xác.
138
+ # 3. Nếu không có thông tin liên quan, hãy để `"Trả lời"` là `"Không có dữ liệu liên quan"`.
139
+
140
+ # Nội dung tài liệu:
141
+ # {context}
142
+
143
+ # Câu hỏi:
144
+ # {question}
145
+
146
+ # Trả lời:
147
+ # """)
148
+
149
+ prompt = PromptTemplate.from_template("""
150
+ Bạn là trợ lý AI.
151
+
152
+ Dựa vào nội dung sau, hãy:
153
+ 1. Tóm tắt tối đa 3 ý chính, kèm theo số trang nếu có.
154
+ 2. Trả lời câu hỏi bằng tiếng Việt ngắn gọn và chính xác.
155
+ 3. Nếu không thông tin liên quan, hãy để "Answer" là "Không có dữ liệu liên quan".
156
+
157
+
158
+
159
+ Đảm bảo trả kết quả **ở dạng JSON** với cấu trúc sau:
160
+ {{"main_ideas": [
161
+ {{"point": "Ý chính 1", "source": "Trang ..."}},
162
+ {{"point": "Ý chính 2", "source": "Trang ..."}},
163
+ {{"point": "Ý chính 3", "source": "Trang ..."}}
164
+ ],
165
+ "answer": "Câu trả lời ngắn gọn"
166
+ }}
167
+
168
+ Vui lòng chỉ in JSON, không giải thích thêm.
169
+
170
+ Context:
171
+ {context}
172
+
173
+ Question:
174
+ {question}
175
+
176
+ Answer:
177
+
178
+ """) #? dùng {{ }} để langchain không nhận string bên trong {} là Biến
179
+
180
+
181
+ rag_chain = (
182
+ {"context": retriever | format_docs, "question": RunnablePassthrough()}
183
+ | prompt
184
+ | st.session_state.llm
185
+ | parser
186
  )
187
+
188
+ os.unlink(tmp_file_path)
189
+ return rag_chain, len(docs)
190
+
191
+ st.set_page_config(page_title="PDF RAG Assistant", layout='wide')
192
+ st.title('PDF RAG Assistant')
193
+
194
+ st.markdown("""
195
+ **Ứng dụng AI giúp bạn hỏi đáp trực tiếp với nội dung tài liệu PDF bằng tiếng Việt**
196
+ **Cách sử dụng đơn giản:**
197
+ 1. **Upload PDF** Chọn file PDF từ máy tính và nhấn "Xử lý PDF"
198
+ 2. **Đặt câu hỏi** Nhập câu hỏi về nội dung tài liệu và nhận câu trả lời ngay lập tức
199
+ """)
200
+
201
+ #? Tải models
202
+ if not st.session_state.models_loaded:
203
+ st.info("Đang tải models...")
204
+ st.session_state.embeddings = load_embeddings()
205
+ st.session_state.llm = load_llm()
206
+ st.session_state.models_loaded = True
207
+ st.success("Models đã sẵn sàng!")
208
+ st.rerun()
209
+
210
+ #? Upload and Process PDF
211
+ uploaded_file = st.file_uploader("Upload file PDF", type="pdf")
212
+ if uploaded_file and st.button("Xử lý PDF"):
213
+ with st.spinner("Đang xử lý..."):
214
+ st.session_state.rag_chain, num_chunks = process_pdf(uploaded_file)
215
+ st.success(f"Hoàn thành! {num_chunks} chunks")
216
+
217
+
218
+ #? Answers UI
219
+ if st.session_state.rag_chain:
220
+ question = st.text_input("Đặt câu hỏi:")
221
+ if question:
222
+ with st.spinner("Đang trả lời..."):
223
+ raw_output = st.session_state.rag_chain.invoke(question)
224
+ try:
225
+ result = json.loads(raw_output)
226
+ st.write("📌 **Nội dung chính:**")
227
+ st.write("raw_output:", raw_output)
228
+ for idea in result["main_ideas"]:
229
+ st.markdown(f"- {idea['point']} (📄 {idea['source']})")
230
+
231
+ st.write("🧠 **Trả lời:**")
232
+ st.markdown(result["answer"])
233
+
234
+ except json.JSONDecodeError:
235
+ st.error("⚠️ Output không đúng JSON")
236
+ st.text(raw_output)
237
+
238
+ # answer = output.split("Answer:")[1].strip() if "Answer:" in output else output.strip()
239
+ # st.write("**Trả lời:**")
240
+ # st.write(answer)