wt002 commited on
Commit
65f51b7
·
verified ·
1 Parent(s): 450a49d

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +117 -67
agent.py CHANGED
@@ -44,7 +44,7 @@ from transformers import BertTokenizer, BertModel
44
  import torch
45
  import torch.nn.functional as F
46
  from langchain.agents import initialize_agent, AgentType
47
- from langchain_openai import ChatOpenAI
48
  from langchain_community.tools import Tool
49
 
50
  import time
@@ -104,17 +104,12 @@ def modulus(a: int, b: int) -> int:
104
 
105
 
106
  @tool
107
- def calculator(data: dict) -> float:
108
- """
109
- Perform a calculation between two numbers.
110
-
111
- Args:
112
- data: A dictionary with keys 'a', 'b', and 'operation'
113
- """
114
- a = data.get("a")
115
- b = data.get("b")
116
- operation = data.get("operation", "").lower()
117
-
118
  if operation == "add":
119
  return a + b
120
  elif operation == "subtract":
@@ -123,12 +118,12 @@ def calculator(data: dict) -> float:
123
  return a * b
124
  elif operation == "divide":
125
  if b == 0:
126
- raise ValueError("Cannot divide by zero.")
127
  return a / b
128
  elif operation == "modulus":
129
  return a % b
130
  else:
131
- raise ValueError(f"Unsupported operation: {operation}")
132
 
133
  @tool
134
  def wiki_search(query: str) -> str:
