import os import time import json import logging from dotenv import load_dotenv from langgraph.graph import StateGraph, END from langchain_google_genai import ChatGoogleGenerativeAI from langchain_community.tools import DuckDuckGoSearchRun from langchain_community.document_loaders import WikipediaLoader, ArxivLoader from langchain_core.messages import SystemMessage, AIMessage, HumanMessage from langchain_core.tools import tool from typing import TypedDict, Annotated, Sequence import operator import random # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger("GAIA_Agent") # Load environment variables load_dotenv() google_api_key = os.getenv("GOOGLE_API_KEY") or os.environ.get("GOOGLE_API_KEY") if not google_api_key: raise ValueError("Missing GOOGLE_API_KEY environment variable") # --- Math Tools --- @tool def multiply(a: int, b: int) -> int: """Multiply two integers.""" return a * b @tool def add(a: int, b: int) -> int: """Add two integers.""" return a + b @tool def subtract(a: int, b: int) -> int: """Subtract b from a.""" return a - b @tool def divide(a: int, b: int) -> float: """Divide a by b, error on zero.""" if b == 0: raise ValueError("Cannot divide by zero.") return a / b @tool def modulus(a: int, b: int) -> int: """Compute a mod b.""" return a % b # --- Browser Tools --- @tool def wiki_search(query: str) -> str: """Search Wikipedia and return up to 3 relevant documents.""" try: # Ensure query contains "discography" keyword if "discography" not in query.lower(): query = f"{query} discography" docs = WikipediaLoader(query=query, load_max_docs=3).load() if not docs: return "No Wikipedia results found." results = [] for doc in docs: title = doc.metadata.get('title', 'Unknown Title') content = doc.page_content[:2000] # Limit content length results.append(f"Title: {title}\nContent: {content}") return "\n\n---\n\n".join(results) except Exception as e: return f"Wikipedia search error: {str(e)}" @tool def arxiv_search(query: str) -> str: """Search Arxiv and return up to 3 relevant papers.""" try: docs = ArxivLoader(query=query, load_max_docs=3).load() if not docs: return "No arXiv papers found." results = [] for doc in docs: title = doc.metadata.get('Title', 'Unknown Title') authors = ", ".join(doc.metadata.get('Authors', [])) content = doc.page_content[:2000] # Limit content length results.append(f"Title: {title}\nAuthors: {authors}\nContent: {content}") return "\n\n---\n\n".join(results) except Exception as e: return f"arXiv search error: {str(e)}" @tool def web_search(query: str) -> str: """Search the web using DuckDuckGo and return top results.""" try: search = DuckDuckGoSearchRun() result = search.run(query) return f"Web search results for '{query}':\n{result[:2000]}" # Limit content length except Exception as e: return f"Web search error: {str(e)}" # --- Load system prompt --- with open("system_prompt.txt", "r", encoding="utf-8") as f: system_prompt = f.read() # --- Tool Setup --- tools = [ multiply, add, subtract, divide, modulus, wiki_search, arxiv_search, web_search, ] # --- Graph Builder --- def build_graph(): # Initialize model with Gemini 2.5 Flash llm = ChatGoogleGenerativeAI( model="gemini-2.5-flash", temperature=0.3, google_api_key=google_api_key, max_retries=0, # Disable internal retries request_timeout=30 # Keep timeout reasonable ) # Bind tools to LLM llm_with_tools = llm.bind_tools(tools) # 1. Define state structure class AgentState(TypedDict): messages: Annotated[Sequence, operator.add] step_count: int start_time: float last_action: str api_errors: int # Track consecutive API errors # 2. Create graph workflow = StateGraph(AgentState) # 3. Define node functions def agent_node(state: AgentState): """Main agent node with manual retry handling""" # Ensure state has required fields state.setdefault("start_time", time.time()) state.setdefault("step_count", 0) state.setdefault("last_action", "start") state.setdefault("api_errors", 0) # Check global timeout (2 minutes) if time.time() - state["start_time"] > 120: return { "messages": [AIMessage(content="AGENT ERROR (GLOBAL_TIMEOUT): Execution exceeded 2-minute limit")], "step_count": state["step_count"] + 1, "start_time": state["start_time"], "last_action": "timeout", "api_errors": state["api_errors"] } # Check step limit (max 8 steps) if state["step_count"] >= 8: return { "messages": [AIMessage(content="AGENT ERROR (STEP_LIMIT): Exceeded maximum step count of 8")], "step_count": state["step_count"] + 1, "start_time": state["start_time"], "last_action": "step_limit", "api_errors": state["api_errors"] } # Check consecutive API errors if state["api_errors"] >= 3: return { "messages": [AIMessage(content="AGENT ERROR (API_LIMIT): Too many consecutive API errors")], "step_count": state["step_count"] + 1, "start_time": state["start_time"], "last_action": "api_limit", "api_errors": state["api_errors"] } try: # Add variable delay to avoid rate limiting delay = 2 + random.uniform(0, 3) # 2-5 seconds time.sleep(delay) # Call API without automatic retries response = llm_with_tools.invoke(state["messages"]) # Reset error counter on success return { "messages": [response], "step_count": state["step_count"] + 1, "start_time": state["start_time"], "last_action": "agent", "api_errors": 0 # Reset error counter } except Exception as e: # Detailed error logging error_details = f"Gemini API Error: {type(e).__name__}: {str(e)}" logger.error(error_details) error_type = "UNKNOWN" if "429" in str(e) or "ResourceExhausted" in str(e): error_type = "RESOURCE_EXHAUSTED" elif "400" in str(e): error_type = "INVALID_REQUEST" elif "503" in str(e): error_type = "SERVICE_UNAVAILABLE" error_msg = f"AGENT ERROR ({error_type}): {error_details[:300]}" return { "messages": [AIMessage(content=error_msg)], "step_count": state["step_count"] + 1, "start_time": state["start_time"], "last_action": "error", "api_errors": state["api_errors"] + 1 # Increment error counter } def tool_node(state: AgentState): """Tool execution node""" # Ensure state has required fields state.setdefault("start_time", time.time()) state.setdefault("step_count", 0) state.setdefault("last_action", "start") state.setdefault("api_errors", 0) # Check global timeout (2 minutes) if time.time() - state["start_time"] > 120: return { "messages": [AIMessage(content="AGENT ERROR (GLOBAL_TIMEOUT): Execution exceeded 2-minute limit")], "step_count": state["step_count"] + 1, "start_time": state["start_time"], "last_action": "timeout", "api_errors": state["api_errors"] } last_msg = state["messages"][-1] tool_calls = last_msg.additional_kwargs.get("tool_calls", []) responses = [] for call in tool_calls: tool_name = call["function"]["name"] tool_args = call["function"].get("arguments", {}) tool_func = next((t for t in tools if t.name == tool_name), None) if not tool_func: responses.append(f"Tool {tool_name} not available") continue try: # Parse arguments if isinstance(tool_args, str): try: tool_args = json.loads(tool_args) except json.JSONDecodeError: if "query" in tool_args: tool_args = {"query": tool_args} else: tool_args = {"query": tool_args} # Execute tool result = tool_func.invoke(tool_args) responses.append(f"{tool_name} result: {str(result)[:1000]}") except Exception as e: responses.append(f"{tool_name} error: {str(e)}") tool_response_content = "\n".join(responses) return { "messages": [AIMessage(content=tool_response_content)], "step_count": state["step_count"] + 1, "start_time": state["start_time"], "last_action": "tool", "api_errors": state["api_errors"] # Preserve error count } # 4. Add nodes to workflow workflow.add_node("agent", agent_node) workflow.add_node("tools", tool_node) # 5. Set entry point workflow.set_entry_point("agent") # 6. Define conditional edges def should_continue(state: AgentState): last_msg = state["messages"][-1] # Handle timeout or step limit errors if "AGENT ERROR (GLOBAL_TIMEOUT)" in last_msg.content or "AGENT ERROR (STEP_LIMIT)" in last_msg.content or "AGENT ERROR (API_LIMIT)" in last_msg.content: return "end" # Handle all other errors if "AGENT ERROR" in last_msg.content: # For RESOURCE_EXHAUSTED errors, wait longer before retrying if "RESOURCE_EXHAUSTED" in last_msg.content: time.sleep(10 + random.uniform(0, 10)) # Wait 10-20 seconds return "agent" # Route to tools if tool calls exist if hasattr(last_msg, "tool_calls") and last_msg.tool_calls: return "tools" # End if final answer is present if "FINAL ANSWER" in last_msg.content: return "end" # Continue to agent otherwise return "agent" workflow.add_conditional_edges( "agent", should_continue, { "agent": "agent", "tools": "tools", "end": END } ) # 7. Define flow after tool node workflow.add_edge("tools", "agent") # 8. Compile graph return workflow.compile() # Initialize agent graph agent_graph = build_graph() # Wrapper function to ensure execution within time limits def run_agent(question): # Create initial state with all required fields initial_state = { "messages": [ SystemMessage(content=system_prompt), HumanMessage(content=question) ], "step_count": 0, "start_time": time.time(), "last_action": "start", "api_errors": 0 } # Run with overall timeout start_time = time.time() result = None end_state_reached = False try: # Execute with 3-minute overall timeout for step in agent_graph.stream(initial_state): # Check overall timeout every step if time.time() - start_time > 180: # 3 minutes return {"error": "Overall execution timeout (3 minutes)"} # Capture the final state when the graph completes if END in step: result = step[END] end_state_reached = True break except Exception as e: return {"error": f"Execution failed: {str(e)}"} # Extract final answer safely if end_state_reached and result is not None: if "messages" in result and result["messages"]: return {"answer": result["messages"][-1].content} else: return {"error": "Agent finished but produced no messages"} else: return {"error": "Agent did not complete execution"} # 示例调用函数(在app.py中使用) def process_question(question): # Add initial delay to avoid burst requests time.sleep(1 + random.uniform(0, 2)) response = run_agent(question) if "answer" in response: return response["answer"] elif "error" in response: return f"Error: {response['error']}" else: return "Unexpected response format"