ngcanh commited on
Commit
a76ab68
·
verified ·
1 Parent(s): 0a3b438

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -126
app.py CHANGED
@@ -1,57 +1,39 @@
1
- import streamlit as st #? run app streamlit run file_name.py
2
- import tempfile
 
 
 
 
 
 
 
 
 
 
 
 
3
  import os
4
- import torch
5
-
6
- from transformers import (
7
- AutoTokenizer, # Tokenize Model
8
- AutoModelForCausalLM, # LLM Loader - used for loading and using pre-trained models designed for causal language modeling tasks
9
- pipeline) # pipline to setup llm-task oritented model
10
- # pipline("text-classification", model='model', device=0)
11
-
12
- from langchain_huggingface import HuggingFaceEmbeddings # huggingface sentence_transformer embedding models
13
- from langchain_huggingface.llms import HuggingFacePipeline # like transformer pipeline
14
-
15
- from langchain.memory import ConversationBufferMemory # Deprecated
16
- from langchain_community.chat_message_histories import ChatMessageHistory # Deprecated
17
- from langchain_community.document_loaders import PyPDFLoader, TextLoader # PDF Processing
18
- from langchain.chains import ConversationalRetrievalChain # Deprecated
19
- from langchain_experimental.text_splitter import SemanticChunker # module for chunking text
20
-
21
- from langchain_chroma import Chroma # AI-native vector databases (ai-native mean built for handle large-scale AI workloads efficiently)
22
- from langchain_text_splitters import RecursiveCharacterTextSplitter # recursively divide text, then merge them together if merge_size < chunk_size
23
- from langchain_core.runnables import RunnablePassthrough # Use for testing (make 'example' easy to execute and experiment with)
24
- from langchain_core.output_parsers import StrOutputParser # format LLM's output text into (list, dict or any custom structure we can work with)
25
- from langchain import hub
26
- from langchain_core.prompts import PromptTemplate
27
- import json
28
- from sentence_transformers import SentenceTransformer
29
- HF_TOKEN = st.secrets["HF_TOKEN"]
30
-
31
- # Save RAG chain builded from PDF
32
- if 'rag_chain' not in st.session_state:
33
- st.session_state.rag_chain = None
34
-
35
- # Check if models downloaded or not
36
- if 'models_loaded' not in st.session_state:
37
- st.session_state.models_loaded = False
38
-
39
- # save downloaded embeding model
40
- if 'embeddings' not in st.session_state:
41
- st.session_state.embeddings = None
42
-
43
- # Save downloaded LLM
44
- if 'llm' not in st.session_state:
45
- st.session_state.llm = None
46
-
47
- @st.cache_resource # cache model embeddings, avoid model reloading each runtime
48
- def load_embeddings():
49
- return SentenceTransformer("bkai-foundation-models/vietnamese-bi-encoder")
50
 
 
 
 
 
 
 
 
 
51
 
 
 
52
  @st.cache_resource
53
- def load_llm():
54
- # MODEL_NAME= "lmsys/vicuna-7b-v1.5"
55
  MODEL_NAME = "google/gemma-2b-it"
56
 
