ZeroTimo commited on
Commit
60684f0
·
verified ·
1 Parent(s): e7d1b6d

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +46 -31
agent.py CHANGED
@@ -4,32 +4,35 @@ from langgraph.graph import START, StateGraph, MessagesState
4
  from langgraph.prebuilt import tools_condition
5
  from langgraph.prebuilt import ToolNode
6
  from langchain_community.tools.duckduckgo_search import DuckDuckGoSearchResults
7
- from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
 
8
  from langchain_core.messages import SystemMessage, HumanMessage
9
  from langchain_core.tools import tool
10
  from langchain_google_genai import ChatGoogleGenerativeAI
11
 
 
12
  load_dotenv()
13
- GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
14
 
15
- # --- Define Tools ---
 
16
 
17
- @tool
 
18
  def multiply(a: int, b: int) -> int:
19
  """Multiplies two numbers."""
20
  return a * b
21
 
22
- @tool
23
  def add(a: int, b: int) -> int:
24
  """Adds two numbers."""
25
  return a + b
26
 
27
- @tool
28
  def subtract(a: int, b: int) -> int:
29
  """Subtracts two numbers."""
30
  return a - b
31
 
32
- @tool
33
  def divide(a: int, b: int) -> float:
34
  """Divides two numbers."""
35
  if b == 0:
@@ -38,43 +41,40 @@ def divide(a: int, b: int) -> float:
38
 
39
  @tool
40
  def modulo(a: int, b: int) -> int:
41
- """Returns the remainder of dividing two numbers."""
42
  return a % b
43
 
44
  @tool
45
  def wiki_search(query: str) -> str:
46
  """Search Wikipedia for a query and return up to 2 results."""
47
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
48
- formatted = "\n\n---\n\n".join(
49
- [f'<Document source="{doc.metadata["source"]}">\n{doc.page_content}\n</Document>'
50
- for doc in search_docs]
 
 
51
  )
52
- return formatted
53
 
54
  @tool
55
  def arxiv_search(query: str) -> str:
56
- """Search Arxiv for scientific papers matching the query."""
57
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
58
- formatted = "\n\n---\n\n".join(
59
- [f'<Document source="{doc.metadata["source"]}">\n{doc.page_content[:1000]}\n</Document>'
60
- for doc in search_docs]
 
 
61
  )
62
- return formatted
63
 
64
  @tool
65
  def web_search(query: str) -> str:
66
- """Search the web using DuckDuckGo."""
67
- search = DuckDuckGoSearchResults()
68
  return search.run(query)
69
 
