import os import time import json from dotenv import load_dotenv from langgraph.graph import StateGraph, END from langchain_google_genai import ChatGoogleGenerativeAI from langchain_community.tools import DuckDuckGoSearchRun from langchain_community.document_loaders import WikipediaLoader, ArxivLoader from langchain_core.messages import SystemMessage, AIMessage, HumanMessage from langchain_core.tools import tool from tenacity import retry, stop_after_attempt, wait_exponential from typing import TypedDict, Annotated, Sequence import operator # Load environment variables load_dotenv() google_api_key = os.getenv("GOOGLE_API_KEY") or os.environ.get("GOOGLE_API_KEY") if not google_api_key: raise ValueError("Missing GOOGLE_API_KEY environment variable") # --- Math Tools --- @tool def multiply(a: int, b: int) -> int: """Multiply two integers.""" return a * b @tool def add(a: int, b: int) -> int: """Add two integers.""" return a + b @tool def subtract(a: int, b: int) -> int: """Subtract b from a.""" return a - b @tool def divide(a: int, b: int) -> float: """Divide a by b, error on zero.""" if b == 0: raise ValueError("Cannot divide by zero.") return a / b @tool def modulus(a: int, b: int) -> int: """Compute a mod b.""" return a % b # --- Browser Tools --- @tool def wiki_search(query: str) -> str: """Search Wikipedia and return up to 3 relevant documents.""" try: docs = WikipediaLoader(query=query, load_max_docs=3).load() if not docs: return "No Wikipedia results found." results = [] for doc in docs: title = doc.metadata.get('title', 'Unknown Title') content = doc.page_content[:2000] # Limit content length results.append(f"Title: {title}\nContent: {content}") return "\n\n---\n\n".join(results) except Exception as e: return f"Wikipedia search error: {str(e)}" @tool def arxiv_search(query: str) -> str: """Search Arxiv and return up to 3 relevant papers.""" try: docs = ArxivLoader(query=query, load_max_docs=3).load() if not docs: return "No arXiv papers found." results = [] for doc in docs: title = doc.metadata.get('Title', 'Unknown Title') authors = ", ".join(doc.metadata.get('Authors', [])) content = doc.page_content[:2000] # Limit content length results.append(f"Title: {title}\nAuthors: {authors}\nContent: {content}") return "\n\n---\n\n".join(results) except Exception as e: return f"arXiv search error: {str(e)}" @tool def web_search(query: str) -> str: """Search the web using DuckDuckGo and return top results.""" try: search = DuckDuckGoSearchRun() result = search.run(query) return f"Web search results for '{query}':\n{result[:2000]}" # Limit content length except Exception as e: return f"Web search error: {str(e)}" # --- Load system prompt --- with open("system_prompt.txt", "r", encoding="utf-8") as f: system_prompt = f.read() # --- Tool Setup --- tools = [ multiply, add, subtract, divide, modulus, wiki_search, arxiv_search, web_search, ] # --- Graph Builder --- def build_graph(): # Initialize model with Gemini 2.5 Flash llm = ChatGoogleGenerativeAI( model="gemini-1.5-flash", temperature=0.3, google_api_key=google_api_key, max_retries=3 ) # Bind tools to LLM llm_with_tools = llm.bind_tools(tools) # 1. 定义状态结构 class AgentState(TypedDict): messages: Annotated[Sequence, operator.add] # 2. 创建图 workflow = StateGraph(AgentState) # 3. 定义节点函数 def agent_node(state: AgentState): """主代理节点""" try: # 添加请求间隔 time.sleep(1) # 带重试的调用 @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10)) def invoke_with_retry(): return llm_with_tools.invoke(state["messages"]) response = invoke_with_retry() return {"messages": [response]} except Exception as e: error_type = "UNKNOWN" if "429" in str(e): error_type = "QUOTA_EXCEEDED" elif "400" in str(e): error_type = "INVALID_REQUEST" error_msg = f"AGENT ERROR ({error_type}): {str(e)[:200]}" return {"messages": [AIMessage(content=error_msg)]} def tool_node(state: AgentState): """工具执行节点""" last_msg = state["messages"][-1] tool_calls = last_msg.additional_kwargs.get("tool_calls", []) responses = [] for call in tool_calls: tool_name = call["function"]["name"] tool_args = call["function"].get("arguments", {}) # 查找工具 tool_func = next((t for t in tools if t.name == tool_name), None) if not tool_func: responses.append(f"Tool {tool_name} not available") continue try: # 解析参数 if isinstance(tool_args, str): tool_args = json.loads(tool_args) # 执行工具 result = tool_func.invoke(tool_args) responses.append(f"{tool_name} result: {result[:1000]}") # 限制结果长度 except Exception as e: responses.append(f"{tool_name} error: {str(e)}") # 修复括号错误:确保正确关闭所有括号 tool_response_content = "\n".join(responses) return {"messages": [AIMessage(content=tool_response_content)]} # 4. 添加节点到工作流 workflow.add_node("agent", agent_node) workflow.add_node("tools", tool_node) # 5. 设置入口点 workflow.set_entry_point("agent") # 6. 定义条件边 def should_continue(state: AgentState): last_msg = state["messages"][-1] # 错误情况直接结束 if "AGENT ERROR" in last_msg.content: return "end" # 有工具调用则转到工具节点 if hasattr(last_msg, "tool_calls") and last_msg.tool_calls: return "tools" # 包含最终答案则结束 if "FINAL ANSWER" in last_msg.content: return "end" # 其他情况继续代理处理 return "agent" workflow.add_conditional_edges( "agent", should_continue, { "agent": "agent", "tools": "tools", "end": END } ) # 7. 定义工具节点后的流向 workflow.add_edge("tools", "agent") # 8. 编译图 return workflow.compile() # 初始化代理图 agent_graph = build_graph()