wt002 commited on
Commit
26e01cd
Β·
verified Β·
1 Parent(s): 5a9cf8b

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +31 -87
agent.py CHANGED
@@ -42,7 +42,7 @@ from io import StringIO
42
  from transformers import BertTokenizer, BertModel
43
  import torch
44
  import torch.nn.functional as F
45
-
46
 
47
  load_dotenv()
48
 
@@ -333,81 +333,8 @@ for task in tasks:
333
 
334
 
335
  # -------------------------------
336
- # Step 1: Define the planner function
337
- # -------------------------------
338
- def planner(question: str):
339
- """Break down the question into smaller tasks"""
340
- if "how many" in question and "albums" in question:
341
- return ["Retrieve album list", "Filter by date", "Count albums"]
342
- elif "who" in question:
343
- return ["Retrieve biography", "Find related works"]
344
- return ["Default task"]
345
-
346
- # -------------------------------
347
- # Step 2: Task Classifier (decides the best tool to use)
348
- # -------------------------------
349
- def task_classifier(question: str):
350
- """Classify the question to select the best tool"""
351
- if "calculate" in question or any(op in question for op in ["+", "-", "*", "/"]):
352
- return "math"
353
- elif "album" in question or "music" in question:
354
- return "wiki_search"
355
- elif "file" in question or "attachment" in question:
356
- return "file_analysis"
357
- return "default_tool"
358
-
359
- # -------------------------------
360
- # Step 3: Decide Task Function
361
- # -------------------------------
362
- def decide_task(state):
363
- """Logic to decide what to do based on prior actions and results"""
364
- if "no relevant documents" in state.get("last_response", ""):
365
- return "web_search"
366
- if "not found" in state.get("last_response", "").lower():
367
- return "wiki_search"
368
- return "final_answer"
369
-
370
- # -------------------------------
371
- # Step 4: Node Skipper (Skip unnecessary nodes)
372
  # -------------------------------
373
- def node_skipper(state):
374
- """Skip unnecessary nodes based on context"""
375
- if "just generate" in state.get("question", "").lower():
376
- return "answer_generation" # Skip all tools and just generate the answer.
377
- return None # Continue to the next tool or node
378
-
379
-
380
-
381
- # -------------------------------
382
- # Step 4: Set up HuggingFace Embeddings and FAISS VectorStore
383
- # -------------------------------
384
- # Initialize HuggingFace Embedding model
385
- #embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
386
- #embedding_model = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
387
-
388
-
389
-
390
- class BERTEmbeddings(Embeddings):
391
- def __init__(self, model_name='bert-base-uncased'):
392
- self.tokenizer = BertTokenizer.from_pretrained(model_name)
393
- self.model = BertModel.from_pretrained(model_name)
394
- self.model.eval() # Set to evaluation mode
395
-
396
- def embed_documents(self, texts: List[str]) -> List[List[float]]:
397
- return self._embed(texts)
398
-
399
- def embed_query(self, text: str) -> List[float]:
400
- return self._embed([text])[0]
401
-
402
- def _embed(self, texts: List[str]) -> List[List[float]]:
403
- inputs = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
404
- with torch.no_grad():
405
- outputs = self.model(**inputs)
406
- embeddings = outputs.last_hidden_state.mean(dim=1) # Mean pooling
407
- return embeddings.cpu().numpy().tolist()
408
-
409
- # Example usage of BERTEmbedding with LangChain
410
-
411
 
412
  # -----------------------------
413
  # 1. Define Custom BERT Embedding Model
