FrancescaScipioni commited on
Commit
83eac9e
·
verified ·
1 Parent(s): 67d9b11

added the final agent graph, plus a test question for testing the agent

Browse files
Files changed (1) hide show
  1. agent.py +116 -197
agent.py CHANGED
@@ -1,203 +1,88 @@
1
- from langchain.tools import Tool
2
- from langchain.utilities import WikipediaAPIWrapper, ArxivAPIWrapper, DuckDuckGoSearchRun
3
  import math
4
  import whisper
5
- from youtube_transcript_api import YouTubeTranscriptApi
6
- from PIL import Image
7
- import pytesseract
8
  import pandas as pd
 
 
9
  from dotenv import load_dotenv
 
 
10
 
11
- from langgraph.graph import StateGraph, START, END
12
- from langgraph.prebuilt import ToolNode, tools_condition
 
 
13
  from langchain_openai import ChatOpenAI
14
  from langchain_core.messages import HumanMessage, SystemMessage
15
- from typing import TypedDict, Dict, Any, Optional, List
 
 
16
 
 
17
  load_dotenv()
18
-
19
- ## ----- API KEYS ----- ##
20
-
21
  openai_api_key = os.getenv("OPENAI_API_KEY")
22
 
23
- ## ----- TOOLS DEFINITION ----- ##
24
-
25
- # ** Math Tools ** #
26
-
27
- def add_numbers(a: float, b: float) -> float:
28
- """
29
- Add two floating-point numbers.
30
-
31
- Args:
32
- a (float): The first number.
33
- b (float): The second number.
34
-
35
- Returns:
36
- float: The result of the addition.
37
- """
38
- return a + b
39
-
40
- def subtract_numbers(a: float, b: float) -> float:
41
- """
42
- Subtract the second floating-point number from the first.
43
-
44
- Args:
45
- a (float): The first number.
46
- b (float): The second number.
47
-
48
- Returns:
49
- float: The result of the subtraction.
50
- """
51
- return a - b
52
-
53
- def multiply_numbers(a: float, b: float) -> float:
54
- """
55
- Multiply two floating-point numbers.
56
-
57
- Args:
58
- a (float): The first number.
59
- b (float): The second number.
60
-
61
- Returns:
62
- float: The result of the multiplication.
63
- """
64
- return a * b
65
 
 
 
 
 
66
  def divide_numbers(a: float, b: float) -> float:
67
- """
68
- Divide the first floating-point number by the second.
69
-
70
- Args:
71
- a (float): The numerator.
72
- b (float): The denominator.
73
-
74
- Returns:
75
- float: The result of the division.
76
-
77
- Raises:
78
- ValueError: If division by zero is attempted.
79
- """
80
- if b == 0:
81
- raise ValueError("Division by zero")
82
  return a / b
83
-
84
- def power(a: float, b: float) -> float:
85
- """
86
- Raise the first number to the power of the second.
87
-
88
- Args:
89
- a (float): The base.
90
- b (float): The exponent.
91
-
92
- Returns:
93
- float: The result of the exponentiation.
94
- """
95
- return a ** b
96
-
97
- def modulus(a: float, b: float) -> float:
98
- """
99
- Compute the modulus (remainder) of the division of a by b.
100
-
101
- Args:
102
- a (float): The dividend.
103
- b (float): The divisor.
104
-
105
- Returns:
106
- float: The remainder after division.
107
- """
108
- return a % b
109
-
110
  def square_root(a: float) -> float:
111
- """
112
- Compute the square root of a number.
113
-
114
- Args:
115
- a (float): The number.
116
-
117
- Returns:
118
- float: The square root.
119
-
120
- Raises:
121
- ValueError: If a is negative.
122
- """
123
- if a < 0:
124
- raise ValueError("Cannot compute square root of a negative number")
125
  return math.sqrt(a)
126
-
127
  def logarithm(a: float, base: float = math.e) -> float:
128
- """
129
- Compute the logarithm of a number with a specified base.
130
-
131
- Args:
132
- a (float): The number.
133
- base (float, optional): The logarithmic base (default is natural log).
134
-
135
- Returns:
136
- float: The logarithm.
137
-
138
- Raises:
139
- ValueError: If a or base is not positive.
140
- """
141
- if a <= 0 or base <= 0:
142
- raise ValueError("Logarithm arguments must be positive")
143
  return math.log(a, base)
144
 
145
- # ** Search Tools ** #
146
-
147
- # DuckDuckGo Web Search
148
- duckduckgo_search = DuckDuckGoSearchRun()
149
  web_search_tool = Tool.from_function(
150
- func=duckduckgo_search.run,
151
  name="Web Search",
152
- description="Use this tool to search the internet for general-purpose queries."
153
  )
