Spaces:
Running
Running
Update agent.py
Browse files
agent.py
CHANGED
@@ -44,7 +44,7 @@ from transformers import BertTokenizer, BertModel
|
|
44 |
import torch
|
45 |
import torch.nn.functional as F
|
46 |
from langchain.agents import initialize_agent, AgentType
|
47 |
-
from
|
48 |
from langchain_community.tools import Tool
|
49 |
|
50 |
import time
|
@@ -104,17 +104,12 @@ def modulus(a: int, b: int) -> int:
|
|
104 |
|
105 |
|
106 |
@tool
|
107 |
-
def
|
108 |
-
"""
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
"""
|
114 |
-
a = data.get("a")
|
115 |
-
b = data.get("b")
|
116 |
-
operation = data.get("operation", "").lower()
|
117 |
-
|
118 |
if operation == "add":
|
119 |
return a + b
|
120 |
elif operation == "subtract":
|
@@ -123,12 +118,12 @@ def calculator(data: dict) -> float:
|
|
123 |
return a * b
|
124 |
elif operation == "divide":
|
125 |
if b == 0:
|
126 |
-
|
127 |
return a / b
|
128 |
elif operation == "modulus":
|
129 |
return a % b
|
130 |
else:
|
131 |
-
|
132 |
|
133 |
@tool
|
134 |
def wiki_search(query: str) -> str:
|
@@ -419,33 +414,19 @@ docs = [
|
|
419 |
vector_store = FAISS.from_documents(docs, embedding_model)
|
420 |
vector_store.save_local("faiss_index")
|
421 |
|
422 |
-
# -----------------------------
|
423 |
-
# 5. Query & Filter Results (optional preview)
|
424 |
-
# -----------------------------
|
425 |
-
query = "How many albums did Mercedes Sosa release between 2000 and 2009?"
|
426 |
-
results = vector_store.similarity_search_with_score(query, k=5)
|
427 |
-
threshold = 0.75
|
428 |
-
filtered = [doc for doc, score in results if score < threshold]
|
429 |
-
|
430 |
-
|
431 |
-
print("\n📊 Retrieved Documents with Similarity Scores:")
|
432 |
-
filtered = []
|
433 |
-
for doc, score in results:
|
434 |
-
print(f"🔢 Score: {score:.4f}")
|
435 |
-
print(f"📄 Content: {doc.page_content}")
|
436 |
-
if score < threshold:
|
437 |
-
filtered.append(doc)
|
438 |
-
print("✅ Accepted")
|
439 |
-
else:
|
440 |
-
print("❌ Rejected")
|
441 |
-
print("-" * 80)
|
442 |
-
|
443 |
|
444 |
# -----------------------------
|
445 |
# 6. Create LangChain Retriever Tool
|
446 |
# -----------------------------
|
447 |
retriever = vector_store.as_retriever()
|
448 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
449 |
# -------------------------------
|
450 |
# Step 6: Create LangChain Tools
|
451 |
# -------------------------------
|
@@ -463,9 +444,6 @@ wikiq_tool = wikidata_query
|
|
463 |
# -------------------------------
|
464 |
# Step 7: Create the Planner-Agent Logic
|
465 |
# -------------------------------
|
466 |
-
# Define the agent tool set
|
467 |
-
from langchain.chat_models import ChatOpenAI
|
468 |
-
from langchain.agents import initialize_agent, AgentType
|
469 |
|
470 |
# Define the tools (as you've already done)
|
471 |
tools = [wiki_tool, calc_tool, file_tool, web_tool, arvix_tool, youtube_tool, video_tool, analyze_tool, wikiq_tool]
|
@@ -510,14 +488,14 @@ def process_question(question):
|
|
510 |
# Step 5: Execute task (with error handling)
|
511 |
try:
|
512 |
if task_type == "wiki_search":
|
513 |
-
response =
|
514 |
elif task_type == "math":
|
515 |
-
response =
|
516 |
else:
|
517 |
response = "Default answer logic"
|
518 |
|
519 |
# Step 6: Final response formatting
|
520 |
-
final_response =
|
521 |
return final_response
|
522 |
|
523 |
except Exception as e:
|
@@ -527,16 +505,11 @@ def process_question(question):
|
|
527 |
|
528 |
# Run the process
|
529 |
question = "How many albums did Mercedes Sosa release between 2000 and 2009?"
|
530 |
-
response =
|
531 |
print("Final Response:", response)
|
532 |
|
533 |
|
534 |
|
535 |
-
question_retriever_tool = create_retriever_tool(
|
536 |
-
retriever=retriever,
|
537 |
-
name="Question_Search",
|
538 |
-
description="A tool to retrieve documents related to a user's question."
|
539 |
-
)
|
540 |
|
541 |
|
542 |
|
@@ -562,39 +535,116 @@ def retriever(state: MessagesState):
|
|
562 |
|
563 |
|
564 |
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
subtract,
|
569 |
-
divide,
|
570 |
-
modulus,
|
571 |
-
wiki_search,
|
572 |
-
web_search,
|
573 |
-
arvix_search,
|
574 |
-
]
|
575 |
-
|
576 |
-
|
577 |
def get_llm(provider: str, config: dict):
|
578 |
if provider == "google":
|
|
|
579 |
return ChatGoogleGenerativeAI(model=config["model"], temperature=config["temperature"])
|
|
|
580 |
elif provider == "groq":
|
|
|
581 |
return ChatGroq(model=config["model"], temperature=config["temperature"])
|
|
|
582 |
elif provider == "huggingface":
|
|
|
|
|
583 |
return ChatHuggingFace(
|
584 |
llm=HuggingFaceEndpoint(url=config["url"], temperature=config["temperature"])
|
585 |
)
|
|
|
586 |
else:
|
587 |
raise ValueError(f"Invalid provider: {provider}")
|
588 |
|
589 |
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
598 |
|
599 |
|
600 |
|
|
|
44 |
import torch
|
45 |
import torch.nn.functional as F
|
46 |
from langchain.agents import initialize_agent, AgentType
|
47 |
+
from langchain_community.chat_models import ChatOpenAI
|
48 |
from langchain_community.tools import Tool
|
49 |
|
50 |
import time
|
|
|
104 |
|
105 |
|
106 |
@tool
|
107 |
+
def calculator_tool(inputs: dict):
|
108 |
+
"""Perform mathematical operations based on the operation provided."""
|
109 |
+
a = inputs.get("a")
|
110 |
+
b = inputs.get("b")
|
111 |
+
operation = inputs.get("operation")
|
112 |
+
|
|
|
|
|
|
|
|
|
|
|
113 |
if operation == "add":
|
114 |
return a + b
|
115 |
elif operation == "subtract":
|
|
|
118 |
return a * b
|
119 |
elif operation == "divide":
|
120 |
if b == 0:
|
121 |
+
return "Error: Division by zero"
|
122 |
return a / b
|
123 |
elif operation == "modulus":
|
124 |
return a % b
|
125 |
else:
|
126 |
+
return "Unknown operation"
|
127 |
|
128 |
@tool
|
129 |
def wiki_search(query: str) -> str:
|
|
|
414 |
vector_store = FAISS.from_documents(docs, embedding_model)
|
415 |
vector_store.save_local("faiss_index")
|
416 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
417 |
|
418 |
# -----------------------------
|
419 |
# 6. Create LangChain Retriever Tool
|
420 |
# -----------------------------
|
421 |
retriever = vector_store.as_retriever()
|
422 |
|
423 |
+
question_retriever_tool = create_retriever_tool(
|
424 |
+
retriever=retriever,
|
425 |
+
name="Question_Search",
|
426 |
+
description="A tool to retrieve documents related to a user's question."
|
427 |
+
)
|
428 |
+
|
429 |
+
|
430 |
# -------------------------------
|
431 |
# Step 6: Create LangChain Tools
|
432 |
# -------------------------------
|
|
|
444 |
# -------------------------------
|
445 |
# Step 7: Create the Planner-Agent Logic
|
446 |
# -------------------------------
|
|
|
|
|
|
|
447 |
|
448 |
# Define the tools (as you've already done)
|
449 |
tools = [wiki_tool, calc_tool, file_tool, web_tool, arvix_tool, youtube_tool, video_tool, analyze_tool, wikiq_tool]
|
|
|
488 |
# Step 5: Execute task (with error handling)
|
489 |
try:
|
490 |
if task_type == "wiki_search":
|
491 |
+
response = wiki_tool(question)
|
492 |
elif task_type == "math":
|
493 |
+
response = calc_tool(question)
|
494 |
else:
|
495 |
response = "Default answer logic"
|
496 |
|
497 |
# Step 6: Final response formatting
|
498 |
+
final_response = final_answer_tool(state, {'wiki_search': response})
|
499 |
return final_response
|
500 |
|
501 |
except Exception as e:
|
|
|
505 |
|
506 |
# Run the process
|
507 |
question = "How many albums did Mercedes Sosa release between 2000 and 2009?"
|
508 |
+
response = agent.run(question)
|
509 |
print("Final Response:", response)
|
510 |
|
511 |
|
512 |
|
|
|
|
|
|
|
|
|
|
|
513 |
|
514 |
|
515 |
|
|
|
535 |
|
536 |
|
537 |
|
538 |
+
# ----------------------------------------------------------------
|
539 |
+
# LLM Loader
|
540 |
+
# ----------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
541 |
def get_llm(provider: str, config: dict):
|
542 |
if provider == "google":
|
543 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
544 |
return ChatGoogleGenerativeAI(model=config["model"], temperature=config["temperature"])
|
545 |
+
|
546 |
elif provider == "groq":
|
547 |
+
from langchain_groq import ChatGroq
|
548 |
return ChatGroq(model=config["model"], temperature=config["temperature"])
|
549 |
+
|
550 |
elif provider == "huggingface":
|
551 |
+
from langchain_huggingface import ChatHuggingFace
|
552 |
+
from langchain_huggingface import HuggingFaceEndpoint
|
553 |
return ChatHuggingFace(
|
554 |
llm=HuggingFaceEndpoint(url=config["url"], temperature=config["temperature"])
|
555 |
)
|
556 |
+
|
557 |
else:
|
558 |
raise ValueError(f"Invalid provider: {provider}")
|
559 |
|
560 |
|
561 |
+
|
562 |
+
|
563 |
+
# ----------------------------------------------------------------
|
564 |
+
# Planning & Execution Logic
|
565 |
+
# ----------------------------------------------------------------
|
566 |
+
def planner(question: str) -> list:
|
567 |
+
if "calculate" in question or any(op in question for op in ["add", "subtract", "multiply", "divide", "modulus"]):
|
568 |
+
return ["math"]
|
569 |
+
elif "wiki" in question or "who is" in question.lower():
|
570 |
+
return ["wiki_search"]
|
571 |
+
else:
|
572 |
+
return ["default"]
|
573 |
+
|
574 |
+
|
575 |
+
def task_classifier(question: str) -> str:
|
576 |
+
if any(op in question.lower() for op in ["add", "subtract", "multiply", "divide", "modulus"]):
|
577 |
+
return "math"
|
578 |
+
elif "who" in question.lower() or "what is" in question.lower():
|
579 |
+
return "wiki_search"
|
580 |
+
else:
|
581 |
+
return "default"
|
582 |
+
|
583 |
+
# Function to extract math operation from the question
|
584 |
+
def extract_math_from_question(question: str):
|
585 |
+
"""Extract numbers and operator from a math question."""
|
586 |
+
match = re.search(r'(\d+)\s*(\+|\-|\*|\/|\%)\s*(\d+)', question)
|
587 |
+
if match:
|
588 |
+
num1 = int(match.group(1))
|
589 |
+
operator = match.group(2)
|
590 |
+
num2 = int(match.group(3))
|
591 |
+
return num1, operator, num2
|
592 |
+
else:
|
593 |
+
return None
|
594 |
+
|
595 |
+
def decide_task(state: dict) -> str:
|
596 |
+
return planner(state["question"])[0]
|
597 |
+
|
598 |
+
|
599 |
+
def node_skipper(state: dict) -> bool:
|
600 |
+
return False
|
601 |
+
|
602 |
+
|
603 |
+
def generate_final_answer(state: dict, task_results: dict) -> str:
|
604 |
+
if "wiki_search" in task_results:
|
605 |
+
return f"📚 Wiki Summary:\n{task_results['wiki_search']}"
|
606 |
+
elif "math" in task_results:
|
607 |
+
return f"🧮 Math Result: {task_results['math']}"
|
608 |
+
else:
|
609 |
+
return "🤖 Unable to generate a specific answer."
|
610 |
+
|
611 |
+
|
612 |
+
# ----------------------------------------------------------------
|
613 |
+
# Process Function (Main Agent Runner)
|
614 |
+
# ----------------------------------------------------------------
|
615 |
+
def process_question(question: str):
|
616 |
+
tasks = planner(question)
|
617 |
+
print(f"Tasks to perform: {tasks}")
|
618 |
+
|
619 |
+
task_type = task_classifier(question)
|
620 |
+
print(f"Task type: {task_type}")
|
621 |
+
|
622 |
+
state = {"question": question, "last_response": "", "messages": [HumanMessage(content=question)]}
|
623 |
+
next_task = decide_task(state)
|
624 |
+
print(f"Next task: {next_task}")
|
625 |
+
|
626 |
+
if node_skipper(state):
|
627 |
+
print(f"Skipping task: {next_task}")
|
628 |
+
return "Task skipped."
|
629 |
+
|
630 |
+
try:
|
631 |
+
if task_type == "wiki_search":
|
632 |
+
response = wiki_tool.run(question)
|
633 |
+
elif task_type == "math":
|
634 |
+
# You should dynamically parse these inputs in real use
|
635 |
+
response = calc_tool.run(question)
|
636 |
+
elif task_type == "retriever":
|
637 |
+
retrieval_result = retriever(state)
|
638 |
+
response = retrieval_result["messages"][-1].content
|
639 |
+
else:
|
640 |
+
response = "Default fallback answer."
|
641 |
+
|
642 |
+
return generate_final_answer(state, {task_type: response})
|
643 |
+
|
644 |
+
except Exception as e:
|
645 |
+
print(f"❌ Error: {e}")
|
646 |
+
return "Sorry, I encountered an error processing your request."
|
647 |
+
|
648 |
|
649 |
|
650 |
|