LamiaYT commited on
Commit
8d36e0e
·
1 Parent(s): 7343388
Files changed (1) hide show
  1. agent.py +151 -131
agent.py CHANGED
@@ -14,13 +14,14 @@ serper_api_key = os.getenv("SERPER_API_KEY")
14
  # ---- Imports ----
15
  from langgraph.graph import START, StateGraph, MessagesState
16
  from langgraph.prebuilt import tools_condition, ToolNode
17
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
18
  from langchain_community.tools.tavily_search import TavilySearchResults
19
  from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
20
  from langchain_community.vectorstores import Chroma
21
  from langchain_core.documents import Document
22
- from langchain_core.messages import SystemMessage, HumanMessage
23
  from langchain_core.tools import tool
 
24
  from langchain.tools.retriever import create_retriever_tool
25
  from langchain.vectorstores import Chroma
26
  from langchain.embeddings import HuggingFaceEmbeddings
@@ -32,6 +33,61 @@ import re
32
  import math
33
  from datetime import datetime
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  # ---- Enhanced Tools ----
36
 
37
  @tool
@@ -105,16 +161,25 @@ def compound_interest(principal: float, rate: float, time: float, n: int = 1) ->
105
  """Calculate compound interest"""
106
  return principal * (1 + rate/n) ** (n * time)
107
 
 
 
 
 
 
 
 
 
 
108
  @tool
109
  def wiki_search(query: str) -> str:
110
  """Search Wikipedia for information"""
111
  try:
112
- search_docs = WikipediaLoader(query=query, load_max_docs=3).load()
113
  if not search_docs:
114
  return "No Wikipedia results found."
115
 
116
  formatted = "\n\n---\n\n".join([
117
- f'<Document source="{doc.metadata.get("source", "Wikipedia")}" title="{doc.metadata.get("title", "Unknown")}"/>\n{doc.page_content[:2000]}\n</Document>'
118
  for doc in search_docs
119
  ])
120
  return formatted
@@ -125,12 +190,12 @@ def wiki_search(query: str) -> str:
125
  def web_search(query: str) -> str:
126
  """Search the web using Tavily"""
127
  try:
128
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
129
  if not search_docs:
130
  return "No web search results found."
131
 
132
  formatted = "\n\n---\n\n".join([
133
- f'<Document source="{doc.get("url", "Unknown")}" title="{doc.get("title", "Unknown")}"/>\n{doc.get("content", "")[:2000]}\n</Document>'
134
  for doc in search_docs
135
  ])
136
  return formatted
@@ -138,56 +203,24 @@ def web_search(query: str) -> str:
138
  return f"Web search error: {str(e)}"
139
 
140
  @tool
141
- def arxiv_search(query: str) -> str:
142
- """Search ArXiv for academic papers"""
143
- try:
144
- search_docs = ArxivLoader(query=query, load_max_docs=2).load()
145
- if not search_docs:
146
- return "No ArXiv results found."
147
-
148
- formatted = "\n\n---\n\n".join([
149
- f'<Document source="{doc.metadata.get("source", "ArXiv")}" title="{doc.metadata.get("Title", "Unknown")}"/>\n{doc.page_content[:1500]}\n</Document>'
150
- for doc in search_docs
151
- ])
152
- return formatted
153
- except Exception as e:
154
- return f"ArXiv search error: {str(e)}"
155
-
156
- @tool
157
- def serper_search(query: str) -> str:
158
- """Enhanced web search using Serper API"""
159
- if not serper_api_key:
160
- return "Serper API key not available"
161
-
162
  try:
163
- url = "https://google.serper.dev/search"
164
- payload = json.dumps({
165
- "q": query,
166
- "num": 5
167
- })
168
- headers = {
169
- 'X-API-KEY': serper_api_key,
170
- 'Content-Type': 'application/json'
171
- }
172
-
173
- response = requests.request("POST", url, headers=headers, data=payload)
174
- results = response.json()
175
 
176
- if 'organic' not in results:
177
- return "No search results found"
178
-
179
- formatted = "\n\n---\n\n".join([
180
- f'<Document source="{result.get("link", "Unknown")}" title="{result.get("title", "Unknown")}"/>\n{result.get("snippet", "")}\n</Document>'
181
- for result in results['organic'][:3]
182
- ])
183
- return formatted
184
  except Exception as e:
185
- return f"Serper search error: {str(e)}"
186
 
187
  # ---- Embedding & Vector Store Setup ----
188
  def setup_vector_store():
189
  try:
190
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
191
 