57
  model = AutoModelForCausalLM.from_pretrained(
@@ -73,95 +55,94 @@ def load_llm():
73
 
74
  return HuggingFacePipeline(pipeline=model_pipeline)
75
 
76
- def format_docs(docs):
77
- return "\n\n".join(doc.page_content for doc in docs)
78
-
79
- def process_pdf(uploaded_file):
80
-
81
- df = pd.read_excel("chunk_metadata_template.xlsx")
82
- docs = []
83
 
84
- # Tạo danh sách các Document có metadata
 
 
 
 
85
  for _, row in df.iterrows():
86
  chunk_with_metadata = Document(
87
  page_content=row['page_content'],
88
  metadata={
89
  'chunk_id': row['chunk_id'],
90
- 'document_title': row['document_title']
91
-
92
  }
93
  )
94
- docs.append(chunk_with_metadata)
95
- vector_db = Chroma.from_documents(documents=docs,
96
- embedding=st.session_state.embeddings)
97
 
98
- retriever = vector_db.as_retriever()
99
- parser = StrOutputParser()
 
 
 
100
 
101
- prompt = PromptTemplate.from_template("""
102
- Bạn là một chuyên viên tư vấn cho khách hàng về sản phẩm bảo hiểm của công ty MB Ageas Life tại Việt Nam.
103
  Hãy trả lời chuyên nghiệp, chính xác, cung cấp thông tin trước rồi hỏi câu tiếp theo. Tất cả các thông tin cung cấp đều trong phạm vi MBAL. Khi có đủ thông tin khách hàng thì mới mời khách hàng đăng ký để nhận tư vấn trên https://www.mbageas.life/
104
  {context}
105
  Câu hỏi: {question}
106
- Trả lời:
107
-
108
- """) #? dùng {{ }} để langchain không nhận string bên trong {} là Biến
109
-
110
-
111
- rag_chain = (
112
- {"context": retriever | format_docs, "question": RunnablePassthrough()}
 
 
 
 
 
 
113
  | prompt
114
- | st.session_state.llm
115
  | parser
116
  )
117
-
118
- os.unlink(tmp_file_path)
119
- return rag_chain, len(docs)
120
-
121
- st.set_page_config(page_title="PDF RAG Assistant", layout='wide')
122
- st.title('PDF RAG Assistant')
123
-
124
- st.markdown("""
125
- **Ứng dụng AI giúp bạn hỏi đáp trực tiếp về thông tin các gói bảo hiểm của MB Ageas Life**
126
- """)
127
-
128
- #? Tải models
129
- if not st.session_state.models_loaded:
130
- st.info("Đang tải model...")
131
- st.session_state.embeddings = load_embeddings()
132
- st.session_state.llm = load_llm()
133
- st.session_state.models_loaded = True
134
- st.success("Model đã sẵn sàng!")
135
- st.rerun()
136
-
137
- # #? Upload and Process PDF
138
- # uploaded_file = st.file_uploader("Upload file PDF", type="pdf")
139
- # if uploaded_file and st.button("Xử lý PDF"):
140
- # with st.spinner("Đang xử lý..."):
141
- # st.session_state.rag_chain, num_chunks = process_pdf(uploaded_file)
142
- # st.success(f"Hoàn thành! {num_chunks} chunks")
143
-
144
-
145
- #? Answers UI
146
- if st.session_state.rag_chain:
147
- question = st.text_input("Đặt câu hỏi:")
148
- if question:
149
- with st.spinner("Đang trả lời..."):
150
- raw_output = st.session_state.rag_chain.invoke(question)
151
- try:
152
- result = json.loads(raw_output)
153
- st.write("📌 **Nội dung chính:**")
154
- st.write("raw_output:", raw_output)
155
- for idea in result["main_ideas"]:
156
- st.markdown(f"- {idea['point']} (📄 {idea['source']})")
157
-
158
- st.write("🧠 **Trả lời:**")
159
- st.markdown(result["answer"])
160
-
161
- except json.JSONDecodeError:
162
- st.error("⚠️ Output không đúng JSON")
163
- st.text(raw_output)
164
-
165
- # answer = output.split("Answer:")[1].strip() if "Answer:" in output else output.strip()
166
- # st.write("**Trả lời:**")
167
- # st.write(answer)
 
1
+ import streamlit as st
2
+ from langchain.llms import HuggingFacePipeline
3
+ from langchain.memory import ConversationBufferMemory
4
+ from langchain.chains import ConversationalRetrievalChain
5
+ from langchain.prompts.prompt import PromptTemplate
6
+ from langchain.embeddings import HuggingFaceEmbeddings, OpenAIEmbeddings
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
8
+ from langchain.schema import Document
9
+ from langchain_community.llms import HuggingFaceEndpoint
10
+ from langchain.vectorstores import Chroma
11
+ from transformers import TextStreamer
12
+ from langchain.llms import HuggingFacePipeline
13
+ from langchain.prompts import ChatPromptTemplate
14
+ from langchain.llms import HuggingFaceHub
15
  import os
16
+ import pandas as pd
17
+ from langchain.vectorstores import FAISS
18
+ import subprocess
19
+ from langchain_community.llms import HuggingFaceHub
20
+
21
+ import pandas as pd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ # Configuración del modelo
24
+ MODEL_NAME = "mistralai/Mixtral-8x7B-Instruct-v0.1"
25
+ model_name = "google/gemma-2-2b"
26
+ TOKEN=os.getenv('HF_TOKEN')
27
+ subprocess.run(["huggingface-cli", "login", "--token", TOKEN, "--add-to-git-credential"])
28
+ ######
29
+ # set this key as an environment variable
30
+ os.environ["HUGGINGFACEHUB_API_TOKEN"] = st.secrets["HF_TOKEN"]
31
 
32
+
33
+ # Initialize tokenizer
34
  @st.cache_resource
35
+ def load_model():
36
+ # MODEL_NAME= "lmsys/vicuna-7b-v1.5"
37
  MODEL_NAME = "google/gemma-2b-it"
38
 
39
  model = AutoModelForCausalLM.from_pretrained(
 
55
 
56
  return HuggingFacePipeline(pipeline=model_pipeline)
57
 
58
+ # Initialize embeddings
59
+ @st.cache_resource
60
+ def load_embeddings():
61
+ embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/bkai-foundation-models/vietnamese-bi-encoder')
62
+ # embeddings = OpenAIEmbeddings()
63
+ return embeddings
 
64
 
65
+ # Chroma Vector store
66
+ @st.cache_resource
67
+ def setup_vector():
68
+ chunks = []
69
+ df = pd.read_excel(r"chunk_metadata_template.xlsx")
70
  for _, row in df.iterrows():
71
  chunk_with_metadata = Document(
72
  page_content=row['page_content'],
73
  metadata={
74
  'chunk_id': row['chunk_id'],
75
+ 'document_title': row['document_title'],
 
76
  }
77
  )
78
+ chunks.append(chunk_with_metadata)
79
+ embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/bkai-foundation-models/vietnamese-bi-encoder')
80
+ return Chroma.from_documents(chunks, embedding=embeddings)
81
 
82
+ # Set up chain
83
+ def setup_conversation_chain():
84
+ llm = load_model()
85
+ vector = setup_vector()
86
+ memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
87
 
88
+ template = """Bạn là một chuyên viên tư vấn cho khách hàng về sản phẩm bảo hiểm của công ty MB Ageas Life tại Việt Nam.
 
89
  Hãy trả lời chuyên nghiệp, chính xác, cung cấp thông tin trước rồi hỏi câu tiếp theo. Tất cả các thông tin cung cấp đều trong phạm vi MBAL. Khi có đủ thông tin khách hàng thì mới mời khách hàng đăng ký để nhận tư vấn trên https://www.mbageas.life/
90
  {context}
91
  Câu hỏi: {question}
92
+ Trả lời:"""
93
+
94
+
95
+ # PROMPT = ChatPromptTemplate.from_template(template=template)
96
+ # chain = ConversationalRetrievalChain.from_llm(
97
+ # llm=llm,
98
+ # retriever=vector.as_retriever(search_kwargs={'k': 5}),
99
+ # memory=memory,
100
+ # combine_docs_chain_kwargs={"prompt": PROMPT}
101
+ # # condense_question_prompt=CUSTOM_QUESTION_PROMPT
102
+ )
103
+ chain = (
104
+ {"context": vector.as_retriever(search_kwargs={'k': 5}) | format_docs, "question": RunnablePassthrough()}
105
  | prompt
106
+ | llm
107
  | parser
108
  )
109
+
110
+ return chain
111
+
112
+ # Streamlit
113
+ def main():
114
+ st.title("🛡️ MBAL Chatbot 🛡️")
115
+
116
+ # Inicializar la cadena de conversación
117
+ if 'conversation_chain' not in st.session_state:
118
+ st.session_state.conversation_chain = setup_conversation_chain()
119
+
120
+ # Mostrar mensajes del chat
121
+ if 'messages' not in st.session_state:
122
+ st.session_state.messages = []
123
+
124
+ for message in st.session_state.messages:
125
+ with st.chat_message(message["role"]):
126
+ st.markdown(message["content"])
127
+
128
+ # Campo de entrada para el usuario
129
+ if prompt := st.chat_input("Bạn cần tư vấn về điều gì? Hãy chia sẻ nhu cầu và thông tin của bạn nhé!"):
130
+ st.session_state.messages.append({"role": "user", "content": prompt})
131
+ with st.chat_message("user"):
132
+ st.markdown(prompt)
133
+
134
+ with st.chat_message("assistant"):
135
+ message_placeholder = st.empty()
136
+ full_response = ""
137
+
138
+ # Generar respuesta
139
+ response = st.session_state.conversation_chain({"question": prompt, "chat_history": []})
140
+ full_response = response['answer']
141
+ # full_response = response.get("answer", "No response generated.")
142
+
143
+ message_placeholder.markdown(full_response)
144
+
145
+ st.session_state.messages.append({"role": "assistant", "content": full_response})
146
+
147
+ # if __name__ == "__main__":
148
+ main()