Spaces:
Runtime error
Runtime error
Update agent.py
Browse files
agent.py
CHANGED
@@ -491,48 +491,274 @@ def get_llm(provider: str, config: dict):
|
|
491 |
# ----------------------------------------------------------------
|
492 |
# Planning & Execution Logic
|
493 |
# ----------------------------------------------------------------
|
494 |
-
def planner(question: str) -> list:
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
501 |
|
502 |
|
503 |
def task_classifier(question: str) -> str:
|
504 |
-
|
|
|
|
|
|
|
|
|
|
|
505 |
return "math"
|
506 |
-
|
|
|
|
|
|
|
507 |
return "wiki_search"
|
508 |
-
|
509 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
510 |
|
511 |
# Function to extract math operation from the question
|
512 |
def extract_math_from_question(question: str):
|
513 |
-
|
514 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
515 |
if match:
|
516 |
num1 = int(match.group(1))
|
517 |
operator = match.group(2)
|
518 |
num2 = int(match.group(3))
|
519 |
return num1, operator, num2
|
520 |
-
|
521 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
522 |
|
523 |
def decide_task(state: dict) -> str:
|
524 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
525 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
526 |
|
527 |
def node_skipper(state: dict) -> bool:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
528 |
return False
|
529 |
|
|
|
|
|
|
|
|
|
|
|
|
|
530 |
|
531 |
def generate_final_answer(state: dict, task_results: dict) -> str:
|
|
|
532 |
if "wiki_search" in task_results:
|
533 |
return f"๐ Wiki Summary:\n{task_results['wiki_search']}"
|
534 |
elif "math" in task_results:
|
535 |
return f"๐งฎ Math Result: {task_results['math']}"
|
|
|
|
|
536 |
else:
|
537 |
return "๐ค Unable to generate a specific answer."
|
538 |
|
@@ -541,28 +767,27 @@ def answer_question(question: str) -> str:
|
|
541 |
"""Process a single question and return the answer."""
|
542 |
print(f"Processing question: {question[:50]}...") # Debugging: show first 50 chars
|
543 |
|
544 |
-
# Wrap the question in a HumanMessage from langchain_core
|
545 |
messages = [HumanMessage(content=question)]
|
546 |
-
|
547 |
|
548 |
# Extract the answer from the response
|
549 |
-
answer =
|
550 |
return answer[14:] # Assuming 'answer[14:]' is correct based on your example
|
551 |
|
552 |
|
553 |
def process_all_tasks(tasks: list):
|
554 |
"""Process a list of tasks."""
|
555 |
results = {}
|
556 |
-
|
557 |
for task in tasks:
|
558 |
-
# Ensure task has a question and process it
|
559 |
question = task.get("question", "").strip()
|
560 |
if not question:
|
561 |
print(f"Skipping task with missing or empty 'question': {task}")
|
562 |
continue
|
563 |
-
|
564 |
print(f"\n๐ข Processing Task: {task['task_id']} - Question: {question}")
|
565 |
-
|
566 |
# Call the existing process_question logic
|
567 |
response = process_question(question)
|
568 |
|
@@ -573,41 +798,9 @@ def process_all_tasks(tasks: list):
|
|
573 |
|
574 |
|
575 |
|
576 |
-
def process_question(question: str):
|
577 |
-
tasks = planner(question)
|
578 |
-
print(f"Tasks to perform: {tasks}")
|
579 |
-
|
580 |
-
task_type = task_classifier(question)
|
581 |
-
print(f"Task type: {task_type}")
|
582 |
-
|
583 |
-
state = {"question": question, "last_response": "", "messages": [HumanMessage(content=question)]}
|
584 |
-
next_task = decide_task(state)
|
585 |
-
print(f"Next task: {next_task}")
|
586 |
-
|
587 |
-
if node_skipper(state):
|
588 |
-
print(f"Skipping task: {next_task}")
|
589 |
-
return "Task skipped."
|
590 |
-
|
591 |
-
try:
|
592 |
-
if task_type == "wiki_search":
|
593 |
-
response = wiki_tool.run(question)
|
594 |
-
elif task_type == "math":
|
595 |
-
# You should dynamically parse these inputs in real use
|
596 |
-
response = calc_tool.run(question)
|
597 |
-
elif task_type == "retriever":
|
598 |
-
retrieval_result = retriever(state)
|
599 |
-
response = retrieval_result["messages"][-1].content
|
600 |
-
else:
|
601 |
-
response = "Default fallback answer."
|
602 |
-
|
603 |
-
return generate_final_answer(state, {task_type: response})
|
604 |
-
|
605 |
-
except Exception as e:
|
606 |
-
print(f"โ Error: {e}")
|
607 |
-
return "Sorry, I encountered an error processing your request."
|
608 |
-
|
609 |
|
610 |
|
|
|
611 |
|
612 |
# Build graph function
|
613 |
provider = "huggingface"
|
|
|
491 |
# ----------------------------------------------------------------
|
492 |
# Planning & Execution Logic
|
493 |
# ----------------------------------------------------------------
|
494 |
+
def planner(question: str, tools: list) -> list:
|
495 |
+
question = question.lower().strip()
|
496 |
+
|
497 |
+
# Define general keywords for various intents (without hardcoding the tool names)
|
498 |
+
intent_keywords = {
|
499 |
+
"math": ["calculate", "evaluate", "add", "subtract", "multiply", "divide", "modulus", "plus", "minus", "times"],
|
500 |
+
"wiki_search": ["who is", "what is", "define", "explain", "tell me about", "overview of"],
|
501 |
+
"web_search": ["search", "find", "look up", "google", "latest news", "current info"],
|
502 |
+
"arxiv": ["arxiv", "research paper", "scientific paper", "preprint"],
|
503 |
+
"youtube": ["youtube", "watch", "play video", "show me a video"],
|
504 |
+
"video_analysis": ["analyze video", "summarize video", "video content"],
|
505 |
+
"data_analysis": ["analyze", "plot", "graph", "data", "visualize"],
|
506 |
+
"wikidata_query": ["wikidata", "sparql", "run sparql", "query wikidata"],
|
507 |
+
"general_qa": ["why", "how", "difference between", "compare", "what happens", "reason for", "cause of", "effect of"]
|
508 |
+
}
|
509 |
+
|
510 |
+
matched_tools = []
|
511 |
+
|
512 |
+
# Loop over tools and match based on descriptions
|
513 |
+
for tool in tools:
|
514 |
+
# Get tool description
|
515 |
+
tool_description = getattr(tool, "description", "").lower()
|
516 |
+
|
517 |
+
# Check if any keywords match tool's description or if they fit general intent categories
|
518 |
+
for intent, keywords in intent_keywords.items():
|
519 |
+
if any(keyword in question for keyword in keywords) and intent in tool_description:
|
520 |
+
matched_tools.append(tool)
|
521 |
+
break # No need to check other keywords for this tool once matched
|
522 |
+
|
523 |
+
# If no matched tool found, fallback to general-purpose tools or default
|
524 |
+
if not matched_tools:
|
525 |
+
matched_tools = [tool for tool in tools if "default" in getattr(tool, "name", "").lower() or "qa" in getattr(tool, "description", "").lower()]
|
526 |
+
|
527 |
+
return matched_tools if matched_tools else [tools[0]] # Return the first tool as a last resort
|
528 |
+
|
529 |
|
530 |
|
531 |
def task_classifier(question: str) -> str:
|
532 |
+
question = question.lower().strip()
|
533 |
+
|
534 |
+
# Context-aware intent patterns
|
535 |
+
if any(phrase in question for phrase in [
|
536 |
+
"calculate", "how much is", "what is the result of", "evaluate", "solve"
|
537 |
+
]) or any(op in question for op in ["add", "subtract", "multiply", "divide", "modulus", "plus", "minus", "times"]):
|
538 |
return "math"
|
539 |
+
|
540 |
+
elif any(phrase in question for phrase in [
|
541 |
+
"who is", "what is", "define", "explain", "tell me about", "give me an overview of"
|
542 |
+
]):
|
543 |
return "wiki_search"
|
544 |
+
|
545 |
+
elif any(phrase in question for phrase in [
|
546 |
+
"search", "find", "look up", "google", "get the latest", "current news", "trending"
|
547 |
+
]):
|
548 |
+
return "web_search"
|
549 |
+
|
550 |
+
elif any(phrase in question for phrase in [
|
551 |
+
"arxiv", "latest research", "scientific paper", "research paper", "preprint"
|
552 |
+
]):
|
553 |
+
return "arxiv"
|
554 |
+
|
555 |
+
elif any(phrase in question for phrase in [
|
556 |
+
"youtube", "watch", "play the video", "show me a video"
|
557 |
+
]):
|
558 |
+
return "youtube"
|
559 |
+
|
560 |
+
elif any(phrase in question for phrase in [
|
561 |
+
"analyze video", "summarize video", "what happens in the video", "video content"
|
562 |
+
]):
|
563 |
+
return "video_analysis"
|
564 |
+
|
565 |
+
elif any(phrase in question for phrase in [
|
566 |
+
"analyze", "visualize", "plot", "graph", "inspect data", "explore dataset"
|
567 |
+
]):
|
568 |
+
return "data_analysis"
|
569 |
+
|
570 |
+
elif any(phrase in question for phrase in [
|
571 |
+
"sparql", "wikidata", "query wikidata", "run sparql", "wikidata query"
|
572 |
+
]):
|
573 |
+
return "wikidata_query"
|
574 |
+
|
575 |
+
return "default"
|
576 |
+
|
577 |
|
578 |
# Function to extract math operation from the question
|
579 |
def extract_math_from_question(question: str):
|
580 |
+
question = question.lower()
|
581 |
+
# Map word-based operations to symbols
|
582 |
+
ops = {
|
583 |
+
"add": "+", "plus": "+",
|
584 |
+
"subtract": "-", "minus": "-",
|
585 |
+
"multiply": "*", "times": "*",
|
586 |
+
"divide": "/", "divided by": "/",
|
587 |
+
"modulus": "%", "mod": "%"
|
588 |
+
}
|
589 |
+
|
590 |
+
for word, symbol in ops.items():
|
591 |
+
question = question.replace(word, symbol)
|
592 |
+
|
593 |
+
# Match expressions like "4 + 5"
|
594 |
+
match = re.search(r'(\d+)\s*([\+\-\*/%])\s*(\d+)', question)
|
595 |
if match:
|
596 |
num1 = int(match.group(1))
|
597 |
operator = match.group(2)
|
598 |
num2 = int(match.group(3))
|
599 |
return num1, operator, num2
|
600 |
+
return None
|
601 |
+
|
602 |
+
|
603 |
+
# Example tool set (adjust these to match your actual tool names)
|
604 |
+
tools = {
|
605 |
+
"math": calc_tool, # Example tool for math tasks
|
606 |
+
"wiki_search": wiki_tool, # Example tool for wiki search tasks
|
607 |
+
"retriever": retriever_tool, # Example tool for retriever tasks
|
608 |
+
"default": default_tool # Fallback tool
|
609 |
+
}
|
610 |
+
|
611 |
+
# The task order can also include the tools for each task
|
612 |
+
priority_order = [
|
613 |
+
{"task": "math", "tool": "math"}, # Priority task and tool
|
614 |
+
{"task": "wiki_search", "tool": "wiki_search"},
|
615 |
+
{"task": "retriever", "tool": "retriever"},
|
616 |
+
{"task": "default", "tool": "default"} # Fallback tool
|
617 |
+
]
|
618 |
|
619 |
def decide_task(state: dict) -> str:
|
620 |
+
"""Decides which task to perform based on the current state."""
|
621 |
+
|
622 |
+
# Get the list of tasks from the planner
|
623 |
+
tasks = planner(state["question"])
|
624 |
+
print(f"Available tasks: {tasks}") # Debugging: show all possible tasks
|
625 |
+
|
626 |
+
# Check if the tasks list is empty or invalid
|
627 |
+
if not tasks:
|
628 |
+
print("โ No valid tasks were returned from the planner.")
|
629 |
+
return "default" # Return a default task if no tasks were generated
|
630 |
+
|
631 |
+
# If there are multiple tasks, we can prioritize based on certain conditions
|
632 |
+
task = tasks[0] # Default to the first task in the list
|
633 |
+
if len(tasks) > 1:
|
634 |
+
print(f"โ ๏ธ Multiple tasks found. Deciding based on priority.")
|
635 |
+
# Example logic to prioritize tasks, adjust based on your use case
|
636 |
+
task = prioritize_tasks(tasks)
|
637 |
+
|
638 |
+
print(f"Decided on task: {task}") # Debugging: show the final task
|
639 |
+
return task
|
640 |
+
|
641 |
|
642 |
+
def prioritize_tasks(tasks: list) -> str:
|
643 |
+
"""Prioritize tasks based on certain conditions or criteria, including tools."""
|
644 |
+
# Sort tasks based on priority_order mapping
|
645 |
+
for priority in priority_order:
|
646 |
+
# Check if any task matches the priority task type
|
647 |
+
for task in tasks:
|
648 |
+
if priority["task"] in task:
|
649 |
+
print(f"โ
Prioritizing task: {task} with tool: {priority['tool']}") # Debugging: show the chosen task and tool
|
650 |
+
# Assign the correct tool based on the task
|
651 |
+
tool = tools.get(priority["tool"], tools["default"]) # Default to 'default_tool' if not found
|
652 |
+
return task, tool
|
653 |
+
|
654 |
+
# If no priority task is found, return the first task with its default tool
|
655 |
+
return tasks[0], tools["default"]
|
656 |
+
|
657 |
+
|
658 |
+
def process_question(question: str):
|
659 |
+
"""Process the question and route it to the appropriate tool."""
|
660 |
+
# Get the tasks from the planner
|
661 |
+
tasks = planner(question)
|
662 |
+
print(f"Tasks to perform: {tasks}")
|
663 |
+
|
664 |
+
task_type, tool = decide_task({"question": question})
|
665 |
+
print(f"Next task: {task_type} with tool: {tool}")
|
666 |
+
|
667 |
+
if node_skipper({"question": question}):
|
668 |
+
print(f"Skipping task: {task_type}")
|
669 |
+
return "Task skipped."
|
670 |
+
|
671 |
+
try:
|
672 |
+
# Execute the corresponding tool for the task type
|
673 |
+
if task_type == "wiki_search":
|
674 |
+
response = tool.run(question) # Assuming tool is wiki_tool
|
675 |
+
elif task_type == "math":
|
676 |
+
response = tool.run(question) # Assuming tool is calc_tool
|
677 |
+
elif task_type == "retriever":
|
678 |
+
response = tool.run(question) # Assuming tool is retriever_tool
|
679 |
+
else:
|
680 |
+
response = tool.run(question) # Default tool
|
681 |
+
|
682 |
+
return generate_final_answer({"question": question}, {task_type: response})
|
683 |
+
|
684 |
+
except Exception as e:
|
685 |
+
print(f"โ Error: {e}")
|
686 |
+
return f"Sorry, I encountered an error: {str(e)}"
|
687 |
+
|
688 |
+
|
689 |
+
|
690 |
+
|
691 |
+
# To store previously asked questions and timestamps (simulating state persistence)
|
692 |
+
recent_questions = {}
|
693 |
|
694 |
def node_skipper(state: dict) -> bool:
|
695 |
+
"""
|
696 |
+
Determines whether to skip the task based on the state.
|
697 |
+
This could include:
|
698 |
+
1. Repeated or similar questions
|
699 |
+
2. Irrelevant or empty questions
|
700 |
+
3. Tasks that have already been processed recently
|
701 |
+
"""
|
702 |
+
question = state.get("question", "").strip()
|
703 |
+
|
704 |
+
if not question:
|
705 |
+
print("โ Skipping: Empty or invalid question.")
|
706 |
+
return True # Skip if no valid question
|
707 |
+
|
708 |
+
# 1. Skip if the question has already been asked recently (within a given time window)
|
709 |
+
# Here, we're using a simple example with a 5-minute window (300 seconds).
|
710 |
+
if question in recent_questions:
|
711 |
+
last_asked_time = recent_questions[question]
|
712 |
+
time_since_last_ask = time.time() - last_asked_time
|
713 |
+
if time_since_last_ask < 300: # 5-minute threshold
|
714 |
+
print(f"โ Skipping: The question has been asked recently. Time since last ask: {time_since_last_ask:.2f} seconds.")
|
715 |
+
return True # Skip if the question was asked within the last 5 minutes
|
716 |
+
|
717 |
+
# 2. Skip if the question is irrelevant or not meaningful enough
|
718 |
+
irrelevant_keywords = ["blah", "nothing", "invalid", "nonsense"]
|
719 |
+
if any(keyword in question.lower() for keyword in irrelevant_keywords):
|
720 |
+
print("โ Skipping: Irrelevant or nonsense question.")
|
721 |
+
return True # Skip if the question contains irrelevant keywords
|
722 |
+
|
723 |
+
# 3. Skip if the task has already been completed for this question (based on a unique task identifier)
|
724 |
+
if "last_response" in state and state["last_response"]:
|
725 |
+
print("โ Skipping: Task has already been processed recently.")
|
726 |
+
return True # Skip if a response has already been given
|
727 |
+
|
728 |
+
# 4. Skip based on a condition related to the task itself
|
729 |
+
# Example: Skip math-related tasks if the result is already known or trivial
|
730 |
+
if "math" in state.get("question", "").lower():
|
731 |
+
# If math is trivial (like "What is 2+2?")
|
732 |
+
trivial_math = ["2 + 2", "1 + 1", "3 + 3"]
|
733 |
+
if any(trivial_question in question for trivial_question in trivial_math):
|
734 |
+
print(f"โ Skipping trivial math question: {question}")
|
735 |
+
return True # Skip if the math question is trivial
|
736 |
+
|
737 |
+
# 5. Skip based on external factors (e.g., current time, system load, etc.)
|
738 |
+
# Example: Avoid processing tasks at night if that's part of the business logic
|
739 |
+
current_hour = time.localtime().tm_hour
|
740 |
+
if current_hour >= 22 or current_hour < 6:
|
741 |
+
print("โ Skipping: It's night time, not processing tasks.")
|
742 |
+
return True # Skip tasks during night time (e.g., between 10 PM and 6 AM)
|
743 |
+
|
744 |
+
# If none of the conditions matched, don't skip the task
|
745 |
return False
|
746 |
|
747 |
+
# Update recent questions (for simulating repeated question check)
|
748 |
+
def update_recent_questions(question: str):
|
749 |
+
"""Update the recent questions dictionary with the current timestamp."""
|
750 |
+
recent_questions[question] = time.time()
|
751 |
+
|
752 |
+
|
753 |
|
754 |
def generate_final_answer(state: dict, task_results: dict) -> str:
|
755 |
+
"""Generate a final answer based on the results of the task."""
|
756 |
if "wiki_search" in task_results:
|
757 |
return f"๐ Wiki Summary:\n{task_results['wiki_search']}"
|
758 |
elif "math" in task_results:
|
759 |
return f"๐งฎ Math Result: {task_results['math']}"
|
760 |
+
elif "retriever" in task_results:
|
761 |
+
return f"๐ Retrieved Info: {task_results['retriever']}"
|
762 |
else:
|
763 |
return "๐ค Unable to generate a specific answer."
|
764 |
|
|
|
767 |
"""Process a single question and return the answer."""
|
768 |
print(f"Processing question: {question[:50]}...") # Debugging: show first 50 chars
|
769 |
|
770 |
+
# Wrap the question in a HumanMessage from langchain_core (assuming langchain is used)
|
771 |
messages = [HumanMessage(content=question)]
|
772 |
+
response = graph.invoke({"messages": messages}) # Assuming `graph` is defined elsewhere
|
773 |
|
774 |
# Extract the answer from the response
|
775 |
+
answer = response['messages'][-1].content
|
776 |
return answer[14:] # Assuming 'answer[14:]' is correct based on your example
|
777 |
|
778 |
|
779 |
def process_all_tasks(tasks: list):
|
780 |
"""Process a list of tasks."""
|
781 |
results = {}
|
782 |
+
|
783 |
for task in tasks:
|
|
|
784 |
question = task.get("question", "").strip()
|
785 |
if not question:
|
786 |
print(f"Skipping task with missing or empty 'question': {task}")
|
787 |
continue
|
788 |
+
|
789 |
print(f"\n๐ข Processing Task: {task['task_id']} - Question: {question}")
|
790 |
+
|
791 |
# Call the existing process_question logic
|
792 |
response = process_question(question)
|
793 |
|
|
|
798 |
|
799 |
|
800 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
801 |
|
802 |
|
803 |
+
## Langgraph
|
804 |
|
805 |
# Build graph function
|
806 |
provider = "huggingface"
|