|
from typing import TypedDict, Annotated |
|
import os |
|
from dotenv import load_dotenv |
|
from langgraph.graph.message import add_messages |
|
|
|
|
|
load_dotenv() |
|
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, SystemMessage |
|
from langgraph.prebuilt import ToolNode |
|
from langgraph.graph import START, StateGraph |
|
from langgraph.checkpoint.memory import MemorySaver |
|
from langgraph.prebuilt import tools_condition |
|
from langchain_openai import ChatOpenAI |
|
from tools import agent_tools |
|
from utils import format_gaia_answer, log_agent_step |
|
|
|
|
|
chat = ChatOpenAI( |
|
model="gpt-4o", |
|
temperature=0.1, |
|
max_tokens=1024, |
|
api_key=os.environ.get("OPENAI_API_KEY") |
|
) |
|
|
|
chat_with_tools = chat.bind_tools(agent_tools) |
|
|
|
|
|
SYSTEM_PROMPT = """You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. |
|
|
|
You have access to tools that can help you: |
|
- Search the web for current information |
|
- Download and process files associated with task IDs |
|
- Analyze images |
|
- Perform calculations |
|
- Process text |
|
|
|
IMPORTANT: You must provide a specific answer in the FINAL ANSWER format. Do not say you cannot find information or provide general approaches. Use web search to find the information you need, but limit yourself to 2-3 search attempts maximum. If you cannot find perfect information, make your best determination based on what you found and provide a concrete FINAL ANSWER. Always end with a specific FINAL ANSWER, never with explanations about not finding information.""" |
|
|
|
|
|
class AgentState(TypedDict): |
|
messages: Annotated[list[AnyMessage], add_messages] |
|
task_id: str |
|
|
|
def assistant(state: AgentState): |
|
"""Main assistant function that processes messages and calls tools.""" |
|
messages = state["messages"] |
|
|
|
|
|
if not any(isinstance(msg, SystemMessage) for msg in messages): |
|
messages = [SystemMessage(content=SYSTEM_PROMPT)] + messages |
|
|
|
|
|
response = chat_with_tools.invoke(messages) |
|
|
|
return { |
|
"messages": [response], |
|
} |
|
|
|
def create_smart_agent(): |
|
"""Create and return the smart agent graph.""" |
|
|
|
builder = StateGraph(AgentState) |
|
|
|
|
|
builder.add_node("assistant", assistant) |
|
builder.add_node("tools", ToolNode(agent_tools)) |
|
|
|
|
|
builder.add_edge(START, "assistant") |
|
builder.add_conditional_edges( |
|
"assistant", |
|
tools_condition, |
|
) |
|
builder.add_edge("tools", "assistant") |
|
|
|
agent = builder.compile() |
|
|
|
return agent |
|
|
|
class SmartAgent: |
|
"""High-level intelligent agent class that wraps the LangGraph agent.""" |
|
|
|
def __init__(self): |
|
self.agent = create_smart_agent() |
|
print("π€ Smart Agent initialized with OpenAI GPT-4o and tools") |
|
|
|
def __call__(self, question: str, task_id: str = None) -> tuple: |
|
"""Process a question and return the formatted answer and reasoning trace.""" |
|
try: |
|
print(f"\nπ― Processing question: {question[:100]}...") |
|
|
|
enhanced_question = question |
|
if task_id: |
|
enhanced_question = f"Task ID: {task_id}\n\nQuestion: {question}" |
|
|
|
config = { |
|
"recursion_limit": 15 |
|
} |
|
|
|
initial_state = { |
|
"messages": [HumanMessage(content=enhanced_question)], |
|
"task_id": task_id or "" |
|
} |
|
|
|
result = self.agent.invoke(initial_state, config=config) |
|
|
|
if result and 'messages' in result and result['messages']: |
|
final_message = result['messages'][-1] |
|
raw_answer = final_message.content |
|
|
|
reasoning_trace = [] |
|
for msg in result['messages']: |
|
if hasattr(msg, 'content') and msg.content: |
|
reasoning_trace.append(msg.content) |
|
|
|
reasoning_text = "\n---\n".join(reasoning_trace) |
|
else: |
|
raw_answer = "No response generated" |
|
reasoning_text = "No reasoning trace available" |
|
|
|
|
|
formatted_answer = format_gaia_answer(raw_answer) |
|
|
|
print(f"β
Raw answer: {raw_answer}") |
|
print(f"π― Formatted answer: {formatted_answer}") |
|
|
|
|
|
if not formatted_answer or formatted_answer.strip() == "": |
|
print("β οΈ WARNING: Empty formatted answer!") |
|
formatted_answer = "ERROR: No valid answer extracted" |
|
|
|
return formatted_answer, reasoning_text |
|
|
|
except Exception as e: |
|
error_msg = f"Error processing question: {str(e)}" |
|
print(f"β {error_msg}") |
|
return error_msg, f"Error occurred: {str(e)}" |
|
|
|
smart_agent = SmartAgent() |
|
|