Aman Jain commited on
Commit
e85c8bb
Β·
1 Parent(s): 5263d95

Added features

Browse files
Files changed (2) hide show
  1. app.py +232 -0
  2. requirements.txt +15 -0
app.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
3
+ from langchain_community.document_loaders import WebBaseLoader
4
+ from langchain_community.document_transformers import BeautifulSoupTransformer
5
+ import streamlit as st
6
+ from langchain_huggingface import HuggingFaceEndpoint
7
+ from langchain.indexes import VectorstoreIndexCreator
8
+ from langchain_huggingface.embeddings import HuggingFaceEmbeddings
9
+ from langchain_community.vectorstores import FAISS
10
+ from langchain_core.output_parsers import StrOutputParser
11
+ from langchain_core.prompts import PromptTemplate
12
+ from langchain.chains import RetrievalQA
13
+
14
+ model_id="mistralai/Mistral-7B-Instruct-v0.3"
15
+
16
+ def get_llm_hf_inference(model_id=model_id, max_new_tokens=128, temperature=0.1):
17
+ """
18
+ Returns a language model for HuggingFace inference.
19
+
20
+ Parameters:
21
+ - model_id (str): The ID of the HuggingFace model repository.
22
+ - max_new_tokens (int): The maximum number of new tokens to generate.
23
+ - temperature (float): The temperature for sampling from the model.
24
+
25
+ Returns:
26
+ - llm (HuggingFaceEndpoint): The language model for HuggingFace inference.
27
+ """
28
+ llm = HuggingFaceEndpoint(
29
+ repo_id=model_id,
30
+ max_new_tokens=max_new_tokens,
31
+ temperature=temperature,
32
+ token = os.getenv("HF_TOKEN")
33
+ )
34
+ return llm
35
+
36
+
37
+ st.set_page_config(page_title="Website Information Retirever Agent", page_icon="πŸ€—")
38
+ st.title("Website Information Retriever Agent")
39
+ st.markdown(f"*This is a simple chatbot that uses the HuggingFace transformers library to generate responses to your text input.It uses the model mistralai/Mistral-7B-Instruct-v0.3. You can enter the specific website url and the use the agent to gather information.*")
40
+
41
+ # Initialize session state for avatars
42
+ if "avatars" not in st.session_state:
43
+ st.session_state.avatars = {'user': None, 'assistant': None}
44
+
45
+ # Initialize session state for user text input
46
+ if 'user_text' not in st.session_state:
47
+ st.session_state.user_text = None
48
+
49
+ if "sitemap_url" not in st.session_state:
50
+ st.session_state.sitemap_url = None
51
+
52
+ # Initialize session state for model parameters
53
+ if "max_response_length" not in st.session_state:
54
+ st.session_state.max_response_length = 256
55
+
56
+ if "system_message" not in st.session_state:
57
+ st.session_state.system_message = "friendly AI conversing with a human user"
58
+
59
+ if "starter_message" not in st.session_state:
60
+ st.session_state.starter_message = "Hello, there! How can I help you today?"
61
+
62
+
63
+
64
+
65
+ # Sidebar for settings
66
+ with st.sidebar:
67
+ st.header("System Settings")
68
+
69
+ # AI Settings
70
+ st.session_state.system_message = st.text_area(
71
+ "System Message", value="You are a friendly AI conversing with a human user."
72
+ )
73
+ st.session_state.starter_message = st.text_area(
74
+ 'First AI Message', value="Hello, there! How can I help you today?"
75
+ )
76
+
77
+ # Model Settings
78
+ st.session_state.max_response_length = st.number_input(
79
+ "Max Response Length", value=256
80
+ )
81
+
82
+ # Avatar Selection
83
+ st.markdown("*Select Avatars:*")
84
+ col1, col2 = st.columns(2)
85
+ with col1:
86
+ st.session_state.avatars['assistant'] = st.selectbox(
87
+ "AI Avatar", options=["πŸ€—", "πŸ’¬", "πŸ€–"], index=0
88
+ )
89
+ with col2:
90
+ st.session_state.avatars['user'] = st.selectbox(
91
+ "User Avatar", options=["πŸ‘€", "πŸ‘±β€β™‚οΈ", "πŸ‘¨πŸΎ", "πŸ‘©", "πŸ‘§πŸΎ"], index=0
92
+ )
93
+ # Reset Chat History
94
+ reset_history = st.button("Reset Chat History")
95
+
96
+ # Initialize or reset chat history
97
+ if "chat_history" not in st.session_state or reset_history:
98
+ st.session_state.chat_history = [{"role": "assistant", "content": st.session_state.starter_message}]
99
+
100
+ if "sitemap_url" in st.session_state:
101
+ sitemap_url = st.text_input("URL to the website", value="")
102
+
103
+ if sitemap_url:
104
+ with st.spinner("Processing..."):
105
+ token = os.getenv("HF_TOKEN")
106
+ loader = WebBaseLoader([sitemap_url])
107
+ html = loader.load()
108
+
109
+ # Transform
110
+ # bs_transformer = BeautifulSoupTransformer()
111
+ # docs_transformed = bs_transformer.transform_documents(html,tags_to_extract=["span"])
112
+
113
+
114
+ text_splitter = RecursiveCharacterTextSplitter(
115
+ chunk_size=1000,
116
+ chunk_overlap=10,
117
+ add_start_index=True,
118
+ strip_whitespace=True,
119
+ separators=["\n\n", "\n", ".", " ", ""],
120
+ )
121
+ docs_processed = text_splitter.split_documents(html)
122
+
123
+
124
+ # # Create a vector store based on the crawled data
125
+ # index = VectorstoreIndexCreator().from_loaders([docs_processed])
126
+
127
+
128
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
129
+ db = FAISS.from_documents(docs_processed, embeddings)
130
+ retriever = db.as_retriever(search_kwargs={"k": 4})
131
+
132
+
133
+
134
+
135
+
136
+
137
+ def get_response(system_message, chat_history, user_text,
138
+ eos_token_id=['User'], max_new_tokens=256, get_llm_hf_kws={}):
139
+ """
140
+ Generates a response from the chatbot model.
141
+
142
+ Args:
143
+ system_message (str): The system message for the conversation.
144
+ chat_history (list): The list of previous chat messages.
145
+ user_text (str): The user's input text.
146
+ model_id (str, optional): The ID of the HuggingFace model to use.
147
+ eos_token_id (list, optional): The list of end-of-sentence token IDs.
148
+ max_new_tokens (int, optional): The maximum number of new tokens to generate.
149
+ get_llm_hf_kws (dict, optional): Additional keyword arguments for the get_llm_hf function.
150
+
151
+ Returns:
152
+ tuple: A tuple containing the generated response and the updated chat history.
153
+ """
154
+ # Set up the model
155
+ hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.1)
156
+
157
+ # Create the prompt template
158
+ prompt = PromptTemplate.from_template(
159
+ (
160
+ "[INST] {system_message}"
161
+ "\nCurrent Conversation:\n{chat_history}\n\n"
162
+ "\nUser: {user_text}.\n [/INST]"
163
+ "\nAI:"
164
+ )
165
+ )
166
+ # Make the chain and bind the prompt
167
+ chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
168
+ qa = RetrievalQA.from_chain_type(llm=hf, chain_type="refine", retriever=retriever, return_source_documents=False)
169
+ # Generate the response
170
+
171
+ response = qa.run({"query": user_text})
172
+ # response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=chat_history))
173
+ # response = response.split("AI:")[-1]
174
+ # Update the chat history
175
+ chat_history.append({'role': 'user', 'content': user_text})
176
+ chat_history.append({'role': 'assistant', 'content': response})
177
+ return response, chat_history
178
+
179
+
180
+
181
+
182
+
183
+
184
+
185
+
186
+
187
+ # Chat interface
188
+
189
+ if sitemap_url:
190
+ chat_interface = st.container(border=True)
191
+ with chat_interface:
192
+ output_container = st.container()
193
+ st.session_state.user_text = st.chat_input(placeholder="Enter your text here.")
194
+
195
+ # Display chat messages
196
+ with output_container:
197
+ # For every message in the history
198
+ for message in st.session_state.chat_history:
199
+ # Skip the system message
200
+ if message['role'] == 'system':
201
+ continue
202
+
203
+ # Display the chat message using the correct avatar
204
+ with st.chat_message(message['role'],
205
+ avatar=st.session_state['avatars'][message['role']]):
206
+ st.markdown(message['content'])
207
+
208
+ # When the user enter new text:
209
+ if st.session_state.user_text:
210
+
211
+ # Display the user's new message immediately
212
+ with st.chat_message("user",
213
+ avatar=st.session_state.avatars['user']):
214
+ st.markdown(st.session_state.user_text)
215
+
216
+ # Display a spinner status bar while waiting for the response
217
+ with st.chat_message("assistant",
218
+ avatar=st.session_state.avatars['assistant']):
219
+
220
+ with st.spinner("Thinking..."):
221
+ # Call the Inference API with the system_prompt, user text, and history
222
+
223
+
224
+ response, st.session_state.chat_history = get_response(
225
+ system_message=st.session_state.system_message,
226
+ user_text=st.session_state.user_text,
227
+ chat_history=st.session_state.chat_history,
228
+ max_new_tokens=st.session_state.max_response_length,
229
+ )
230
+ st.markdown(response)
231
+
232
+
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ smolagents
2
+ pandas
3
+ langchain
4
+ langchain-community
5
+ sentence-transformers
6
+ faiss-cpu
7
+ langchain_huggingface
8
+ langchain_core
9
+ streamlit
10
+ huggingface_hub
11
+ transformers
12
+ accelerate
13
+ langchain_text_splitters
14
+ beautifulsoup4
15
+ playwright