ZeroTimo commited on
Commit
d046ba6
·
verified ·
1 Parent(s): f5078a2

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +102 -84
agent.py CHANGED
@@ -1,93 +1,111 @@
1
  import os
 
 
 
 
 
 
 
 
2
  from langchain_google_genai import ChatGoogleGenerativeAI
3
- from langchain.agents import initialize_agent, Tool, AgentType
4
- from langchain_community.tools import DuckDuckGoSearchResults, WikipediaQueryRun
5
- from langchain_experimental.tools import PythonREPLTool
6
- from langchain.tools import tool
7
- from langchain.memory import ConversationBufferMemory
8
- from langchain_core.messages import SystemMessage
9
- import pandas as pd
10
-
11
- # API Key automatisch aus Environment ziehen
12
- google_api_key = os.getenv("GOOGLE_API_KEY")
13
-
14
- # LLM: Gemini 2.0 Flash
15
- llm = ChatGoogleGenerativeAI(
16
- model="gemini-2.0-flash",
17
- google_api_key=google_api_key,
18
- temperature=0,
19
- max_output_tokens=2048,
20
- system_message=SystemMessage(content=(
21
- "You are a highly accurate AI assistant. "
22
- "You must answer precisely, concisely, and only if you are confident. "
23
- "Use the available tools like Web Search, Wikipedia, Python REPL, or Table Analysis if needed. "
24
- "Always prefer exact information over assumptions."
25
- ))
26
- )
27
-
28
- # Tool 1: Web Search
29
- web_search = DuckDuckGoSearchResults()
30
-
31
- # Tool 2: Wikipedia Search
32
- wiki_search = WikipediaQueryRun()
33
-
34
- # Tool 3: Python REPL
35
- python_repl = PythonREPLTool()
36
-
37
- # Tool 4: Analyze CSV files (sehr einfaches Tool)
38
  @tool
39
- def analyze_csv(content: str) -> str:
40
- """Analyzes CSV data and provides basic statistics and insights."""
41
- try:
42
- from io import StringIO
43
- df = pd.read_csv(StringIO(content))
44
- return str(df.describe())
45
- except Exception as e:
46
- return f"Failed to analyze CSV: {str(e)}"
47
-
48
- # Tool 5: Analyze Excel files
49
  @tool
50
- def analyze_excel(content: bytes) -> str:
51
- """Analyzes Excel data and provides basic statistics and insights."""
52
- try:
53
- from io import BytesIO
54
- df = pd.read_excel(BytesIO(content))
55
- return str(df.describe())
56
- except Exception as e:
57
- return f"Failed to analyze Excel: {str(e)}"
58
-
59
- # Alle Tools zusammen
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  tools = [
61
- Tool(
62
- name="WebSearch",
63
- func=web_search.run,
64
- description="Use this to search the internet for up-to-date or unknown information."
65
- ),
66
- Tool(
67
- name="WikipediaSearch",
68
- func=wiki_search.run,
69
- description="Use this to search Wikipedia articles when a direct lookup of factual information is needed."
70
- ),
71
- Tool(
72
- name="Python_REPL",
73
- func=python_repl.run,
74
- description="Use this for math problems, small code executions, or calculations."
75
- ),
76
- analyze_csv,
77
- analyze_excel,
78
  ]
79
 
80
- # Memory (optional, für Chat-History)
81
- memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
82
-
83
- # Agent
84
- agent_executor = initialize_agent(
85
- tools=tools,
86
- llm=llm,
87
- agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
88
- verbose=True,
89
- memory=memory,
90
- handle_parsing_errors=True,
91
- )
92
 
 
 
 
 
 
 
 
 
 
93
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from dotenv import load_dotenv
3
+ from langgraph.graph import START, StateGraph, MessagesState
4
+ from langgraph.prebuilt import tools_condition
5
+ from langgraph.prebuilt import ToolNode
6
+ from langchain_community.tools.duckduckgo_search import DuckDuckGoSearchResults
7
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
8
+ from langchain_core.messages import SystemMessage, HumanMessage
9
+ from langchain_core.tools import tool
10
  from langchain_google_genai import ChatGoogleGenerativeAI
11
+
12
+ load_dotenv()
13
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
14
+
15
+ # --- Define Tools ---
16
+
17
+ @tool
18
+ def multiply(a: int, b: int) -> int:
19
+ """Multiplies two numbers."""
20
+ return a * b
21
+
22
+ @tool
23
+ def add(a: int, b: int) -> int:
24
+ """Adds two numbers."""
25
+ return a + b
26
+
27
+ @tool
28
+ def subtract(a: int, b: int) -> int:
29
+ """Subtracts two numbers."""
30
+ return a - b
31
+
32
+ @tool
33
+ def divide(a: int, b: int) -> float:
34
+ """Divides two numbers."""
35
+ if b == 0:
36
+ raise ValueError("Cannot divide by zero.")
37
+ return a / b
38
+
 
 
 
 
 
 
 
39
  @tool
40
+ def modulo(a: int, b: int) -> int:
41
+ """Returns the remainder of dividing two numbers."""
42
+ return a % b
43
+
 
 
 
 
 
 
44
  @tool
45
+ def wiki_search(query: str) -> str:
46
+ """Search Wikipedia for a query and return up to 2 results."""
47
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
48
+ formatted = "\n\n---\n\n".join(
49
+ [f'<Document source="{doc.metadata["source"]}">\n{doc.page_content}\n</Document>'
50
+ for doc in search_docs]
51
+ )
52
+ return formatted
53
+
54
+ @tool
55
+ def arxiv_search(query: str) -> str:
56
+ """Search Arxiv for scientific papers matching the query."""
57
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
58
+ formatted = "\n\n---\n\n".join(
59
+ [f'<Document source="{doc.metadata["source"]}">\n{doc.page_content[:1000]}\n</Document>'
60
+ for doc in search_docs]
61
+ )
62
+ return formatted
63
+
64
+ @tool
65
+ def web_search(query: str) -> str:
66
+ """Search the web using DuckDuckGo."""
67
+ search = DuckDuckGoSearchResults()
68
+ return search.run(query)
69
+
70
+ # --- Load System Prompt ---
71
+
72
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
73
+ system_prompt = f.read()
74
+
75
+ sys_msg = SystemMessage(content=system_prompt)
76
+
77
+ # --- Define Tools List ---
78
  tools = [
79
+ multiply,
80
+ add,
81
+ subtract,
82
+ divide,
83
+ modulo,
84
+ wiki_search,
85
+ arxiv_search,
86
+ web_search,
 
 
 
 
 
 
 
 
 
87
  ]
88
 
89
+ # --- Build Graph Function ---
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ def build_graph():
92
+ llm = ChatGoogleGenerativeAI(
93
+ model="gemini-2.0-flash",
94
+ google_api_key=GOOGLE_API_KEY,
95
+ temperature=0,
96
+ max_output_tokens=2048,
97
+ system_message=sys_msg
98
+ )
99
+ llm_with_tools = llm.bind_tools(tools)
100
 
101
+ def assistant(state: MessagesState):
102
+ """Assistant Node"""
103
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
104
+
105
+ builder = StateGraph(MessagesState)
106
+ builder.add_node("assistant", assistant)
107
+ builder.add_node("tools", ToolNode(tools))
108
+ builder.add_edge(START, "assistant")
109
+ builder.add_conditional_edges("assistant", tools_condition)
110
+ builder.add_edge("tools", "assistant")
111
+ return builder.compile()