import os import time import json import logging from dotenv import load_dotenv from langgraph.graph import StateGraph, END from langchain_google_genai import ChatGoogleGenerativeAI from langchain_community.tools import DuckDuckGoSearchRun from langchain_community.document_loaders import WikipediaLoader, ArxivLoader from langchain_core.messages import SystemMessage, AIMessage, HumanMessage from langchain_core.tools import tool from tenacity import retry, stop_after_attempt, wait_exponential from typing import TypedDict, Annotated, Sequence import operator # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger("GAIA_Agent") # Load environment variables load_dotenv() google_api_key = os.getenv("GOOGLE_API_KEY") or os.environ.get("GOOGLE_API_KEY") if not google_api_key: raise ValueError("Missing GOOGLE_API_KEY environment variable") # --- Math Tools --- @tool def multiply(a: int, b: int) -> int: """Multiply two integers.""" return a * b @tool def add(a: int, b: int) -> int: """Add two integers.""" return a + b @tool def subtract(a: int, b: int) -> int: """Subtract b from a.""" return a - b @tool def divide(a: int, b: int) -> float: """Divide a by b, error on zero.""" if b == 0: raise ValueError("Cannot divide by zero.") return a / b @tool def modulus(a: int, b: int) -> int: """Compute a mod b.""" return a % b # --- Browser Tools --- @tool def wiki_search(query: str) -> str: """Search Wikipedia and return up to 3 relevant documents.""" try: # Ensure query contains "discography" keyword if "discography" not in query.lower(): query = f"{query} discography" docs = WikipediaLoader(query=query, load_max_docs=3).load() if not docs: return "No Wikipedia results found." results = [] for doc in docs: title = doc.metadata.get('title', 'Unknown Title') content = doc.page_content[:2000] # Limit content length results.append(f"Title: {title}\nContent: {content}") return "\n\n---\n\n".join(results) except Exception as e: return f"Wikipedia search error: {str(e)}" @tool def arxiv_search(query: str) -> str: """Search Arxiv and return up to 3 relevant papers.""" try: docs = ArxivLoader(query=query, load_max_docs=3).load() if not docs: return "No arXiv papers found." results = [] for doc in docs: title = doc.metadata.get('Title', 'Unknown Title') authors = ", ".join(doc.metadata.get('Authors', [])) content = doc.page_content[:2000] # Limit content length results.append(f"Title: {title}\nAuthors: {authors}\nContent: {content}") return "\n\n---\n\n".join(results) except Exception as e: return f"arXiv search error: {str(e)}" @tool def web_search(query: str) -> str: """Search the web using DuckDuckGo and return top results.""" try: search = DuckDuckGoSearchRun() result = search.run(query) return f"Web search results for '{query}':\n{result[:2000]}" # Limit content length except Exception as e: return f"Web search error: {str(e)}" # --- Load system prompt --- with open("system_prompt.txt", "r", encoding="utf-8") as f: system_prompt = f.read() # --- Tool Setup --- tools = [ multiply, add, subtract, divide, modulus, wiki_search, arxiv_search, web_search, ] # --- Graph Builder --- def build_graph(): # Initialize model with Gemini 2.5 Flash llm = ChatGoogleGenerativeAI( model="gemini-2.5-flash", temperature=0.3, google_api_key=google_api_key, max_retries=5, request_timeout=60 ) # Bind tools to LLM llm_with_tools = llm.bind_tools(tools) # 1. Define state structure class AgentState(TypedDict): messages: Annotated[Sequence, operator.add] retry_count: int # 2. Create graph workflow = StateGraph(AgentState) # 3. Define node functions def agent_node(state: AgentState): """Main agent node""" try: # Add request delay to avoid rate limiting time.sleep(2) # Retry mechanism for API calls @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=30)) def invoke_with_retry(): return llm_with_tools.invoke(state["messages"]) response = invoke_with_retry() return {"messages": [response], "retry_count": 0} except Exception as e: # Detailed error logging error_details = f"Gemini API Error: {type(e).__name__}: {str(e)}" logger.error(error_details) error_type = "UNKNOWN" if "429" in str(e): error_type = "QUOTA_EXCEEDED" elif "400" in str(e): error_type = "INVALID_REQUEST" elif "503" in str(e): error_type = "SERVICE_UNAVAILABLE" new_retry_count = state.get("retry_count", 0) + 1 error_msg = f"AGENT ERROR ({error_type}): {error_details[:300]}" if new_retry_count < 3: error_msg += "\n\nWill retry after delay..." else: error_msg += "\n\nMax retries exceeded. Please try again later." return {"messages": [AIMessage(content=error_msg)], "retry_count": new_retry_count} def tool_node(state: AgentState): """Tool execution node""" last_msg = state["messages"][-1] tool_calls = last_msg.additional_kwargs.get("tool_calls", []) responses = [] for call in tool_calls: tool_name = call["function"]["name"] tool_args = call["function"].get("arguments", {}) tool_func = next((t for t in tools if t.name == tool_name), None) if not tool_func: responses.append(f"Tool {tool_name} not available") continue try: # Parse arguments if isinstance(tool_args, str): try: tool_args = json.loads(tool_args) except json.JSONDecodeError: if "query" in tool_args: tool_args = {"query": tool_args} else: tool_args = {"query": tool_args} # Execute tool result = tool_func.invoke(tool_args) responses.append(f"{tool_name} result: {str(result)[:1000]}") except Exception as e: responses.append(f"{tool_name} error: {str(e)}") tool_response_content = "\n".join(responses) return {"messages": [AIMessage(content=tool_response_content)], "retry_count": 0} # 4. Add nodes to workflow workflow.add_node("agent", agent_node) workflow.add_node("tools", tool_node) # 5. Set entry point workflow.set_entry_point("agent") # 6. Define conditional edges def should_continue(state: AgentState): last_msg = state["messages"][-1] retry_count = state.get("retry_count", 0) # Handle error cases if "AGENT ERROR" in last_msg.content: if retry_count < 3: return "agent" return "end" # Route to tools if tool calls exist if hasattr(last_msg, "tool_calls") and last_msg.tool_calls: return "tools" # End if final answer is present if "FINAL ANSWER" in last_msg.content: return "end" # Continue to agent otherwise return "agent" workflow.add_conditional_edges( "agent", should_continue, { "agent": "agent", "tools": "tools", "end": END } ) # 7. Define flow after tool node workflow.add_edge("tools", "agent") # 8. Compile graph return workflow.compile() # Initialize agent graph agent_graph = build_graph()