@@ -419,33 +414,19 @@ docs = [
419
  vector_store = FAISS.from_documents(docs, embedding_model)
420
  vector_store.save_local("faiss_index")
421
 
422
- # -----------------------------
423
- # 5. Query & Filter Results (optional preview)
424
- # -----------------------------
425
- query = "How many albums did Mercedes Sosa release between 2000 and 2009?"
426
- results = vector_store.similarity_search_with_score(query, k=5)
427
- threshold = 0.75
428
- filtered = [doc for doc, score in results if score < threshold]
429
-
430
-
431
- print("\n📊 Retrieved Documents with Similarity Scores:")
432
- filtered = []
433
- for doc, score in results:
434
- print(f"🔢 Score: {score:.4f}")
435
- print(f"📄 Content: {doc.page_content}")
436
- if score < threshold:
437
- filtered.append(doc)
438
- print("✅ Accepted")
439
- else:
440
- print("❌ Rejected")
441
- print("-" * 80)
442
-
443
 
444
  # -----------------------------
445
  # 6. Create LangChain Retriever Tool
446
  # -----------------------------
447
  retriever = vector_store.as_retriever()
448
 
 
 
 
 
 
 
 
449
  # -------------------------------
450
  # Step 6: Create LangChain Tools
451
  # -------------------------------
@@ -463,9 +444,6 @@ wikiq_tool = wikidata_query
463
  # -------------------------------
464
  # Step 7: Create the Planner-Agent Logic
465
  # -------------------------------
466
- # Define the agent tool set
467
- from langchain.chat_models import ChatOpenAI
468
- from langchain.agents import initialize_agent, AgentType
469
 
470
  # Define the tools (as you've already done)
471
  tools = [wiki_tool, calc_tool, file_tool, web_tool, arvix_tool, youtube_tool, video_tool, analyze_tool, wikiq_tool]
@@ -510,14 +488,14 @@ def process_question(question):
510
  # Step 5: Execute task (with error handling)
511
  try:
512
  if task_type == "wiki_search":
513
- response = wiki_search_tool(question)
514
  elif task_type == "math":
515
- response = calculator_tool(question)
516
  else:
517
  response = "Default answer logic"
518
 
519
  # Step 6: Final response formatting
520
- final_response = generate_final_answer(state, {'wiki_search': response})
521
  return final_response
522
 
523
  except Exception as e:
@@ -527,16 +505,11 @@ def process_question(question):
527
 
528
  # Run the process
529
  question = "How many albums did Mercedes Sosa release between 2000 and 2009?"
530
- response = process_question(question)
531
  print("Final Response:", response)
532
 
533
 
534
 
535
- question_retriever_tool = create_retriever_tool(
536
- retriever=retriever,
537
- name="Question_Search",
538
- description="A tool to retrieve documents related to a user's question."
539
- )
540
 
541
 
542
 
@@ -562,39 +535,116 @@ def retriever(state: MessagesState):
562
 
563
 
564
 
565
- tools = [
566
- multiply,
567
- add,
568
- subtract,
569
- divide,
570
- modulus,
571
- wiki_search,
572
- web_search,
573
- arvix_search,
574
- ]
575
-
576
-
577
  def get_llm(provider: str, config: dict):
578
  if provider == "google":
 
579
  return ChatGoogleGenerativeAI(model=config["model"], temperature=config["temperature"])
 
580
  elif provider == "groq":
 
581
  return ChatGroq(model=config["model"], temperature=config["temperature"])
 
582
  elif provider == "huggingface":
 
 
583
  return ChatHuggingFace(
584
  llm=HuggingFaceEndpoint(url=config["url"], temperature=config["temperature"])
585
  )
 
586
  else:
587
  raise ValueError(f"Invalid provider: {provider}")
588
 
589
 
590
- def generate_final_answer(state, tools_results):
591
- final_answer = ""
592
-
593
- # Concatenate results from each tool (wiki_search, calculator, etc.)
594
- for tool_name, result in tools_results.items():
595
- final_answer += f"{tool_name} result: {result}\n"
596
-
597
- return final_answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
598
 
599
 
600
 
 
44
  import torch
45
  import torch.nn.functional as F
46
  from langchain.agents import initialize_agent, AgentType
47
+ from langchain_community.chat_models import ChatOpenAI
48
  from langchain_community.tools import Tool
49
 
50
  import time
 
104
 
105
 
106
  @tool
107
+ def calculator_tool(inputs: dict):
108
+ """Perform mathematical operations based on the operation provided."""
109
+ a = inputs.get("a")
110
+ b = inputs.get("b")
111
+ operation = inputs.get("operation")
112
+
 
 
 
 
 
113
  if operation == "add":
114
  return a + b
115
  elif operation == "subtract":
 
118
  return a * b
119
  elif operation == "divide":
120
  if b == 0:
121
+ return "Error: Division by zero"
122
  return a / b
123
  elif operation == "modulus":
124
  return a % b
125
  else:
126
+ return "Unknown operation"
127
 
128
  @tool
129
  def wiki_search(query: str) -> str:
 
414
  vector_store = FAISS.from_documents(docs, embedding_model)
415
  vector_store.save_local("faiss_index")
416
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
 
418
  # -----------------------------
419
  # 6. Create LangChain Retriever Tool
420
  # -----------------------------
421
  retriever = vector_store.as_retriever()
422
 
423
+ question_retriever_tool = create_retriever_tool(
424
+ retriever=retriever,
425
+ name="Question_Search",
426
+ description="A tool to retrieve documents related to a user's question."
427
+ )
428
+
429
+
430
  # -------------------------------
431
  # Step 6: Create LangChain Tools
432
  # -------------------------------
 
444
  # -------------------------------
445
  # Step 7: Create the Planner-Agent Logic
446
  # -------------------------------
 
 
 
447
 
448
  # Define the tools (as you've already done)
449
  tools = [wiki_tool, calc_tool, file_tool, web_tool, arvix_tool, youtube_tool, video_tool, analyze_tool, wikiq_tool]
 
488
  # Step 5: Execute task (with error handling)
489
  try:
490
  if task_type == "wiki_search":
491
+ response = wiki_tool(question)
492
  elif task_type == "math":
493
+ response = calc_tool(question)
494
  else:
495
  response = "Default answer logic"
496
 
497
  # Step 6: Final response formatting
498
+ final_response = final_answer_tool(state, {'wiki_search': response})
499
  return final_response
500
 
501
  except Exception as e:
 
505
 
506
  # Run the process
507
  question = "How many albums did Mercedes Sosa release between 2000 and 2009?"
508
+ response = agent.run(question)
509
  print("Final Response:", response)
510
 
511
 
512
 
 
 
 
 
 
513
 
514
 
515
 
 
535
 
536
 
537
 
538
+ # ----------------------------------------------------------------
539
+ # LLM Loader
540
+ # ----------------------------------------------------------------
 
 
 
 
 
 
 
 
 
541
  def get_llm(provider: str, config: dict):
542
  if provider == "google":
543
+ from langchain_google_genai import ChatGoogleGenerativeAI
544
  return ChatGoogleGenerativeAI(model=config["model"], temperature=config["temperature"])
545
+
546
  elif provider == "groq":
547
+ from langchain_groq import ChatGroq
548
  return ChatGroq(model=config["model"], temperature=config["temperature"])
549
+
550
  elif provider == "huggingface":
551
+ from langchain_huggingface import ChatHuggingFace
552
+ from langchain_huggingface import HuggingFaceEndpoint
553
  return ChatHuggingFace(
554
  llm=HuggingFaceEndpoint(url=config["url"], temperature=config["temperature"])
555
  )
556
+
557
  else:
558
  raise ValueError(f"Invalid provider: {provider}")
559
 
560
 
561
+
562
+
563
+ # ----------------------------------------------------------------
564
+ # Planning & Execution Logic
565
+ # ----------------------------------------------------------------
566
+ def planner(question: str) -> list:
567
+ if "calculate" in question or any(op in question for op in ["add", "subtract", "multiply", "divide", "modulus"]):
568
+ return ["math"]
569
+ elif "wiki" in question or "who is" in question.lower():
570
+ return ["wiki_search"]
571
+ else:
572
+ return ["default"]
573
+
574
+
575
+ def task_classifier(question: str) -> str:
576
+ if any(op in question.lower() for op in ["add", "subtract", "multiply", "divide", "modulus"]):
577
+ return "math"
578
+ elif "who" in question.lower() or "what is" in question.lower():
579
+ return "wiki_search"
580
+ else:
581
+ return "default"
582
+
583
+ # Function to extract math operation from the question
584
+ def extract_math_from_question(question: str):
585
+ """Extract numbers and operator from a math question."""
586
+ match = re.search(r'(\d+)\s*(\+|\-|\*|\/|\%)\s*(\d+)', question)
587
+ if match:
588
+ num1 = int(match.group(1))
589
+ operator = match.group(2)
590
+ num2 = int(match.group(3))
591
+ return num1, operator, num2
592
+ else:
593
+ return None
594
+
595
+ def decide_task(state: dict) -> str:
596
+ return planner(state["question"])[0]
597
+
598
+
599
+ def node_skipper(state: dict) -> bool:
600
+ return False
601
+
602
+
603
+ def generate_final_answer(state: dict, task_results: dict) -> str:
604
+ if "wiki_search" in task_results:
605
+ return f"📚 Wiki Summary:\n{task_results['wiki_search']}"
606
+ elif "math" in task_results:
607
+ return f"🧮 Math Result: {task_results['math']}"
608
+ else:
609
+ return "🤖 Unable to generate a specific answer."
610
+
611
+
612
+ # ----------------------------------------------------------------
613
+ # Process Function (Main Agent Runner)
614
+ # ----------------------------------------------------------------
615
+ def process_question(question: str):
616
+ tasks = planner(question)
617
+ print(f"Tasks to perform: {tasks}")
618
+
619
+ task_type = task_classifier(question)
620
+ print(f"Task type: {task_type}")
621
+
622
+ state = {"question": question, "last_response": "", "messages": [HumanMessage(content=question)]}
623
+ next_task = decide_task(state)
624
+ print(f"Next task: {next_task}")
625
+
626
+ if node_skipper(state):
627
+ print(f"Skipping task: {next_task}")
628
+ return "Task skipped."
629
+
630
+ try:
631
+ if task_type == "wiki_search":
632
+ response = wiki_tool.run(question)
633
+ elif task_type == "math":
634
+ # You should dynamically parse these inputs in real use
635
+ response = calc_tool.run(question)
636
+ elif task_type == "retriever":
637
+ retrieval_result = retriever(state)
638
+ response = retrieval_result["messages"][-1].content
639
+ else:
640
+ response = "Default fallback answer."
641
+
642
+ return generate_final_answer(state, {task_type: response})
643
+
644
+ except Exception as e:
645
+ print(f"❌ Error: {e}")
646
+ return "Sorry, I encountered an error processing your request."
647
+
648
 
649
 
650