@@ -517,6 +444,7 @@ agent = initialize_agent(
517
  # -------------------------------
518
  # Step 8: Use the Planner, Classifier, and Decision Logic
519
  # -------------------------------
 
520
  def process_question(question):
521
  # Step 1: Planner generates the task sequence
522
  tasks = planner(question)
@@ -537,18 +465,31 @@ def process_question(question):
537
  print(f"Skipping to {skip}")
538
  return skip # Or move directly to generating answer
539
 
540
- # Execute the task via the agent
541
- response = agent.run(question)
542
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
543
 
544
- # -------------------------------
545
- # Step 9: Run the Planner-Agent Workflow
546
- # -------------------------------
547
  question = "How many albums did Mercedes Sosa release between 2000 and 2009?"
548
  response = process_question(question)
549
  print("Final Response:", response)
550
 
551
 
 
552
  question_retriever_tool = create_retriever_tool(
553
  retriever=retriever,
554
  name="Question_Search",
@@ -561,12 +502,12 @@ def retriever(state: MessagesState):
561
  """Retriever node using similarity scores for filtering"""
562
  query = state["messages"][0].content
563
  results = vector_store.similarity_search_with_score(query, k=4) # top 4 matches
564
-
565
  # Dynamically adjust threshold based on query complexity
566
  threshold = 0.75 if "who" in query else 0.8
567
-
568
  filtered = [doc for doc, score in results if score < threshold]
569
-
 
570
  if not filtered:
571
  example_msg = HumanMessage(content="No relevant documents found.")
572
  else:
@@ -605,9 +546,12 @@ def get_llm(provider: str, config: dict):
605
 
606
 
607
  def generate_final_answer(state, tools_results):
608
- # Combine results from all tools
609
- # For example, if both the calculator and the Wikipedia tool were used:
610
- final_answer = f"Answer: {tools_results['wiki_search']}\nAdditional Info: {tools_results['calculator']}"
 
 
 
611
  return final_answer
612
 
613
 
 
42
  from transformers import BertTokenizer, BertModel
43
  import torch
44
  import torch.nn.functional as F
45
+ import time
46
 
47
  load_dotenv()
48
 
 
333
 
334
 
335
  # -------------------------------
336
+ # Step 4: Set up BERT Embeddings and FAISS VectorStore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  # -------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
 
339
  # -----------------------------
340
  # 1. Define Custom BERT Embedding Model
 
444
  # -------------------------------
445
  # Step 8: Use the Planner, Classifier, and Decision Logic
446
  # -------------------------------
447
+
448
  def process_question(question):
449
  # Step 1: Planner generates the task sequence
450
  tasks = planner(question)
 
465
  print(f"Skipping to {skip}")
466
  return skip # Or move directly to generating answer
467
 
468
+ # Step 5: Execute task (with error handling)
469
+ try:
470
+ if task_type == "wiki_search":
471
+ response = wiki_search_tool(question)
472
+ elif task_type == "math":
473
+ response = calculator_tool(question)
474
+ else:
475
+ response = "Default answer logic"
476
+
477
+ # Step 6: Final response formatting
478
+ final_response = generate_final_answer(state, {'wiki_search': response})
479
+ return final_response
480
+
481
+ except Exception as e:
482
+ print(f"Error executing task: {e}")
483
+ return "Sorry, I encountered an error processing your request."
484
 
485
+
486
+ # Run the process
 
487
  question = "How many albums did Mercedes Sosa release between 2000 and 2009?"
488
  response = process_question(question)
489
  print("Final Response:", response)
490
 
491
 
492
+
493
  question_retriever_tool = create_retriever_tool(
494
  retriever=retriever,
495
  name="Question_Search",
 
502
  """Retriever node using similarity scores for filtering"""
503
  query = state["messages"][0].content
504
  results = vector_store.similarity_search_with_score(query, k=4) # top 4 matches
505
+
506
  # Dynamically adjust threshold based on query complexity
507
  threshold = 0.75 if "who" in query else 0.8
 
508
  filtered = [doc for doc, score in results if score < threshold]
509
+
510
+ # Provide a default message if no documents found
511
  if not filtered:
512
  example_msg = HumanMessage(content="No relevant documents found.")
513
  else:
 
546
 
547
 
548
  def generate_final_answer(state, tools_results):
549
+ final_answer = ""
550
+
551
+ # Concatenate results from each tool (wiki_search, calculator, etc.)
552
+ for tool_name, result in tools_results.items():
553
+ final_answer += f"{tool_name} result: {result}\n"
554
+
555
  return final_answer
556
 
557