ngcanh commited on
Commit
d8aac59
·
verified ·
1 Parent(s): 7c4be71

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -139
app.py CHANGED
@@ -1,150 +1,118 @@
1
- __import__('pysqlite3')
2
- import sys
3
- sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
4
-
5
- # DATABASES = {
6
- # 'default': {
7
- # 'ENGINE': 'django.db.backends.sqlite3',
8
- # 'NAME': os.path.join(BASE_DIR, 'db.sqlite3'),
9
- # }
10
- # }
11
  import streamlit as st
12
- from huggingface_hub import InferenceClient
13
- from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, ServiceContext, PromptTemplate
14
- from llama_index.vector_stores.chroma import ChromaVectorStore
15
- from llama_index.core import StorageContext
16
- from langchain.embeddings import HuggingFaceEmbeddings
17
- from langchain.text_splitter import CharacterTextSplitter
18
- from langchain.vectorstores import Chroma
19
- import chromadb
20
- from langchain.memory import ConversationBufferMemory
21
- import pandas as pd
22
- from langchain.schema import Document
23
-
24
-
25
- # Set page config
26
- st.set_page_config(page_title="MBAL Chatbot", page_icon="🛡️", layout="wide")
27
-
28
- # Set your Hugging Face token here
29
-
30
- HF_TOKEN = st.secrets["HF_TOKEN"]
31
-
32
- @st.cache_resource
33
- def init_chroma():
34
- persist_directory = "chroma_db"
35
- chroma_client = chromadb.PersistentClient(path=persist_directory)
36
- chroma_collection = chroma_client.get_or_create_collection("my_collection")
37
- return chroma_client, chroma_collection
38
-
39
- @st.cache_resource
40
- def init_vectorstore():
41
- persist_directory = "chroma_db"
42
- embeddings = HuggingFaceEmbeddings()
43
- vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embeddings, collection_name="my_collection")
44
- return vectorstore
45
- @st.cache_resource
46
- def setup_vector():
47
- # Đọc dữ liệu từ file Excel
48
- df = pd.read_excel("chunk_metadata_template.xlsx")
49
- chunks = []
50
-
51
- # Tạo danh sách các Document có metadata
52
- for _, row in df.iterrows():
53
- chunk_with_metadata = Document(
54
- page_content=row['page_content'],
55
- metadata={
56
- 'chunk_id': row['chunk_id'],
57
- 'document_title': row['document_title'],
58
- 'topic': row['topic'],
59
- 'access': row['access']
60
- }
61
- )
62
- chunks.append(chunk_with_metadata)
63
-
64
- # Khởi tạo embedding
65
- embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
66
-
67
- # Khởi tạo hoặc ghi vào vectorstore đã tồn tại
68
- persist_directory = "chroma_db"
69
- collection_name = "my_collection"
70
-
71
- # Tạo vectorstore từ dữ liệu và ghi vào Chroma
72
- vectorstore = Chroma.from_documents(
73
- documents=chunks,
74
- embedding=embeddings,
75
- persist_directory=persist_directory,
76
- collection_name=collection_name
77
- )
78
-
79
- # Ghi xuống đĩa để đảm bảo dữ liệu được lưu
80
- vectorstore.persist()
81
-
82
- return vectorstore
83
-
84
- # Initialize components
85
- client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3", token=HF_TOKEN)
86
- chroma_client, chroma_collection = init_chroma()
87
- init_vectorstore()
88
- vectorstore = setup_vector()
89
-
90
- # Initialize memory buffer
91
- memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
92
-
93
- def rag_query(query):
94
- # Lấy tài liệu liên quan
95
- retrieved_docs = vectorstore.similarity_search(query, k=5)
96
- context = "\n".join([doc.page_content for doc in retrieved_docs]) if retrieved_docs else ""
97
-
98
- # Lấy tương tác cũ
99
- past_interactions = memory.load_memory_variables({})[memory.memory_key]
100
- context_with_memory = f"{context}\n\nConversation History:\n{past_interactions}"
101
-
102
- # Chuẩn bị prompt
103
- messages = [
104
- {
105
- "role": "user",
106
- "content": f"""Bạn là một chuyên gia tư vấn, hỗ trợ khách hàng lựa chọn các sản phẩm bảo hiểm của MB Ageas Life tại Việt Nam. Vui lòng phản hồi một cách chuyên nghiệp và chính xác, đồng thời đề xuất các sản phẩm phù hợp bằng cách đặt một vài câu hỏi để tìm hiểu nhu cầu của khách hàng. Mọi thông tin cung cấp phải nằm trong phạm vi các sản phẩm của MB Ageas Life. Hãy mời khách hàng đăng ký để được tư vấn chi tiết hơn tại: https://www.mbageas.life/
107
- {context_with_memory}
108
- Câu hỏi: {query}
109
- Câu trả lời:"""
110
  }
111
- ]
112
 
113
- response_content = client.chat_completion(messages=messages, max_tokens=1024, stream=False)
114
- response = response_content.choices[0].message.content.split("Answer:")[-1].strip()
115
- return response
 
 
 
 
 
 
 
 
 
 
116
 
 
 
 
117
 
118
- def process_feedback(query, response, feedback):
119
- # st.write(f"Feedback received: {'👍' if feedback else '👎'} for query: {query}")
120
- if feedback:
121
- # If thumbs up, store the response in memory buffer
122
- memory.chat_memory.add_ai_message(response)
123
- else:
124
- # If thumbs down, remove the response from memory buffer and regenerate the response
125
- # memory.chat_memory.messages = [msg for msg in memory.chat_memory.messages if msg.get("content") != response]
126
- new_query=f"{query}. Tạo câu trả lời đúng với câu hỏi"
127
- new_response = rag_query(new_query)
128
- st.markdown(new_response)
129
- memory.chat_memory.add_ai_message(new_response)
130
 
