ZeroTimo commited on
Commit
257cce5
·
verified ·
1 Parent(s): c6338ed

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +53 -76
agent.py CHANGED
@@ -1,118 +1,95 @@
1
  import os
2
- from dotenv import load_dotenv
3
- from langgraph.graph import START, StateGraph, MessagesState
4
- from langgraph.prebuilt import tools_condition
5
- from langgraph.prebuilt import ToolNode
6
- from duckduckgo_search import DDGS
7
- from langchain_community.document_loaders import WikipediaLoader
8
- from langchain_community.document_loaders import ArxivLoader
9
- from langchain_core.messages import SystemMessage, HumanMessage
10
- from langchain_core.tools import tool
11
  from langchain_google_genai import ChatGoogleGenerativeAI
 
 
 
 
12
 
13
- load_dotenv()
14
-
15
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
16
 
17
- # --- Tools ---
18
  @tool
19
  def multiply(a: int, b: int) -> int:
20
- """Multiplies two integers and returns the result."""
21
  return a * b
22
 
23
  @tool
24
  def add(a: int, b: int) -> int:
25
- """Adds two integers and returns the result."""
26
  return a + b
27
 
28
  @tool
29
  def subtract(a: int, b: int) -> int:
30
- """Subtracts the second integer from the first."""
31
  return a - b
32
 
33
  @tool
34
  def divide(a: int, b: int) -> float:
35
- """Divides the first integer by the second, returns float."""
36
  if b == 0:
37
  raise ValueError("Cannot divide by zero.")
38
  return a / b
39
 
40
  @tool
41
  def modulo(a: int, b: int) -> int:
42
- """Returns the remainder of the division of two integers."""
43
  return a % b
44
 
45
  @tool
46
  def wiki_search(query: str) -> str:
47
- """Search Wikipedia for a given query and return up to 2 results formatted."""
48
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
49
- formatted = "\n\n---\n\n".join(
50
- [f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}">\n{doc.page_content}\n</Document>' for doc in search_docs]
51
- )
52
- return {"wiki_results": formatted}
53
 
54
  @tool
55
  def arxiv_search(query: str) -> str:
56
- """Search Arxiv for a given query and return up to 3 results formatted."""
57
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
58
- formatted = "\n\n---\n\n".join(
59
- [f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}">\n{doc.page_content[:1000]}\n</Document>' for doc in search_docs]
60
- )
61
- return {"arxiv_results": formatted}
62
 
63
  @tool
64
  def web_search(query: str) -> str:
65
- """Search DuckDuckGo (for websearch) for a query and return up to 5 links."""
66
- with DDGS() as ddgs:
67
- results = ddgs.text(query, max_results=5)
68
- if not results:
69
- return "No results found."
70
- return "\n\n".join(f"{r['title']}: {r['href']}" for r in results)
71
-
72
- # --- Setup LLM und Tools ---
73
- tools = [
74
- multiply,
75
- add,
76
- subtract,
77
- divide,
78
- modulo,
79
- wiki_search,
80
- arxiv_search,
81
- web_search,
82
- ]
83
-
84
- system_prompt = (
85
- "You are a highly accurate AI assistant. "
86
- "Use tools when needed. Be very concise and precise. "
87
- "Do not hallucinate information."
88
  )
89
- sys_msg = SystemMessage(content=system_prompt)
90
-
91
- def build_graph():
92
- llm = ChatGoogleGenerativeAI(
93
- model="gemini-2.0-flash",
94
- google_api_key=GOOGLE_API_KEY,
95
- temperature=0,
96
- max_output_tokens=2048,
97
- system_message=sys_msg,
98
- )
99
- llm_with_tools = llm.bind_tools(tools)
100
 
101
- def assistant(state: MessagesState):
102
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
 
103
 
104
- builder = StateGraph(MessagesState)
105
- builder.add_node("assistant", assistant)
106
- builder.add_node("tools", ToolNode(tools))
107
- builder.add_edge(START, "assistant")
108
- builder.add_conditional_edges("assistant", tools_condition)
109
- builder.add_edge("tools", "assistant")
110
 