154
 
155
- # Wikipedia Search
156
- wikipedia_search = WikipediaAPIWrapper()
157
  wikipedia_tool = Tool.from_function(
158
- func=wikipedia_search.run,
159
  name="Wikipedia Search",
160
- description="Use this tool to search Wikipedia for factual or encyclopedic information."
161
  )
162
 
163
- # ArXiv Search
164
- arxiv_search = ArxivAPIWrapper()
165
  arxiv_tool = Tool.from_function(
166
- func=arxiv_search.run,
167
  name="ArXiv Search",
168
- description="Use this tool to search ArXiv for scientific papers. Input should be a research topic or query."
169
  )
170
 
171
- # ** Audio Transcription Tool ** #
172
-
173
- model = whisper.load_model("base")
174
-
175
- @tool
176
  def transcribe_audio(file_path: str) -> str:
177
- """Transcribe spoken words from an audio file into text."""
178
- result = model.transcribe(file_path)
179
- return result["text"]
180
 
181
- # ** youtube-transcript-api Tool ** #
182
-
183
- @tool
184
  def get_youtube_transcript(video_id: str) -> str:
185
- """Get transcript of a YouTube video from its video ID."""
186
  transcript = YouTubeTranscriptApi.get_transcript(video_id)
187
- return " ".join([entry["text"] for entry in transcript])
188
-
189
- # ** Image Tool ** #
190
 
191
- @tool
 
192
  def extract_text_from_image(image_path: str) -> str:
193
- """Extract text from an image using OCR."""
194
  return pytesseract.image_to_string(Image.open(image_path))
195
 
196
- # ** Code Execution Tool ** #
197
-
198
- @tool
199
  def execute_python_code(code: str) -> str:
200
- """Execute a Python code string and return the output."""
201
  try:
202
  local_vars = {}
203
  exec(code, {}, local_vars)
@@ -205,60 +90,94 @@ def execute_python_code(code: str) -> str:
205
  except Exception as e:
206
  return f"Error: {e}"
207
 
208
- # ** Excel Parsing Tool ** #
209
-
210
- @tool
211
  def total_sales_from_excel(file_path: str) -> str:
212
  """Compute total food sales from an Excel file."""
213
  df = pd.read_excel(file_path)
214
  food_df = df[df["Category"] == "Food"]
215
- total_sales = food_df["Sales"].sum()
216
- return f"{total_sales:.2f} USD"
217
-
218
 
219
- ## ----- TOOLS LIST ----- ##
220
 
221
  tools = [
222
- # Math
223
- Tool.from_function(func=add_numbers, name="Add Numbers", description="Add two numbers."),
224
- Tool.from_function(func=subtract_numbers, name="Subtract Numbers", description="Subtract two numbers."),
225
- Tool.from_function(func=multiply_numbers, name="Multiply Numbers", description="Multiply two numbers."),
226
- Tool.from_function(func=divide_numbers, name="Divide Numbers", description="Divide two numbers."),
227
- Tool.from_function(func=power, name="Power", description="Raise one number to the power of another."),
228
- Tool.from_function(func=modulus, name="Modulus", description="Compute the modulus (remainder) of a division."),
229
- Tool.from_function(func=square_root, name="Square Root", description="Compute the square root of a number."),
230
- Tool.from_function(func=logarithm, name="Logarithm", description="Compute the logarithm of a number with a given base."),
231
- # Search
232
  web_search_tool,
233
  wikipedia_tool,
234
  arxiv_tool,
235
- # Audio
236
- Tool.from_function(func=transcribe_audio, name="Transcribe Audio", description="Transcribe audio files to text."),
237
- # Youtube
238
- Tool.from_function(func=get_youtube_transcript, name="YouTube Transcript", description="Extract transcript from YouTube video."),
239
- # Image
240
- Tool.from_function(func=extract_text_from_image, name="Image OCR", description="Extract text from an image file."),
241
- # Code Execution
242
- Tool.from_function(func=execute_python_code, name="Python Code Executor", description="Run and return output from a Python script."),
243
- # Excel parsing
244
- Tool.from_function(func=total_sales_from_excel, name="Excel Sales Parser", description="Compute total food sales from Excel file."),
245
  ]
246
 
247
-
248
- ## ----- LLM MODEL ----- ##
249
-
250
- llm = ChatOpenAI(model="gpt-4o", temperature=0)
251
- llm_with_tools = llm.bind_tools(tools)
252
-
253
  ## ----- SYSTEM PROMPT ----- ##