192
  # Check if metadata.jsonl exists and load it
193
  if os.path.exists('metadata.jsonl'):
@@ -195,16 +228,20 @@ def setup_vector_store():
195
  with open('metadata.jsonl', 'r') as jsonl_file:
196
  for line in jsonl_file:
197
  if line.strip(): # Skip empty lines
198
- json_QA.append(json.loads(line))
 
 
 
199
 
200
  if json_QA:
201
- documents = [
202
- Document(
203
- page_content=f"Question: {sample.get('Question', '')}\n\nFinal answer: {sample.get('Final answer', '')}",
204
- metadata={"source": sample.get("task_id", "unknown")}
205
- )
206
- for sample in json_QA if sample.get('Question') and sample.get('Final answer')
207
- ]
 
208
 
209
  if documents:
210
  vector_store = Chroma.from_documents(
@@ -228,7 +265,6 @@ def setup_vector_store():
228
 
229
  except Exception as e:
230
  print(f"Vector store setup error: {e}")
231
- # Return a dummy vector store function
232
  return None
233
 
234
  vector_store = setup_vector_store()
@@ -237,15 +273,15 @@ vector_store = setup_vector_store()
237
  def similar_question_search(query: str) -> str:
238
  """Search for similar questions in the knowledge base"""
239
  if not vector_store:
240
- return "Vector store not available"
241
 
242
  try:
243
- matched_docs = vector_store.similarity_search(query, 3)
244
  if not matched_docs:
245
  return "No similar questions found"
246
 
247
- formatted = "\n\n---\n\n".join([
248
- f'<Document source="{doc.metadata.get("source", "Unknown")}" />\n{doc.page_content[:1000]}\n</Document>'
249
  for doc in matched_docs
250
  ])
251
  return formatted
@@ -254,110 +290,97 @@ def similar_question_search(query: str) -> str:
254
 
255
  # ---- Enhanced System Prompt ----
256
  system_prompt = """
257
- You are an expert assistant capable of solving complex questions using available tools. You have access to:
258
 
259
- 1. Mathematical tools: add, subtract, multiply, divide, modulus, power, square_root, factorial, gcd, lcm, percentage, compound_interest
260
- 2. Search tools: wiki_search, web_search, arxiv_search, serper_search, similar_question_search
 
261
 
262
- IMPORTANT INSTRUCTIONS:
263
- 1. Break down complex questions into smaller steps
264
- 2. Use tools systematically to gather information and perform calculations
265
- 3. For mathematical problems, show your work step by step
266
- 4. For factual questions, search for current and accurate information
267
- 5. Cross-reference information from multiple sources when possible
268
- 6. Be precise with numbers - avoid rounding unless necessary
269
 
270
- When providing your final answer, use this exact format:
271
- FINAL ANSWER: [YOUR ANSWER]
272
 
273
- Rules for the final answer:
274
- - Numbers: Use plain digits without commas, units, or symbols (unless specifically requested)
275
- - Strings: Use exact names without articles or abbreviations
276
- - Lists: Comma-separated values following the above rules
277
- - Be concise and accurate
278
 
279
- Think step by step and use the available tools to ensure accuracy.
280
  """
281
 
282
  sys_msg = SystemMessage(content=system_prompt)
283
 
284
- # ---- Enhanced Tool List ----
285
  tools = [
286
  # Math tools
287
  multiply, add, subtract, divide, modulus, power, square_root,
288
- factorial, gcd, lcm, percentage, compound_interest,
289
  # Search tools
290
- wiki_search, web_search, arxiv_search, serper_search, similar_question_search
291
  ]
292
 
293
  # ---- Graph Definition ----
294
  def build_graph(provider: str = "huggingface"):
295
- """Build the agent graph with improved HuggingFace model"""
296
 
297
  if provider == "huggingface":
298
- # Use a more capable model from HuggingFace
299
- try:
300
- # Try with a well-supported model first
301
- endpoint = HuggingFaceEndpoint(
302
- repo_id="google/flan-t5-base", # This model works well with the current setup
303
- temperature=0.1,
304
- huggingfacehub_api_token=hf_token,
305
- max_new_tokens=512,
306
- task="text2text-generation"
307
- )
308
- llm = ChatHuggingFace(llm=endpoint)
309
- except Exception as e:
310
- print(f"Failed to initialize google/flan-t5-base: {e}")
311
- # Fallback to another model
312
  try:
313
- endpoint = HuggingFaceEndpoint(
314
- repo_id="microsoft/DialoGPT-medium",
315
- temperature=0.1,
316
- huggingfacehub_api_token=hf_token,
317
- max_new_tokens=512
318
- )
319
- llm = ChatHuggingFace(llm=endpoint)
320
- except Exception as e2:
321
- print(f"Failed to initialize DialoGPT-medium: {e2}")
322
- # Final fallback
323
- endpoint = HuggingFaceEndpoint(
324
- repo_id="bigscience/bloom-560m",
325
- temperature=0.1,
326
- huggingfacehub_api_token=hf_token,
327
- max_new_tokens=256
328
- )
329
- llm = ChatHuggingFace(llm=endpoint)
330
  else:
331
- raise ValueError("Only 'huggingface' provider is supported in this version.")
332
 
333
- llm_with_tools = llm.bind_tools(tools)
 
 
334
 
335
  def assistant(state: MessagesState):
336
- """Enhanced assistant node with better error handling"""
337
  try:
338
  messages = state["messages"]
339
- response = llm_with_tools.invoke(messages)
340
  return {"messages": [response]}
341
  except Exception as e:
342
  print(f"Assistant error: {e}")
343
- # Fallback response
344
- fallback_msg = HumanMessage(content=f"I encountered an error: {str(e)}. Let me try a simpler approach.")
345
- return {"messages": [fallback_msg]}
346
 
347
  def retriever(state: MessagesState):
348
- """Enhanced retriever with better context injection"""
349
  messages = state["messages"]
350
  user_query = messages[-1].content if messages else ""
351
 
352
- # Try to find similar questions
353
  context_messages = [sys_msg]
354
 
 
355
  if vector_store:
356
  try:
357
- similar = vector_store.similarity_search(user_query, k=2)
358
  if similar:
359
  context_msg = HumanMessage(
360
- content=f"Here are similar questions for context:\n\n{similar[0].page_content}"
361
  )
362
  context_messages.append(context_msg)
363
  except Exception as e:
@@ -365,16 +388,13 @@ def build_graph(provider: str = "huggingface"):
365
 
366
  return {"messages": context_messages + messages}
367
 
368
- # Build the graph
369
  builder = StateGraph(MessagesState)
370
  builder.add_node("retriever", retriever)
371
  builder.add_node("assistant", assistant)
372
- builder.add_node("tools", ToolNode(tools))
373
 
374
- # Define edges
375
  builder.add_edge(START, "retriever")
376
  builder.add_edge("retriever", "assistant")
377
- builder.add_conditional_edges("assistant", tools_condition)
378
- builder.add_edge("tools", "assistant")
379
 
380
  return builder.compile()
 
14
  # ---- Imports ----
15
  from langgraph.graph import START, StateGraph, MessagesState
16
  from langgraph.prebuilt import tools_condition, ToolNode
17
+ from langchain_huggingface import HuggingFaceEmbeddings
18
  from langchain_community.tools.tavily_search import TavilySearchResults
19
  from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
20
  from langchain_community.vectorstores import Chroma
21
  from langchain_core.documents import Document
22
+ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
23
  from langchain_core.tools import tool
24
+ from langchain_core.language_models.base import BaseLanguageModel
25
  from langchain.tools.retriever import create_retriever_tool
26
  from langchain.vectorstores import Chroma
27
  from langchain.embeddings import HuggingFaceEmbeddings
 
33
  import math
34
  from datetime import datetime
35
 
36
+ # Custom HuggingFace LLM wrapper
37
+ class SimpleHuggingFaceLLM(BaseLanguageModel):
38
+ def __init__(self, repo_id: str, hf_token: str):
39
+ super().__init__()
40
+ self.repo_id = repo_id
41
+ self.hf_token = hf_token
42
+ self.api_url = f"https://api-inference.huggingface.co/models/{repo_id}"
43
+ self.headers = {"Authorization": f"Bearer {hf_token}"}
44
+
45
+ def _generate(self, messages, stop=None, run_manager=None, **kwargs):
46
+ # Convert messages to a single prompt
47
+ if isinstance(messages, list):
48
+ prompt = messages[-1].content if messages else ""
49
+ else:
50
+ prompt = str(messages)
51
+
52
+ payload = {
53
+ "inputs": prompt,
54
+ "parameters": {
55
+ "max_new_tokens": 512,
56
+ "temperature": 0.1,
57
+ "return_full_text": False
58
+ }
59
+ }
60
+
61
+ try:
62
+ response = requests.post(self.api_url, headers=self.headers, json=payload)
63
+ if response.status_code == 200:
64
+ result = response.json()
65
+ if isinstance(result, list) and len(result) > 0:
66
+ generated_text = result[0].get('generated_text', '')
67
+ else:
68
+ generated_text = str(result)
69
+
70
+ from langchain_core.outputs import LLMResult, Generation
71
+ return LLMResult(generations=[[Generation(text=generated_text)]])
72
+ else:
73
+ return LLMResult(generations=[[Generation(text=f"Error: {response.status_code}")]])
74
+ except Exception as e:
75
+ return LLMResult(generations=[[Generation(text=f"Error: {str(e)}")]])
76
+
77
+ def invoke(self, input, config=None, **kwargs):
78
+ if isinstance(input, list):
79
+ prompt = input[-1].content if input else ""
80
+ else:
81
+ prompt = str(input)
82
+
83
+ result = self._generate(prompt)
84
+ generated_text = result.generations[0][0].text
85
+ return AIMessage(content=generated_text)
86
+
87
+ @property
88
+ def _llm_type(self):
89
+ return "huggingface_custom"
90
+
91
  # ---- Enhanced Tools ----
92
 
93
  @tool
 
161
  """Calculate compound interest"""
162
  return principal * (1 + rate/n) ** (n * time)
163
 
164
+ @tool
165
+ def calculate_average(numbers: str) -> float:
166
+ """Calculate average of comma-separated numbers"""
167
+ try:
168
+ nums = [float(x.strip()) for x in numbers.split(',')]
169
+ return sum(nums) / len(nums)
170
+ except:
171
+ return 0.0
172
+
173
  @tool
174
  def wiki_search(query: str) -> str:
175
  """Search Wikipedia for information"""
176
  try:
177
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
178
  if not search_docs:
179
  return "No Wikipedia results found."
180
 
181
  formatted = "\n\n---\n\n".join([
182
+ f'Wikipedia: {doc.metadata.get("title", "Unknown")}\n{doc.page_content[:1500]}'
183
  for doc in search_docs
184
  ])
185
  return formatted
 
190
  def web_search(query: str) -> str:
191
  """Search the web using Tavily"""
192
  try:
193
+ search_docs = TavilySearchResults(max_results=2).invoke(query=query)
194
  if not search_docs:
195
  return "No web search results found."
196
 
197
  formatted = "\n\n---\n\n".join([
198
+ f'Web: {doc.get("title", "Unknown")}\n{doc.get("content", "")[:1500]}'
199
  for doc in search_docs
200
  ])
201
  return formatted
 
203
  return f"Web search error: {str(e)}"
204
 
205
  @tool
206
+ def simple_calculation(expression: str) -> str:
207
+ """Safely evaluate simple mathematical expressions"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  try:
209
+ # Remove any non-mathematical characters for safety
210
+ safe_chars = set('0123456789+-*/.() ')
211
+ if not all(c in safe_chars for c in expression):
212
+ return "Invalid characters in expression"
 
 
 
 
 
 
 
 
213
 
214
+ # Evaluate the expression
215
+ result = eval(expression)
216
+ return str(result)
 
 
 
 
 
217
  except Exception as e:
218
+ return f"Calculation error: {str(e)}"
219
 
220
  # ---- Embedding & Vector Store Setup ----
221
  def setup_vector_store():
222
  try:
223
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
224
 
225
  # Check if metadata.jsonl exists and load it
226
  if os.path.exists('metadata.jsonl'):
 
228
  with open('metadata.jsonl', 'r') as jsonl_file:
229
  for line in jsonl_file:
230
  if line.strip(): # Skip empty lines
231
+ try:
232
+ json_QA.append(json.loads(line))
233
+ except:
234
+ continue
235
 
236
  if json_QA:
237
+ documents = []
238
+ for sample in json_QA:
239
+ if sample.get('Question') and sample.get('Final answer'):
240
+ doc = Document(
241
+ page_content=f"Question: {sample['Question']}\n\nAnswer: {sample['Final answer']}",
242
+ metadata={"source": sample.get("task_id", "unknown")}
243
+ )
244
+ documents.append(doc)
245
 
246
  if documents:
247
  vector_store = Chroma.from_documents(
 
265
 
266
  except Exception as e:
267
  print(f"Vector store setup error: {e}")
 
268
  return None
269
 
270
  vector_store = setup_vector_store()
 
273
  def similar_question_search(query: str) -> str:
274
  """Search for similar questions in the knowledge base"""
275
  if not vector_store:
276
+ return "No similar questions available"
277
 
278
  try:
279
+ matched_docs = vector_store.similarity_search(query, k=2)
280
  if not matched_docs:
281
  return "No similar questions found"
282
 
283
+ formatted = "\n\n".join([
284
+ f'Similar Q&A:\n{doc.page_content[:800]}'
285
  for doc in matched_docs
286
  ])
287
  return formatted
 
290
 
291
  # ---- Enhanced System Prompt ----
292
  system_prompt = """
293
+ You are an expert assistant that can solve various types of questions using available tools.
294
 
295
+ Available tools:
296
+ - Math: add, subtract, multiply, divide, modulus, power, square_root, factorial, gcd, lcm, percentage, compound_interest, calculate_average, simple_calculation
297
+ - Search: wiki_search, web_search, similar_question_search
298
 
299
+ Instructions:
300
+ 1. Read the question carefully
301
+ 2. Break down complex problems into steps
302
+ 3. Use appropriate tools to gather information or perform calculations
303
+ 4. Think step by step and show your reasoning
304
+ 5. Provide accurate, concise answers
 
305
 
306
+ IMPORTANT: Always end your response with:
307
+ FINAL ANSWER: [your answer here]
308
 
309
+ For the final answer:
310
+ - Numbers: Use plain digits (no commas, units, or symbols unless requested)
311
+ - Text: Use exact names without articles
312
+ - Lists: Comma-separated values
 
313
 
314
+ Think carefully and use tools when needed.
315
  """
316
 
317
  sys_msg = SystemMessage(content=system_prompt)
318
 
319
+ # ---- Tool List ----
320
  tools = [
321
  # Math tools
322
  multiply, add, subtract, divide, modulus, power, square_root,
323
+ factorial, gcd, lcm, percentage, compound_interest, calculate_average, simple_calculation,
324
  # Search tools
325
+ wiki_search, web_search, similar_question_search
326
  ]
327
 
328
  # ---- Graph Definition ----
329
  def build_graph(provider: str = "huggingface"):
330
+ """Build the agent graph with custom HuggingFace integration"""
331
 
332
  if provider == "huggingface":
333
+ # Use custom HuggingFace LLM with fallback models
334
+ models_to_try = [
335
+ "google/flan-t5-base",
336
+ "microsoft/DialoGPT-medium",
337
+ "bigscience/bloom-560m"
338
+ ]
339
+
340
+ llm = None
341
+ for model_id in models_to_try:
 
 
 
 
 
342
  try:
343
+ llm = SimpleHuggingFaceLLM(repo_id=model_id, hf_token=hf_token)
344
+ print(f"Successfully initialized model: {model_id}")
345
+ break
346
+ except Exception as e:
347
+ print(f"Failed to initialize {model_id}: {e}")
348
+ continue
349
+
350
+ if llm is None:
351
+ raise ValueError("Failed to initialize any HuggingFace model")
 
 
 
 
 
 
 
 
352
  else:
353
+ raise ValueError("Only 'huggingface' provider is supported")
354
 
355
+ # Simple tool binding simulation
356
+ def llm_with_tools(messages):
357
+ return llm.invoke(messages)
358
 
359
  def assistant(state: MessagesState):
360
+ """Assistant node with enhanced error handling"""
361
  try:
362
  messages = state["messages"]
363
+ response = llm_with_tools(messages)
364
  return {"messages": [response]}
365
  except Exception as e:
366
  print(f"Assistant error: {e}")
367
+ fallback_response = AIMessage(content="I encountered an error processing your request. Let me try a simpler approach.")
368
+ return {"messages": [fallback_response]}
 
369
 
370
  def retriever(state: MessagesState):
371
+ """Enhanced retriever with context injection"""
372
  messages = state["messages"]
373
  user_query = messages[-1].content if messages else ""
374
 
 
375
  context_messages = [sys_msg]
376
 
377
+ # Add similar question context if available
378
  if vector_store:
379
  try:
380
+ similar = vector_store.similarity_search(user_query, k=1)
381
  if similar:
382
  context_msg = HumanMessage(
383
+ content=f"Here's a similar example:\n{similar[0].page_content[:500]}"
384
  )
385
  context_messages.append(context_msg)
386
  except Exception as e:
 
388
 
389
  return {"messages": context_messages + messages}
390
 
391
+ # Build simplified graph (without complex tool routing for now)
392
  builder = StateGraph(MessagesState)
393
  builder.add_node("retriever", retriever)
394
  builder.add_node("assistant", assistant)
 
395
 
396
+ # Simple linear flow
397
  builder.add_edge(START, "retriever")
398
  builder.add_edge("retriever", "assistant")
 
 
399
 
400
  return builder.compile()