70
- # --- Load System Prompt ---
71
-
72
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
73
- system_prompt = f.read()
74
-
75
- sys_msg = SystemMessage(content=system_prompt)
76
-
77
- # --- Define Tools List ---
78
  tools = [
79
  multiply,
80
  add,
@@ -86,26 +86,41 @@ tools = [
86
  web_search,
87
  ]
88
 
89
- # --- Build Graph Function ---
 
 
 
 
 
 
90
 
 
91
  def build_graph():
92
  llm = ChatGoogleGenerativeAI(
93
  model="gemini-2.0-flash",
94
  google_api_key=GOOGLE_API_KEY,
95
  temperature=0,
96
  max_output_tokens=2048,
97
- system_message=sys_msg
98
  )
99
  llm_with_tools = llm.bind_tools(tools)
100
 
101
  def assistant(state: MessagesState):
102
  """Assistant Node"""
103
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
104
-
105
  builder = StateGraph(MessagesState)
106
  builder.add_node("assistant", assistant)
107
  builder.add_node("tools", ToolNode(tools))
108
  builder.add_edge(START, "assistant")
109
  builder.add_conditional_edges("assistant", tools_condition)
110
  builder.add_edge("tools", "assistant")
111
- return builder.compile()
 
 
 
 
 
 
 
 
 
4
  from langgraph.prebuilt import tools_condition
5
  from langgraph.prebuilt import ToolNode
6
  from langchain_community.tools.duckduckgo_search import DuckDuckGoSearchResults
7
+ from langchain_community.document_loaders import WikipediaLoader
8
+ from langchain_community.document_loaders import ArxivLoader
9
  from langchain_core.messages import SystemMessage, HumanMessage
10
  from langchain_core.tools import tool
11
  from langchain_google_genai import ChatGoogleGenerativeAI
12
 
13
+ # .env laden (falls lokal)
14
  load_dotenv()
 
15
 
16
+ # Google API Key aus Environment
17
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
18
 
19
+ # --- Tools definieren ---
20
+ @tool
21
  def multiply(a: int, b: int) -> int:
22
  """Multiplies two numbers."""
23
  return a * b
24
 
25
+ @tool
26
  def add(a: int, b: int) -> int:
27
  """Adds two numbers."""
28
  return a + b
29
 
30
+ @tool
31
  def subtract(a: int, b: int) -> int:
32
  """Subtracts two numbers."""
33
  return a - b
34
 
35
+ @tool
36
  def divide(a: int, b: int) -> float:
37
  """Divides two numbers."""
38
  if b == 0:
 
41
 
42
  @tool
43
  def modulo(a: int, b: int) -> int:
44
+ """Returns the remainder after division."""
45
  return a % b
46
 
47
  @tool
48
  def wiki_search(query: str) -> str:
49
  """Search Wikipedia for a query and return up to 2 results."""
50
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
51
+ formatted_search_docs = "\n\n---\n\n".join(
52
+ [
53
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}">\n{doc.page_content}\n</Document>'
54
+ for doc in search_docs
55
+ ]
56
  )
57
+ return {"wiki_results": formatted_search_docs}
58
 
59
  @tool
60
  def arxiv_search(query: str) -> str:
61
+ """Search Arxiv for a query and return up to 3 results."""
62
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
63
+ formatted_search_docs = "\n\n---\n\n".join(
64
+ [
65
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}">\n{doc.page_content[:1000]}\n</Document>'
66
+ for doc in search_docs
67
+ ]
68
  )
69
+ return {"arxiv_results": formatted_search_docs}
70
 
71
  @tool
72
  def web_search(query: str) -> str:
73
+ """Search DuckDuckGo for a query and return results."""
74
+ search = DuckDuckGoSearchResults(max_results=5)
75
  return search.run(query)
76
 
77
+ # Tools-Liste
 
 
 
 
 
 
 
78
  tools = [
79
  multiply,
80
  add,
 
86
  web_search,
87
  ]
88
 
89
+ # System Prompt
90
+ system_prompt = (
91
+ "You are a highly accurate AI assistant. "
92
+ "Use tools when needed. Be very concise and precise. "
93
+ "Do not hallucinate information."
94
+ )
95
+ sys_msg = SystemMessage(content=system_prompt)
96
 
97
+ # --- Build Graph ---
98
  def build_graph():
99
  llm = ChatGoogleGenerativeAI(
100
  model="gemini-2.0-flash",
101
  google_api_key=GOOGLE_API_KEY,
102
  temperature=0,
103
  max_output_tokens=2048,
104
+ system_message=sys_msg,
105
  )
106
  llm_with_tools = llm.bind_tools(tools)
107
 
108
  def assistant(state: MessagesState):
109
  """Assistant Node"""
110
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
111
+
112
  builder = StateGraph(MessagesState)
113
  builder.add_node("assistant", assistant)
114
  builder.add_node("tools", ToolNode(tools))
115
  builder.add_edge(START, "assistant")
116
  builder.add_conditional_edges("assistant", tools_condition)
117
  builder.add_edge("tools", "assistant")
118
+
119
+ return builder.compile()
120
+
121
+ # --- Agent Executor für app.py ---
122
+ def agent_executor(question: str) -> str:
123
+ graph = build_graph()
124
+ messages = [HumanMessage(content=question)]
125
+ result = graph.invoke({"messages": messages})
126
+ return result["messages"][-1].content