vtony commited on
Commit
c620442
·
verified ·
1 Parent(s): 4e486b5

Upload agent.py

Browse files
Files changed (1) hide show
  1. agent.py +109 -0
agent.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import TypedDict, Annotated, Sequence
3
+ import operator
4
+ from langgraph.graph import StateGraph, END
5
+ from langgraph.prebuilt import ToolNode, tools_condition
6
+ from langchain_google_genai import ChatGoogleGenerativeAI
7
+ from langchain_community.tools import DuckDuckGoSearchRun
8
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
9
+ from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
10
+ from langchain.agents import Tool
11
+ from langchain_core.tools import tool
12
+
13
+ # Load environment variables (compatible with Hugging Face Space)
14
+ google_api_key = os.getenv("GOOGLE_API_KEY") or os.environ.get("GOOGLE_API_KEY")
15
+ if not google_api_key:
16
+ raise ValueError("Missing GOOGLE_API_KEY environment variable")
17
+
18
+ # --- System Prompt ---
19
+ with open("System_Prompt.txt", "r", encoding="utf-8") as f:
20
+ system_prompt = f.read()
21
+ sys_msg = SystemMessage(content=system_prompt)
22
+
23
+ # --- Tool Definitions ---
24
+ @tool
25
+ def multiply(a: int, b: int) -> int:
26
+ """Multiply two integers together."""
27
+ return a * b
28
+
29
+ @tool
30
+ def add(a: int, b: int) -> int:
31
+ """Add two integers together."""
32
+ return a + b
33
+
34
+ @tool
35
+ def subtract(a: int, b: int) -> int:
36
+ """Subtract b from a."""
37
+ return a - b
38
+
39
+ @tool
40
+ def divide(a: int, b: int) -> float:
41
+ """Divide a by b. Returns float. Raises error if b is zero."""
42
+ if b == 0:
43
+ raise ValueError("Cannot divide by zero.")
44
+ return a / b
45
+
46
+ @tool
47
+ def wiki_search(query: str) -> str:
48
+ """Search Wikipedia and return up to 2 relevant documents."""
49
+ docs = WikipediaLoader(query=query, load_max_docs=2).load()
50
+ if not docs:
51
+ return "No Wikipedia results found."
52
+ return "\n\n".join([d.page_content[:1000] for d in docs])
53
+
54
+ # Tool inventory with proper categorization
55
+ tools = [
56
+ Tool(name="Math/Multiply", func=multiply, description="Multiplies two integers"),
57
+ Tool(name="Math/Add", func=add, description="Adds two integers"),
58
+ Tool(name="Math/Subtract", func=subtract, description="Subtracts two integers"),
59
+ Tool(name="Math/Divide", func=divide, description="Divides two numbers"),
60
+ Tool(name="Search/Wikipedia", func=wiki_search, description="Searches Wikipedia"),
61
+ Tool(
62
+ name="Search/Web",
63
+ func=DuckDuckGoSearchRun().run,
64
+ description="Searches the web using DuckDuckGo"
65
+ )
66
+ ]
67
+
68
+ # --- Graph Definition ---
69
+ class AgentState(TypedDict):
70
+ """State definition for the agent workflow"""
71
+ messages: Annotated[Sequence[BaseMessage], operator.add]
72
+
73
+ def build_graph():
74
+ """Constructs and compiles the LangGraph workflow"""
75
+
76
+ # Initialize LLM with Gemini 2.0 Flash
77
+ llm = ChatGoogleGenerativeAI(
78
+ model="gemini-2.0-flash-exp",
79
+ temperature=0.3,
80
+ google_api_key=google_api_key
81
+ )
82
+ llm_with_tools = llm.bind_tools(tools)
83
+
84
+ # Node definitions
85
+ def agent_node(state: AgentState):
86
+ """Main agent node that processes messages"""
87
+ response = llm_with_tools.invoke(state["messages"])
88
+ return {"messages": [response]}
89
+
90
+ # Graph construction
91
+ workflow = StateGraph(AgentState)
92
+
93
+ # Add nodes to the workflow
94
+ workflow.add_node("agent", agent_node)
95
+ workflow.add_node("tools", ToolNode(tools))
96
+
97
+ # Configure graph flow
98
+ workflow.set_entry_point("agent")
99
+ workflow.add_conditional_edges(
100
+ "agent",
101
+ tools_condition, # LangGraph's built-in tool detection
102
+ {"tools": "tools", "end": END} # Fixed END reference
103
+ )
104
+ workflow.add_edge("tools", "agent")
105
+
106
+ return workflow.compile()
107
+
108
+ # Initialize the agent graph
109
+ agent_graph = build_graph()