254
 
255
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
256
  system_prompt = f.read()
257
- print(system_prompt)
258
-
259
- # System message
260
  sys_msg = SystemMessage(content=system_prompt)
261
 
262
- ## ----- GRAPH AGENT PIPELINE ----- ##
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
 
2
  import math
3
  import whisper
 
 
 
4
  import pandas as pd
5
+ import pytesseract
6
+ from PIL import Image
7
  from dotenv import load_dotenv
8
+ from youtube_transcript_api import YouTubeTranscriptApi
9
+ from typing import TypedDict, Dict, Any, Optional, List
10
 
11
+ from langchain.tools import Tool
12
+ from langchain.utilities import WikipediaAPIWrapper, ArxivAPIWrapper, DuckDuckGoSearchRun
13
+ from langchain_community.embeddings import HuggingFaceEmbeddings
14
+ from langchain_community.vectorstores import FAISS
15
  from langchain_openai import ChatOpenAI
16
  from langchain_core.messages import HumanMessage, SystemMessage
17
+ from langchain.tools.retriever import create_retriever_tool
18
+ from langgraph.graph import StateGraph, START, END, MessagesState
19
+ from langgraph.prebuilt import ToolNode, tools_condition
20
 
21
+ # Load environment variables
22
  load_dotenv()
 
 
 
23
  openai_api_key = os.getenv("OPENAI_API_KEY")
24
 
25
+ ## ----- TOOL DEFINITIONS ----- ##
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ # Math Tools
28
+ def add_numbers(a: float, b: float) -> float: return a + b
29
+ def subtract_numbers(a: float, b: float) -> float: return a - b
30
+ def multiply_numbers(a: float, b: float) -> float: return a * b
31
  def divide_numbers(a: float, b: float) -> float:
32
+ if b == 0: raise ValueError("Division by zero")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  return a / b
34
+ def power(a: float, b: float) -> float: return a ** b
35
+ def modulus(a: float, b: float) -> float: return a % b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def square_root(a: float) -> float:
37
+ if a < 0: raise ValueError("Cannot compute square root of a negative number")
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  return math.sqrt(a)
 
39
  def logarithm(a: float, base: float = math.e) -> float:
40
+ if a <= 0 or base <= 0: raise ValueError("Logarithm arguments must be positive")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  return math.log(a, base)
42
 
43
+ # Web Search Tools
 
 
 
44
  web_search_tool = Tool.from_function(
45
+ func=DuckDuckGoSearchRun().run,
46
  name="Web Search",
47
+ description="Search the internet for general-purpose queries."
48
  )
49
 
 
 
50
  wikipedia_tool = Tool.from_function(
51
+ func=WikipediaAPIWrapper().run,
52
  name="Wikipedia Search",
53
+ description="Search Wikipedia for factual or encyclopedic information."
54
  )
55
 
 
 
56
  arxiv_tool = Tool.from_function(
57
+ func=ArxivAPIWrapper().run,
58
  name="ArXiv Search",
59
+ description="Search ArXiv for scientific papers. Input should be a research topic or query."
60
  )
61
 
62
+ # Audio Transcription
63
+ whisper_model = whisper.load_model("base")
64
+ @Tool
 
 
65
  def transcribe_audio(file_path: str) -> str:
66
+ """Transcribe audio files using Whisper."""
67
+ return whisper_model.transcribe(file_path)["text"]
 
68
 
69
+ # YouTube Transcript
70
+ @Tool
 
71
  def get_youtube_transcript(video_id: str) -> str:
72
+ """Extract transcript from YouTube video using video ID."""
73
  transcript = YouTubeTranscriptApi.get_transcript(video_id)
74
+ return " ".join(entry["text"] for entry in transcript)
 
 
75
 
76
+ # OCR Tool
77
+ @Tool
78
  def extract_text_from_image(image_path: str) -> str:
79
+ """Extract text from an image file."""
80
  return pytesseract.image_to_string(Image.open(image_path))
81
 
82
+ # Code Execution
83
+ @Tool
 
84
  def execute_python_code(code: str) -> str:
85
+ """Execute a Python script and return the output."""
86
  try:
87
  local_vars = {}
88
  exec(code, {}, local_vars)
 
90
  except Exception as e:
91
  return f"Error: {e}"
92
 
93
+ # Excel Parsing
94
+ @Tool
 
95
  def total_sales_from_excel(file_path: str) -> str:
96
  """Compute total food sales from an Excel file."""