131
- # Streamlit interface
 
132
 
133
- st.title("Chào mừng bạn đã đến với MBAL Chatbot")
134
- st.markdown("***")
135
- st.info('''
136
- Tôi sẽ giải đáp các thắc mắc của bạn liên quan đến các sản phẩm bảo hiểm nhân thọ của MB Ageas Life''')
137
 
138
- col1, col2 = st.columns(2)
 
 
 
 
 
139
 
140
- with col1:
141
- chat = st.button("Chat")
142
- if chat:
143
- st.switch_page("pages/chatbot.py")
144
 
145
- with col2:
146
- rag = st.button("Store Document")
147
- if rag:
148
- st.switch_page("pages/management.py")
149
 
150
- st.markdown("<div style='text-align:center;'></div>", unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.vectorstores import FAISS
2
+ from langchain_community.embeddings import HuggingFaceEmbeddings
3
+ from langchain.prompts import PromptTemplate
4
+ import os
5
+ from langchain.memory import ConversationBufferWindowMemory
6
+ from langchain.chains import ConversationalRetrievalChain
7
+ import time
8
+ from langchain_google_genai import ChatGoogleGenerativeAI
 
 
9
  import streamlit as st
10
+ import os
11
+
12
+ st.set_page_config(page_title="MBAL CHATBOT")
13
+ col1, col2, col3 = st.columns([1,2,1])
14
+ with col2:
15
+ st.title("GymGPT 🦾")
16
+
17
+
18
+ st.sidebar.title("Welcome to GymGPT")
19
+ st.sidebar.title("Shoot your gym-related questions")
20
+ st.markdown(
21
+ """
22
+ <style>
23
+ div.stButton > button:first-child {
24
+ background-color: #ffd0d0;
25
+ }
26
+
27
+ div.stButton > button:active {
28
+ background-color: #ff6262;
29
+ }
30
+
31
+ .st-emotion-cache-6qob1r {
32
+ position: relative;
33
+ height: 100%;
34
+ width: 100%;
35
+ background-color: black;
36
+ overflow: overlay;
37
+ }
38
+
39
+ div[data-testid="stStatusWidget"] div button {
40
+ display: none;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  }
 
42
 
43
+ .reportview-container {
44
+ margin-top: -2em;
45
+ }
46
+ #MainMenu {visibility: hidden;}
47
+ .stDeployButton {display:none;}
48
+ footer {visibility: hidden;}
49
+ #stDecoration {display:none;}
50
+ button[title="View fullscreen"]{
51
+ visibility: hidden;}
52
+ </style>
53
+ """,
54
+ unsafe_allow_html=True,
55
+ )
56
 
57
+ def reset_conversation():
58
+ st.session_state.messages = []
59
+ st.session_state.memory.clear()
60
 
61
+ if "messages" not in st.session_state:
62
+ st.session_state.messages = []
 
 
 
 
 
 
 
 
 
 
63
 
64
+ if "memory" not in st.session_state:
65
+ st.session_state.memory = ConversationBufferWindowMemory(k=2, memory_key="chat_history",return_messages=True)
66
 
67
+ embeddings = HuggingFaceEmbeddings(model_name="bkai-foundation-models/vietnamese-bi-encoder", model_kwargs={"trust_remote_code": True})
68
+ db = FAISS.load_local("mbal_faiss_db", embeddings,allow_dangerous_deserialization= True)
69
+ db_retriever = db.as_retriever(search_type="similarity",search_kwargs={"k": 4})
 
70
 
71
+ prompt_template = """<s>
72
+ {context}
73
+ CHAT HISTORY: {chat_history}[/INST]
74
+ ASSISTANT:
75
+ </s>
76
+ """
77
 
78
+ prompt = PromptTemplate(template=prompt_template,
79
+ input_variables=['question', 'context', 'chat_history'])
 
 
80
 
 
 
 
 
81
 
82
+ llm = ChatGroq(temperature = 0.5,groq_api_key=os.environ["GROQ_API_KEY"],model_name="llama3-7b")
83
+
84
+ # Create a conversational chain using only your database retriever
85
+ qa = ConversationalRetrievalChain.from_llm(
86
+ llm=llm,
87
+ memory=st.session_state.memory,
88
+ retriever=db_retriever,
89
+ combine_docs_chain_kwargs={'prompt': prompt}
90
+ )
91
+
92
+ for message in st.session_state.messages:
93
+ with st.chat_message(message.get("role")):
94
+ st.write(message.get("content"))
95
+
96
+ input_prompt = st.chat_input("Say something")
97
+
98
+ if input_prompt:
99
+ with st.chat_message("user"):
100
+ st.write(input_prompt)
101
+
102
+ st.session_state.messages.append({"role":"user","content":input_prompt})
103
+
104
+ with st.chat_message("assistant"):
105
+ with st.status("Lifting data, one bit at a time 💡...",expanded=True):
106
+ result = qa.invoke(input=input_prompt)
107
+
108
+ message_placeholder = st.empty()
109
+
110
+ full_response = "⚠️ **_Note: Information provided may be inaccurate._** \n\n\n"
111
+ for chunk in result["answer"]:
112
+ full_response+=chunk
113
+ time.sleep(0.02)
114
+
115
+ message_placeholder.markdown(full_response+" ▌")
116
+ st.button('Reset All Chat 🗑️', on_click=reset_conversation)
117
+
118
+ st.session_state.messages.append({"role":"assistant","content":result["answer"]})