Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,57 +1,39 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import os
|
4 |
-
import
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
# pipline("text-classification", model='model', device=0)
|
11 |
-
|
12 |
-
from langchain_huggingface import HuggingFaceEmbeddings # huggingface sentence_transformer embedding models
|
13 |
-
from langchain_huggingface.llms import HuggingFacePipeline # like transformer pipeline
|
14 |
-
|
15 |
-
from langchain.memory import ConversationBufferMemory # Deprecated
|
16 |
-
from langchain_community.chat_message_histories import ChatMessageHistory # Deprecated
|
17 |
-
from langchain_community.document_loaders import PyPDFLoader, TextLoader # PDF Processing
|
18 |
-
from langchain.chains import ConversationalRetrievalChain # Deprecated
|
19 |
-
from langchain_experimental.text_splitter import SemanticChunker # module for chunking text
|
20 |
-
|
21 |
-
from langchain_chroma import Chroma # AI-native vector databases (ai-native mean built for handle large-scale AI workloads efficiently)
|
22 |
-
from langchain_text_splitters import RecursiveCharacterTextSplitter # recursively divide text, then merge them together if merge_size < chunk_size
|
23 |
-
from langchain_core.runnables import RunnablePassthrough # Use for testing (make 'example' easy to execute and experiment with)
|
24 |
-
from langchain_core.output_parsers import StrOutputParser # format LLM's output text into (list, dict or any custom structure we can work with)
|
25 |
-
from langchain import hub
|
26 |
-
from langchain_core.prompts import PromptTemplate
|
27 |
-
import json
|
28 |
-
from sentence_transformers import SentenceTransformer
|
29 |
-
HF_TOKEN = st.secrets["HF_TOKEN"]
|
30 |
-
|
31 |
-
# Save RAG chain builded from PDF
|
32 |
-
if 'rag_chain' not in st.session_state:
|
33 |
-
st.session_state.rag_chain = None
|
34 |
-
|
35 |
-
# Check if models downloaded or not
|
36 |
-
if 'models_loaded' not in st.session_state:
|
37 |
-
st.session_state.models_loaded = False
|
38 |
-
|
39 |
-
# save downloaded embeding model
|
40 |
-
if 'embeddings' not in st.session_state:
|
41 |
-
st.session_state.embeddings = None
|
42 |
-
|
43 |
-
# Save downloaded LLM
|
44 |
-
if 'llm' not in st.session_state:
|
45 |
-
st.session_state.llm = None
|
46 |
-
|
47 |
-
@st.cache_resource # cache model embeddings, avoid model reloading each runtime
|
48 |
-
def load_embeddings():
|
49 |
-
return SentenceTransformer("bkai-foundation-models/vietnamese-bi-encoder")
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
|
|
|
|
52 |
@st.cache_resource
|
53 |
-
def
|
54 |
-
|
55 |
MODEL_NAME = "google/gemma-2b-it"
|
56 |
|
57 |
model = AutoModelForCausalLM.from_pretrained(
|
@@ -73,95 +55,94 @@ def load_llm():
|
|
73 |
|
74 |
return HuggingFacePipeline(pipeline=model_pipeline)
|
75 |
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
docs = []
|
83 |
|
84 |
-
|
|
|
|
|
|
|
|
|
85 |
for _, row in df.iterrows():
|
86 |
chunk_with_metadata = Document(
|
87 |
page_content=row['page_content'],
|
88 |
metadata={
|
89 |
'chunk_id': row['chunk_id'],
|
90 |
-
'document_title': row['document_title']
|
91 |
-
|
92 |
}
|
93 |
)
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
100 |
|
101 |
-
|
102 |
-
Bạn là một chuyên viên tư vấn cho khách hàng về sản phẩm bảo hiểm của công ty MB Ageas Life tại Việt Nam.
|
103 |
Hãy trả lời chuyên nghiệp, chính xác, cung cấp thông tin trước rồi hỏi câu tiếp theo. Tất cả các thông tin cung cấp đều trong phạm vi MBAL. Khi có đủ thông tin khách hàng thì mới mời khách hàng đăng ký để nhận tư vấn trên https://www.mbageas.life/
|
104 |
{context}
|
105 |
Câu hỏi: {question}
|
106 |
-
Trả lời:
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
{
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
| prompt
|
114 |
-
|
|
115 |
| parser
|
116 |
)
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
st.title(
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
if not st.session_state
|
130 |
-
|
131 |
-
|
132 |
-
st.session_state.
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
st.write("🧠 **Trả lời:**")
|
159 |
-
st.markdown(result["answer"])
|
160 |
-
|
161 |
-
except json.JSONDecodeError:
|
162 |
-
st.error("⚠️ Output không đúng JSON")
|
163 |
-
st.text(raw_output)
|
164 |
-
|
165 |
-
# answer = output.split("Answer:")[1].strip() if "Answer:" in output else output.strip()
|
166 |
-
# st.write("**Trả lời:**")
|
167 |
-
# st.write(answer)
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from langchain.llms import HuggingFacePipeline
|
3 |
+
from langchain.memory import ConversationBufferMemory
|
4 |
+
from langchain.chains import ConversationalRetrievalChain
|
5 |
+
from langchain.prompts.prompt import PromptTemplate
|
6 |
+
from langchain.embeddings import HuggingFaceEmbeddings, OpenAIEmbeddings
|
7 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
8 |
+
from langchain.schema import Document
|
9 |
+
from langchain_community.llms import HuggingFaceEndpoint
|
10 |
+
from langchain.vectorstores import Chroma
|
11 |
+
from transformers import TextStreamer
|
12 |
+
from langchain.llms import HuggingFacePipeline
|
13 |
+
from langchain.prompts import ChatPromptTemplate
|
14 |
+
from langchain.llms import HuggingFaceHub
|
15 |
import os
|
16 |
+
import pandas as pd
|
17 |
+
from langchain.vectorstores import FAISS
|
18 |
+
import subprocess
|
19 |
+
from langchain_community.llms import HuggingFaceHub
|
20 |
+
|
21 |
+
import pandas as pd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
# Configuración del modelo
|
24 |
+
MODEL_NAME = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
25 |
+
model_name = "google/gemma-2-2b"
|
26 |
+
TOKEN=os.getenv('HF_TOKEN')
|
27 |
+
subprocess.run(["huggingface-cli", "login", "--token", TOKEN, "--add-to-git-credential"])
|
28 |
+
######
|
29 |
+
# set this key as an environment variable
|
30 |
+
os.environ["HUGGINGFACEHUB_API_TOKEN"] = st.secrets["HF_TOKEN"]
|
31 |
|
32 |
+
|
33 |
+
# Initialize tokenizer
|
34 |
@st.cache_resource
|
35 |
+
def load_model():
|
36 |
+
# MODEL_NAME= "lmsys/vicuna-7b-v1.5"
|
37 |
MODEL_NAME = "google/gemma-2b-it"
|
38 |
|
39 |
model = AutoModelForCausalLM.from_pretrained(
|
|
|
55 |
|
56 |
return HuggingFacePipeline(pipeline=model_pipeline)
|
57 |
|
58 |
+
# Initialize embeddings
|
59 |
+
@st.cache_resource
|
60 |
+
def load_embeddings():
|
61 |
+
embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/bkai-foundation-models/vietnamese-bi-encoder')
|
62 |
+
# embeddings = OpenAIEmbeddings()
|
63 |
+
return embeddings
|
|
|
64 |
|
65 |
+
# Chroma Vector store
|
66 |
+
@st.cache_resource
|
67 |
+
def setup_vector():
|
68 |
+
chunks = []
|
69 |
+
df = pd.read_excel(r"chunk_metadata_template.xlsx")
|
70 |
for _, row in df.iterrows():
|
71 |
chunk_with_metadata = Document(
|
72 |
page_content=row['page_content'],
|
73 |
metadata={
|
74 |
'chunk_id': row['chunk_id'],
|
75 |
+
'document_title': row['document_title'],
|
|
|
76 |
}
|
77 |
)
|
78 |
+
chunks.append(chunk_with_metadata)
|
79 |
+
embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/bkai-foundation-models/vietnamese-bi-encoder')
|
80 |
+
return Chroma.from_documents(chunks, embedding=embeddings)
|
81 |
|
82 |
+
# Set up chain
|
83 |
+
def setup_conversation_chain():
|
84 |
+
llm = load_model()
|
85 |
+
vector = setup_vector()
|
86 |
+
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
87 |
|
88 |
+
template = """Bạn là một chuyên viên tư vấn cho khách hàng về sản phẩm bảo hiểm của công ty MB Ageas Life tại Việt Nam.
|
|
|
89 |
Hãy trả lời chuyên nghiệp, chính xác, cung cấp thông tin trước rồi hỏi câu tiếp theo. Tất cả các thông tin cung cấp đều trong phạm vi MBAL. Khi có đủ thông tin khách hàng thì mới mời khách hàng đăng ký để nhận tư vấn trên https://www.mbageas.life/
|
90 |
{context}
|
91 |
Câu hỏi: {question}
|
92 |
+
Trả lời:"""
|
93 |
+
|
94 |
+
|
95 |
+
# PROMPT = ChatPromptTemplate.from_template(template=template)
|
96 |
+
# chain = ConversationalRetrievalChain.from_llm(
|
97 |
+
# llm=llm,
|
98 |
+
# retriever=vector.as_retriever(search_kwargs={'k': 5}),
|
99 |
+
# memory=memory,
|
100 |
+
# combine_docs_chain_kwargs={"prompt": PROMPT}
|
101 |
+
# # condense_question_prompt=CUSTOM_QUESTION_PROMPT
|
102 |
+
)
|
103 |
+
chain = (
|
104 |
+
{"context": vector.as_retriever(search_kwargs={'k': 5}) | format_docs, "question": RunnablePassthrough()}
|
105 |
| prompt
|
106 |
+
| llm
|
107 |
| parser
|
108 |
)
|
109 |
+
|
110 |
+
return chain
|
111 |
+
|
112 |
+
# Streamlit
|
113 |
+
def main():
|
114 |
+
st.title("🛡️ MBAL Chatbot 🛡️")
|
115 |
+
|
116 |
+
# Inicializar la cadena de conversación
|
117 |
+
if 'conversation_chain' not in st.session_state:
|
118 |
+
st.session_state.conversation_chain = setup_conversation_chain()
|
119 |
+
|
120 |
+
# Mostrar mensajes del chat
|
121 |
+
if 'messages' not in st.session_state:
|
122 |
+
st.session_state.messages = []
|
123 |
+
|
124 |
+
for message in st.session_state.messages:
|
125 |
+
with st.chat_message(message["role"]):
|
126 |
+
st.markdown(message["content"])
|
127 |
+
|
128 |
+
# Campo de entrada para el usuario
|
129 |
+
if prompt := st.chat_input("Bạn cần tư vấn về điều gì? Hãy chia sẻ nhu cầu và thông tin của bạn nhé!"):
|
130 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
131 |
+
with st.chat_message("user"):
|
132 |
+
st.markdown(prompt)
|
133 |
+
|
134 |
+
with st.chat_message("assistant"):
|
135 |
+
message_placeholder = st.empty()
|
136 |
+
full_response = ""
|
137 |
+
|
138 |
+
# Generar respuesta
|
139 |
+
response = st.session_state.conversation_chain({"question": prompt, "chat_history": []})
|
140 |
+
full_response = response['answer']
|
141 |
+
# full_response = response.get("answer", "No response generated.")
|
142 |
+
|
143 |
+
message_placeholder.markdown(full_response)
|
144 |
+
|
145 |
+
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
146 |
+
|
147 |
+
# if __name__ == "__main__":
|
148 |
+
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|