97
  df = pd.read_excel(file_path)
98
  food_df = df[df["Category"] == "Food"]
99
+ return f"{food_df['Sales'].sum():.2f} USD"
 
 
100
 
101
+ ## ----- TOOL LIST ----- ##
102
 
103
  tools = [
104
+ Tool.from_function(add_numbers, name="Add Numbers", description="Add two numbers."),
105
+ Tool.from_function(subtract_numbers, name="Subtract Numbers", description="Subtract two numbers."),
106
+ Tool.from_function(multiply_numbers, name="Multiply Numbers", description="Multiply two numbers."),
107
+ Tool.from_function(divide_numbers, name="Divide Numbers", description="Divide two numbers."),
108
+ Tool.from_function(power, name="Power", description="Raise one number to the power of another."),
109
+ Tool.from_function(modulus, name="Modulus", description="Compute the modulus (remainder) of a division."),
110
+ Tool.from_function(square_root, name="Square Root", description="Compute the square root of a number."),
111
+ Tool.from_function(logarithm, name="Logarithm", description="Compute the logarithm of a number with a given base."),
 
 
112
  web_search_tool,
113
  wikipedia_tool,
114
  arxiv_tool,
115
+ Tool.from_function(transcribe_audio, name="Transcribe Audio", description="Transcribe audio to text."),
116
+ Tool.from_function(get_youtube_transcript, name="YouTube Transcript", description="Extract transcript from YouTube."),
117
+ Tool.from_function(extract_text_from_image, name="Image OCR", description="Extract text from an image."),
118
+ Tool.from_function(execute_python_code, name="Python Code Executor", description="Run Python code."),
119
+ Tool.from_function(total_sales_from_excel, name="Excel Sales Parser", description="Parse Excel file for total food sales."),
 
 
 
 
 
120
  ]
121
 
 
 
 
 
 
 
122
  ## ----- SYSTEM PROMPT ----- ##
123
 
124
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
125
  system_prompt = f.read()
 
 
 
126
  sys_msg = SystemMessage(content=system_prompt)
127
 
128
+ ## ----- EMBEDDINGS & VECTOR DB (FAISS) ----- ##
129
 
130
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
131
+
132
+ # Ensure `documents` is defined – this should be a list of LangChain Document objects
133
+ # Example: documents = [Document(page_content="Q: What is 2+2? A: 4", metadata={}), ...]
134
+ # If you don't have documents yet, load or define them here.
135
+ documents = [] # <-- You MUST fill this with actual documents
136
+ vector_store = FAISS.from_documents(documents, embeddings)
137
+
138
+ retriever_tool = create_retriever_tool(
139
+ retriever=vector_store.as_retriever(),
140
+ name="Question Search",
141
+ description="Retrieve similar questions from a vector store."
142
+ )
143
+
144
+ ## ----- LLM WITH TOOLS ----- ##
145
+
146
+ llm = ChatOpenAI(model="gpt-4o", temperature=0)
147
+ llm_with_tools = llm.bind_tools(tools)
148
 
149
+ ## ----- GRAPH PIPELINE ----- ##
150
+
151
+ def assistant(state: MessagesState):
152
+ """Assistant node to generate answers."""
153
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
154
+
155
+ # Use a retriever node to inject a similar example
156
+ def retriever(state: MessagesState):
157
+ """Retriever node to provide example context."""
158
+ similar = vector_store.similarity_search(state["messages"][0].content)
159
+ if not similar:
160
+ return {"messages": [sys_msg] + state["messages"]}
161
+ example = HumanMessage(content=f"Similar Q&A for context:\n\n{similar[0].page_content}")
162
+ return {"messages": [sys_msg] + state["messages"] + [example]}
163
+
164
+ # Build graph
165
+ builder = StateGraph(MessagesState)
166
+ builder.add_node("retriever", retriever)
167
+ builder.add_node("assistant", assistant)
168
+ builder.add_node("tools", ToolNode(tools))
169
+
170
+ builder.add_edge(START, "retriever")
171
+ builder.add_edge("retriever", "assistant")
172
+ builder.add_conditional_edges("assistant", tools_condition)
173
+ builder.add_edge("tools", "assistant")
174
+
175
+ graph = builder.compile()
176
+
177
+ ## ----- TESTING (Optional) ----- ##
178
+
179
+ if __name__ == "__main__":
180
+ test_question = "How many albums did Taylor Swift release before 2020?"
181
+ response = graph.invoke({"messages": [HumanMessage(content=test_question)]})
182
+ for msg in response["messages"]:
183
+ msg.pretty_print()