111
- return builder.compile()
 
 
 
 
 
 
112
 
113
- # Agent Executor für app.py
114
- def agent_executor(question: str) -> str:
115
- graph = build_graph()
116
- messages = [HumanMessage(content=question)]
117
- result = graph.invoke({"messages": messages})
118
- return result["messages"][-1].content
 
1
  import os
2
+ from langgraph.graph import StateGraph, START, MessagesState
3
+ from langgraph.prebuilt import tools_condition, ToolNode
 
 
 
 
 
 
 
4
  from langchain_google_genai import ChatGoogleGenerativeAI
5
+ from langchain_core.tools import tool
6
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
7
+ from langchain_community.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
8
+ from langchain_core.messages import SystemMessage, HumanMessage
9
 
10
+ # Lade Umgebungsvariablen (Google API Key)
 
11
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
12
 
13
+ # === Tools definieren ===
14
  @tool
15
  def multiply(a: int, b: int) -> int:
16
+ """Multiplies two numbers."""
17
  return a * b
18
 
19
  @tool
20
  def add(a: int, b: int) -> int:
21
+ """Adds two numbers."""
22
  return a + b
23
 
24
  @tool
25
  def subtract(a: int, b: int) -> int:
26
+ """Subtracts two numbers."""
27
  return a - b
28
 
29
  @tool
30
  def divide(a: int, b: int) -> float:
31
+ """Divides two numbers."""
32
  if b == 0:
33
  raise ValueError("Cannot divide by zero.")
34
  return a / b
35
 
36
  @tool
37
  def modulo(a: int, b: int) -> int:
38
+ """Returns the remainder of dividing two numbers."""
39
  return a % b
40
 
41
  @tool
42
  def wiki_search(query: str) -> str:
43
+ """Search Wikipedia for a query and return the result."""
44
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
45
+ return "\n\n".join(doc.page_content for doc in search_docs)
 
 
 
46
 
47
  @tool
48
  def arxiv_search(query: str) -> str:
49
+ """Search Arxiv for academic papers about a query."""
50
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
51
+ return "\n\n".join(doc.page_content[:1000] for doc in search_docs)
 
 
 
52
 
53
  @tool
54
  def web_search(query: str) -> str:
55
+ """Perform a DuckDuckGo web search."""
56
+ wrapper = DuckDuckGoSearchAPIWrapper(max_results=5)
57
+ results = wrapper.run(query)
58
+ return results
59
+
60
+ # === System Prompt definieren ===
61
+ system_prompt = SystemMessage(content=(
62
+ "You are an expert assistant. You MUST answer precisely, factually, and accurately. "
63
+ "If you do not know the answer, use the available tools such as Wikipedia Search, Arxiv Search, "
64
+ "or Web Search to find the correct information. "
65
+ "If a math operation is needed, use the calculation tools. "
66
+ "Do NOT invent answers. Only return answers you are confident in."
67
+ ))
68
+
69
+ # === LLM definieren ===
70
+ llm = ChatGoogleGenerativeAI(
71
+ model="gemini-2.0-flash",
72
+ google_api_key=GOOGLE_API_KEY,
73
+ temperature=0,
74
+ max_output_tokens=2048,
75
+ system_message=system_prompt,
 
 
76
  )
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ # === Tools in LLM einbinden ===
79
+ tools = [multiply, add, subtract, divide, modulo, wiki_search, arxiv_search, web_search]
80
+ llm_with_tools = llm.bind_tools(tools)
81
 
82
+ # === Nodes für LangGraph ===
83
+ def assistant(state: MessagesState):
84
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
 
 
 
85
 
86
+ # === LangGraph bauen ===
87
+ builder = StateGraph(MessagesState)
88
+ builder.add_node("assistant", assistant)
89
+ builder.add_node("tools", ToolNode(tools))
90
+ builder.add_edge(START, "assistant")
91
+ builder.add_conditional_edges("assistant", tools_condition)
92
+ builder.add_edge("tools", "assistant")
93
 
94
+ # === Agent Executor ===
95
+ agent_executor = builder.compile()