Aman Jain commited on
Commit
c8be163
·
1 Parent(s): 15df868

Initial commit

Browse files
Files changed (3) hide show
  1. DATA/Telto_Userguide.pdf +0 -0
  2. app.py +278 -0
  3. requirements.txt +10 -0
DATA/Telto_Userguide.pdf ADDED
Binary file (542 kB). View file
 
app.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from transformers import AutoTokenizer
3
+ from langchain.docstore.document import Document
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain.vectorstores import FAISS
6
+ from langchain_community.embeddings import HuggingFaceEmbeddings
7
+ from langchain_community.vectorstores.utils import DistanceStrategy
8
+ from tqdm import tqdm
9
+ from transformers.agents import Tool, HfApiEngine, ReactJsonAgent
10
+ from huggingface_hub import InferenceClient
11
+ import os
12
+ from langchain_community.document_loaders import DirectoryLoader
13
+ from langchain_huggingface import HuggingFaceEmbeddings
14
+ from langchain_groq import ChatGroq
15
+ from groq import Groq
16
+ from typing import List, Dict
17
+ from transformers.agents.llm_engine import MessageRole, get_clean_message_list
18
+ from huggingface_hub import InferenceClient
19
+ import streamlit as st
20
+
21
+ token = os.getenv("HF_TOKEN")
22
+ os.environ["GROQ_API_KEY"] = "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"
23
+ groq_api_key = os.getenv("GROQ_API_KEY")
24
+
25
+ # model_id="mistralai/Mistral-7B-Instruct-v0.3"
26
+ loader = DirectoryLoader('C:/Users/Saket_Sambhu/Documents/Agentic_RAG/DATA', glob="**/*.pdf", show_progress=True)
27
+ docs = loader.load()
28
+
29
+ tokenizer = AutoTokenizer.from_pretrained("thenlper/gte-small")
30
+ text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
31
+ tokenizer,
32
+ chunk_size=200,
33
+ chunk_overlap=20,
34
+ add_start_index=True,
35
+ strip_whitespace=True,
36
+ separators=["\n\n", "\n", ".", " ", ""],
37
+ )
38
+
39
+ # Split documents and remove duplicates
40
+
41
+ docs_processed = []
42
+ unique_texts = {}
43
+ for doc in tqdm(docs):
44
+ new_docs = text_splitter.split_documents([doc])
45
+ for new_doc in new_docs:
46
+ if new_doc.page_content not in unique_texts:
47
+ unique_texts[new_doc.page_content] = True
48
+ docs_processed.append(new_doc)
49
+
50
+
51
+ model_name = "thenlper/gte-small"
52
+ model_kwargs = {'device': 'cpu'}
53
+ encode_kwargs = {'normalize_embeddings': False}
54
+ embedding_model = HuggingFaceEmbeddings(
55
+ model_name=model_name,
56
+ model_kwargs=model_kwargs,
57
+ encode_kwargs=encode_kwargs
58
+ )
59
+
60
+ # Create the vector database
61
+ vectordb = FAISS.from_documents(
62
+ documents=docs_processed,
63
+ embedding=embedding_model,
64
+ distance_strategy=DistanceStrategy.COSINE,
65
+ )
66
+
67
+ class RetrieverTool(Tool):
68
+ name = "retriever"
69
+ description = "Using semantic similarity, retrieves some documents from the knowledge base that have the closest embeddings to the input query."
70
+ inputs = {
71
+ "query": {
72
+ "type": "string",
73
+ "description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
74
+ }
75
+ }
76
+ output_type = "string"
77
+
78
+ def __init__(self, vectordb, **kwargs):
79
+ super().__init__(**kwargs)
80
+ self.vectordb = vectordb
81
+
82
+ def forward(self, query: str) -> str:
83
+ assert isinstance(query, str), "Your search query must be a string"
84
+
85
+ docs = self.vectordb.similarity_search(
86
+ query,
87
+ k=7,
88
+ )
89
+
90
+ return "\nRetrieved documents:\n" + "".join(
91
+ [f"===== Document {str(i)} =====\n" + doc.page_content for i, doc in enumerate(docs)]
92
+ )
93
+
94
+
95
+ # Create an instance of the RetrieverTool
96
+ retriever_tool = RetrieverTool(vectordb)
97
+
98
+ llm = ChatGroq(
99
+ model="llama3-70b-8192",
100
+ temperature=0,
101
+ max_tokens=2048,
102
+ )
103
+
104
+ openai_role_conversions = {
105
+ MessageRole.TOOL_RESPONSE: MessageRole.USER,
106
+ }
107
+
108
+ class OpenAIEngine:
109
+ def __init__(self, model_name="llama-3.3-70b-versatile"):
110
+ print(groq_api_key)
111
+ self.model_name = model_name
112
+ self.client = Groq(
113
+ api_key=groq_api_key,
114
+ )
115
+
116
+ def __call__(self, messages, stop_sequences=[]):
117
+ messages = get_clean_message_list(messages, role_conversions=openai_role_conversions)
118
+
119
+ response = self.client.chat.completions.create(
120
+ model=self.model_name,
121
+ messages=messages,
122
+ stop=stop_sequences,
123
+ temperature=0.5,
124
+ max_tokens=2048
125
+ )
126
+ return response.choices[0].message.content
127
+
128
+ llm_engine = OpenAIEngine()
129
+
130
+
131
+ # Create the agent
132
+ agent = ReactJsonAgent(tools=[retriever_tool], llm_engine=llm_engine, max_iterations=4, verbose=2)
133
+
134
+ # Function to run the agent
135
+ def run_agentic_rag(question: str) -> str:
136
+ enhanced_question = f"""Using the information contained in your knowledge base, which you can access with the 'retriever' tool,
137
+ give a comprehensive answer to the question below.
138
+ Respond only to the question asked, response should be concise and relevant to the question.
139
+ If you cannot find information, do not give up and try calling your retriever again with different arguments!
140
+ Make sure to have covered the question completely by calling the retriever tool several times with semantically different queries.
141
+ Your queries should not be questions but affirmative form sentences: e.g. rather than "How do I load a model from the Hub in bf16?", query should be "load a model from the Hub bf16 weights".
142
+
143
+ Question:
144
+ {question}"""
145
+
146
+ return agent.run(enhanced_question)
147
+
148
+
149
+ # def get_llm_hf_inference(model_id=model_id, max_new_tokens=128, temperature=0.1):
150
+ # """
151
+ # Returns a language model for HuggingFace inference.
152
+
153
+ # Parameters:
154
+ # - model_id (str): The ID of the HuggingFace model repository.
155
+ # - max_new_tokens (int): The maximum number of new tokens to generate.
156
+ # - temperature (float): The temperature for sampling from the model.
157
+
158
+ # Returns:
159
+ # - llm (HuggingFaceEndpoint): The language model for HuggingFace inference.
160
+ # """
161
+ # llm = HuggingFaceEndpoint(
162
+ # repo_id=model_id,
163
+ # max_new_tokens=max_new_tokens,
164
+ # temperature=temperature,
165
+ # token = os.getenv("HF_TOKEN")
166
+ # )
167
+ # return llm
168
+
169
+
170
+
171
+
172
+
173
+
174
+ def get_response(chat_history, user_text):
175
+ """
176
+ Generates a response from the chatbot model.
177
+
178
+ Args:
179
+ system_message (str): The system message for the conversation.
180
+ chat_history (list): The list of previous chat messages.
181
+ user_text (str): The user's input text.
182
+ model_id (str, optional): The ID of the HuggingFace model to use.
183
+ eos_token_id (list, optional): The list of end-of-sentence token IDs.
184
+ max_new_tokens (int, optional): The maximum number of new tokens to generate.
185
+ get_llm_hf_kws (dict, optional): Additional keyword arguments for the get_llm_hf function.
186
+
187
+ Returns:
188
+ tuple: A tuple containing the generated response and the updated chat history.
189
+ """
190
+
191
+ # Update the chat history
192
+ chat_history.append({'role': 'user', 'content': user_text})
193
+ chat_history.append({'role': 'assistant', 'content': run_agentic_rag(user_text)})
194
+ return run_agentic_rag(user_text), chat_history
195
+
196
+
197
+ st.set_page_config(page_title="Hi, I am Telto assistant", page_icon="🤗")
198
+ st.title("Telto Support")
199
+ st.markdown(f"*This is telto assistant. For any guidance on how to use Telto, feel free to ask me.*")
200
+
201
+ # Initialize session state for avatars
202
+ if "avatars" not in st.session_state:
203
+ st.session_state.avatars = {'user': None, 'assistant': None}
204
+
205
+ # Initialize session state for user text input
206
+ if 'user_text' not in st.session_state:
207
+ st.session_state.user_text = None
208
+
209
+ if "system_message" not in st.session_state:
210
+ st.session_state.system_message = "friendly AI conversing with a human user"
211
+
212
+ if "starter_message" not in st.session_state:
213
+ st.session_state.starter_message = "Hello, there! How can I help you today?"
214
+
215
+ # Sidebar for settings
216
+ with st.sidebar:
217
+ st.header("System Settings")
218
+
219
+ # Avatar Selection
220
+ st.markdown("*Select Avatars:*")
221
+ col1, col2 = st.columns(2)
222
+ with col1:
223
+ st.session_state.avatars['assistant'] = st.selectbox(
224
+ "AI Avatar", options=["🤗", "💬", "🤖"], index=0
225
+ )
226
+ with col2:
227
+ st.session_state.avatars['user'] = st.selectbox(
228
+ "User Avatar", options=["👤", "👱‍♂️", "👨🏾", "👩", "👧🏾"], index=0
229
+ )
230
+ # Reset Chat History
231
+ reset_history = st.button("Reset Chat History")
232
+
233
+ # Initialize or reset chat history
234
+ if "chat_history" not in st.session_state or reset_history:
235
+ st.session_state.chat_history = [{"role": "assistant", "content": st.session_state.starter_message}]
236
+ # Chat interface
237
+
238
+
239
+ chat_interface = st.container(border=True)
240
+ with chat_interface:
241
+ output_container = st.container()
242
+ st.session_state.user_text = st.chat_input(placeholder="Enter your text here.")
243
+
244
+ # Display chat messages
245
+ with output_container:
246
+ # For every message in the history
247
+ for message in st.session_state.chat_history:
248
+ # Skip the system message
249
+ if message['role'] == 'system':
250
+ continue
251
+
252
+ # Display the chat message using the correct avatar
253
+ with st.chat_message(message['role'],
254
+ avatar=st.session_state['avatars'][message['role']]):
255
+ st.markdown(message['content'])
256
+
257
+ # When the user enter new text:
258
+ if st.session_state.user_text:
259
+
260
+ # Display the user's new message immediately
261
+ with st.chat_message("user",
262
+ avatar=st.session_state.avatars['user']):
263
+ st.markdown(st.session_state.user_text)
264
+
265
+ # Display a spinner status bar while waiting for the response
266
+ with st.chat_message("assistant",
267
+ avatar=st.session_state.avatars['assistant']):
268
+
269
+ with st.spinner("Thinking..."):
270
+ # Call the Inference API with the system_prompt, user text, and history
271
+
272
+
273
+ response, st.session_state.chat_history = get_response(
274
+ user_text=st.session_state.user_text,
275
+ chat_history=st.session_state.chat_history,
276
+ )
277
+ st.markdown(response)
278
+
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ langchain
3
+ langchain-community
4
+ sentence-transformers
5
+ faiss-cpu
6
+ groq
7
+ langchain-groq
8
+ unstructured
9
+ "unstructured[pdf]"
10
+ langchain-huggingface