File size: 6,656 Bytes
f9a7c9b
 
a5c9e62
f9a7c9b
a5c9e62
 
 
f9a7c9b
 
 
 
 
a5c9e62
f9a7c9b
a5c9e62
f9a7c9b
a5c9e62
 
 
f9a7c9b
a5c9e62
 
f9a7c9b
 
 
 
a5c9e62
 
f9a7c9b
a5c9e62
 
 
 
 
 
f9a7c9b
a5c9e62
f9a7c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5c9e62
f9a7c9b
a5c9e62
 
f9a7c9b
 
 
a5c9e62
f9a7c9b
 
 
 
 
 
 
 
a5c9e62
f9a7c9b
a5c9e62
 
 
 
f9a7c9b
 
 
a5c9e62
f9a7c9b
 
 
 
a5c9e62
f9a7c9b
 
 
a5c9e62
 
 
 
 
 
 
 
f9a7c9b
 
a5c9e62
f9a7c9b
 
 
 
 
 
 
a5c9e62
 
 
 
 
 
f9a7c9b
 
 
 
a5c9e62
f9a7c9b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from typing import TypedDict, Annotated
import os
from dotenv import load_dotenv
from langgraph.graph.message import add_messages

# Load environment variables from .env file
load_dotenv()
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, SystemMessage
from langgraph.prebuilt import ToolNode
from langgraph.graph import START, StateGraph
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import tools_condition
from langchain_openai import ChatOpenAI
from tools import agent_tools
from utils import format_gaia_answer, create_execution_plan, log_agent_step

# Initialize OpenAI LLM with GPT-4o (most capable model)
chat = ChatOpenAI(
    model="gpt-4o",
    temperature=0.1,
    max_tokens=1024,
    api_key=os.environ.get("OPENAI_API_KEY")
)

chat_with_tools = chat.bind_tools(agent_tools)

# System prompt for GAIA evaluation (exact format required by HF)
SYSTEM_PROMPT = """You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.

You have access to tools that can help you:
- Search the web for current information
- Download and process files associated with task IDs
- Analyze images
- Perform calculations
- Process text

IMPORTANT: You must provide a specific answer in the FINAL ANSWER format. Do not say you cannot find information or provide general approaches. Use web search to find the information you need, but limit yourself to 2-3 search attempts maximum. If you cannot find perfect information, make your best determination based on what you found and provide a concrete FINAL ANSWER. Always end with a specific FINAL ANSWER, never with explanations about not finding information."""

# Generate the AgentState
class AgentState(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]
    task_id: str

def assistant(state: AgentState):
    """Main assistant function that processes messages and calls tools."""
    messages = state["messages"]
    
    # Add system prompt if not already present
    if not any(isinstance(msg, SystemMessage) for msg in messages):
        messages = [SystemMessage(content=SYSTEM_PROMPT)] + messages
    
    # Get the response from the LLM
    response = chat_with_tools.invoke(messages)
    
    return {
        "messages": [response],
    }

def create_smart_agent():
    """Create and return the smart agent graph."""
    # Build the graph
    builder = StateGraph(AgentState)
    
    # Define nodes
    builder.add_node("assistant", assistant)
    builder.add_node("tools", ToolNode(agent_tools))
    
    # Define edges
    builder.add_edge(START, "assistant")
    builder.add_conditional_edges(
        "assistant",
        tools_condition,
    )
    builder.add_edge("tools", "assistant")
    
    # Add memory
    memory = MemorySaver()
    agent = builder.compile(checkpointer=memory)
    
    return agent

class SmartAgent:
    """High-level intelligent agent class that wraps the LangGraph agent."""
    
    def __init__(self):
        self.agent = create_smart_agent()
        print("πŸ€– Smart Agent initialized with OpenAI GPT-4o and tools")
    
    def __call__(self, question: str, task_id: str = None) -> tuple:
        """Process a question and return the formatted answer and reasoning trace."""
        try:
            print(f"\n🎯 Processing question: {question[:100]}...")
            
            # Create simple execution plan for logging
            plan = create_execution_plan(question, task_id)
            print(f"πŸ“‹ Execution plan: {plan}")
            
            # Prepare the question with task_id context if available
            enhanced_question = question
            if task_id:
                enhanced_question = f"Task ID: {task_id}\n\nQuestion: {question}\n\nNote: If this question involves files, use the file_download tool with task_id '{task_id}' to access associated files."
            
            # Invoke the agent - let GPT-4o decide what tools to use
            thread_id = f"task-{task_id}" if task_id else "general"
            config = {
                "configurable": {"thread_id": thread_id},
                "recursion_limit": 15  # Allow more tool usage for complex searches
            }
            
            initial_state = {
                "messages": [HumanMessage(content=enhanced_question)],
                "task_id": task_id or ""
            }
            
            result = self.agent.invoke(initial_state, config=config)
            
            # Extract the final answer and reasoning trace
            if result and 'messages' in result and result['messages']:
                final_message = result['messages'][-1]
                raw_answer = final_message.content
                
                # Build reasoning trace from all messages
                reasoning_trace = []
                for msg in result['messages']:
                    if hasattr(msg, 'content') and msg.content:
                        reasoning_trace.append(msg.content)
                
                reasoning_text = "\n---\n".join(reasoning_trace)
            else:
                raw_answer = "No response generated"
                reasoning_text = "No reasoning trace available"
            
            # Format the answer for submission
            formatted_answer = format_gaia_answer(raw_answer)
            
            print(f"βœ… Raw answer: {raw_answer}")
            print(f"🎯 Formatted answer: {formatted_answer}")
            
            # Validate the formatted answer
            if not formatted_answer or formatted_answer.strip() == "":
                print("⚠️ WARNING: Empty formatted answer!")
                formatted_answer = "ERROR: No valid answer extracted"
            
            return formatted_answer, reasoning_text
            
        except Exception as e:
            error_msg = f"Error processing question: {str(e)}"
            print(f"❌ {error_msg}")
            return error_msg, f"Error occurred: {str(e)}"

smart_agent = SmartAgent()