HarshitSundriyal commited on
Commit
6170174
·
1 Parent(s): 37a70e9

added doc strings to tool

Browse files
Files changed (1) hide show
  1. agent.py +171 -0
agent.py CHANGED
@@ -15,6 +15,7 @@ load_dotenv()
15
 
16
  # Initialize LLM
17
  def initialize_llm():
 
18
  llm = ChatGroq(
19
  temperature=0,
20
  model_name="qwen-qwq-32b",
@@ -24,10 +25,21 @@ def initialize_llm():
24
 
25
  # Initialize Tavily Search Tool
26
  def initialize_search_tool():
 
27
  return TavilySearchResults()
28
 
29
  # Weather tool
30
  def get_weather(location: str, search_tool: TavilySearchResults = None) -> str:
 
 
 
 
 
 
 
 
 
 
31
  if search_tool is None:
32
  search_tool = initialize_search_tool()
33
  query = f"current weather in {location}"
@@ -35,6 +47,15 @@ def get_weather(location: str, search_tool: TavilySearchResults = None) -> str:
35
 
36
  # Recommendation chain
37
  def initialize_recommendation_chain(llm: ChatGroq) -> Runnable:
 
 
 
 
 
 
 
 
 
38
  recommendation_prompt = ChatPromptTemplate.from_template("""
39
  You are a helpful assistant that gives weather-based advice.
40
 
@@ -47,6 +68,16 @@ def initialize_recommendation_chain(llm: ChatGroq) -> Runnable:
47
  return recommendation_prompt | llm
48
 
49
  def get_recommendation(weather_condition: str, recommendation_chain: Runnable = None) -> str:
 
 
 
 
 
 
 
 
 
 
50
  if recommendation_chain is None:
51
  llm = initialize_llm()
52
  recommendation_chain = initialize_recommendation_chain(llm)
@@ -55,36 +86,119 @@ def get_recommendation(weather_condition: str, recommendation_chain: Runnable =
55
  # Math tools
56
  @tool
57
  def add(x: int, y: int) -> int:
 
 
 
 
 
 
 
 
 
 
58
  return x + y
59
 
60
  @tool
61
  def subtract(x: int, y: int) -> int:
 
 
 
 
 
 
 
 
 
 
62
  return x - y
63
 
64
  @tool
65
  def multiply(x: int, y: int) -> int:
 
 
 
 
 
 
 
 
 
 
66
  return x * y
67
 
68
  @tool
69
  def divide(x: int, y: int) -> float:
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  if y == 0:
71
  raise ValueError("Cannot divide by zero.")
72
  return x / y
73
 
74
  @tool
75
  def square(x: int) -> int:
 
 
 
 
 
 
 
 
 
76
  return x * x
77
 
78
  @tool
79
  def cube(x: int) -> int:
 
 
 
 
 
 
 
 
 
80
  return x * x * x
81
 
82
  @tool
83
  def power(x: int, y: int) -> int:
 
 
 
 
 
 
 
 
 
 
84
  return x ** y
85
 
86
  @tool
87
  def factorial(n: int) -> int:
 
 
 
 
 
 
 
 
 
 
 
 
88
  if n < 0:
89
  raise ValueError("Factorial is not defined for negative numbers.")
90
  if n == 0 or n == 1:
@@ -96,12 +210,36 @@ def factorial(n: int) -> int:
96
 
97
  @tool
98
  def mean(numbers: list) -> float:
 
 
 
 
 
 
 
 
 
 
 
 
99
  if not numbers:
100
  raise ValueError("The list is empty.")
101
  return sum(numbers) / len(numbers)
102
 
103
  @tool
104
  def standard_deviation(numbers: list) -> float:
 
 
 
 
 
 
 
 
 
 
 
 
105
  if not numbers:
106
  raise ValueError("The list is empty.")
107
  mean_value = mean(numbers)
@@ -110,16 +248,40 @@ def standard_deviation(numbers: list) -> float:
110
 
111
  # Build the LangGraph
112
  def build_graph():
 
 
 
 
 
 
113
  llm = initialize_llm()
114
  search_tool = initialize_search_tool()
115
  recommendation_chain = initialize_recommendation_chain(llm)
116
 
117
  @tool
118
  def weather_tool(location: str) -> str:
 
 
 
 
 
 
 
 
 
119
  return get_weather(location, search_tool)
120
 
121
  @tool
122
  def recommendation_tool(weather_condition: str) -> str:
 
 
 
 
 
 
 
 
 
123
  return get_recommendation(weather_condition, recommendation_chain)
124
 
125
  tools = [weather_tool, recommendation_tool, add, subtract, multiply, divide, square, cube, power, factorial, mean, standard_deviation]
@@ -127,6 +289,15 @@ def build_graph():
127
  llm_with_tools = llm.bind_tools(tools)
128
 
129
  def assistant(state: MessagesState):
 
 
 
 
 
 
 
 
 
130
  print("Entering assistant node...")
131
  response = llm_with_tools.invoke(state["messages"])
132
  print(f"Assistant says: {response.content}")
 
15
 
16
  # Initialize LLM
17
  def initialize_llm():
18
+ """Initializes the ChatGroq LLM."""
19
  llm = ChatGroq(
20
  temperature=0,
21
  model_name="qwen-qwq-32b",
 
25
 
26
  # Initialize Tavily Search Tool
27
  def initialize_search_tool():
28
+ """Initializes the TavilySearchResults tool."""
29
  return TavilySearchResults()
30
 
31
  # Weather tool
32
  def get_weather(location: str, search_tool: TavilySearchResults = None) -> str:
33
+ """
34
+ Fetches the current weather information for a given location using Tavily search.
35
+
36
+ Args:
37
+ location (str): The name of the location to search for.
38
+ search_tool (TavilySearchResults, optional): Defaults to None.
39
+
40
+ Returns:
41
+ str: The weather information for the specified location.
42
+ """
43
  if search_tool is None:
44
  search_tool = initialize_search_tool()
45
  query = f"current weather in {location}"
 
47
 
48
  # Recommendation chain
49
  def initialize_recommendation_chain(llm: ChatGroq) -> Runnable:
50
+ """
51
+ Initializes the recommendation chain.
52
+
53
+ Args:
54
+ llm(ChatGroq):The LLM to use
55
+
56
+ Returns:
57
+ Runnable: A runnable sequence to generate recommendations.
58
+ """
59
  recommendation_prompt = ChatPromptTemplate.from_template("""
60
  You are a helpful assistant that gives weather-based advice.
61
 
 
68
  return recommendation_prompt | llm
69
 
70
  def get_recommendation(weather_condition: str, recommendation_chain: Runnable = None) -> str:
71
+ """
72
+ Gives activity/clothing recommendations and health tips based on the weather condition.
73
+
74
+ Args:
75
+ weather_condition (str): The current weather condition.
76
+ recommendation_chain (Runnable, optional): The recommendation chain to use. Defaults to None.
77
+
78
+ Returns:
79
+ str: Recommendations and health tips for the given weather condition.
80
+ """
81
  if recommendation_chain is None:
82
  llm = initialize_llm()
83
  recommendation_chain = initialize_recommendation_chain(llm)
 
86
  # Math tools
87
  @tool
88
  def add(x: int, y: int) -> int:
89
+ """
90
+ Adds two integers.
91
+
92
+ Args:
93
+ x (int): The first integer.
94
+ y (int): The second integer.
95
+
96
+ Returns:
97
+ int: The sum of x and y.
98
+ """
99
  return x + y
100
 
101
  @tool
102
  def subtract(x: int, y: int) -> int:
103
+ """
104
+ Subtracts two integers.
105
+
106
+ Args:
107
+ x (int): The first integer.
108
+ y (int): The second integer.
109
+
110
+ Returns:
111
+ int: The difference between x and y.
112
+ """
113
  return x - y
114
 
115
  @tool
116
  def multiply(x: int, y: int) -> int:
117
+ """
118
+ Multiplies two integers.
119
+
120
+ Args:
121
+ x (int): The first integer.
122
+ y (int): The second integer.
123
+
124
+ Returns:
125
+ int: The product of x and y.
126
+ """
127
  return x * y
128
 
129
  @tool
130
  def divide(x: int, y: int) -> float:
131
+ """
132
+ Divides two numbers.
133
+
134
+ Args:
135
+ x (int): The numerator.
136
+ y (int): The denominator.
137
+
138
+ Returns:
139
+ float: The result of the division.
140
+
141
+ Raises:
142
+ ValueError: If y is zero.
143
+ """
144
  if y == 0:
145
  raise ValueError("Cannot divide by zero.")
146
  return x / y
147
 
148
  @tool
149
  def square(x: int) -> int:
150
+ """
151
+ Calculates the square of a number.
152
+
153
+ Args:
154
+ x (int): The number to square.
155
+
156
+ Returns:
157
+ int: The square of x.
158
+ """
159
  return x * x
160
 
161
  @tool
162
  def cube(x: int) -> int:
163
+ """
164
+ Calculates the cube of a number.
165
+
166
+ Args:
167
+ x (int): The number to cube.
168
+
169
+ Returns:
170
+ int: The cube of x.
171
+ """
172
  return x * x * x
173
 
174
  @tool
175
  def power(x: int, y: int) -> int:
176
+ """
177
+ Raises a number to the power of another number.
178
+
179
+ Args:
180
+ x (int): The base number.
181
+ y (int): The exponent.
182
+
183
+ Returns:
184
+ int: x raised to the power of y.
185
+ """
186
  return x ** y
187
 
188
  @tool
189
  def factorial(n: int) -> int:
190
+ """
191
+ Calculates the factorial of a non-negative integer.
192
+
193
+ Args:
194
+ n (int): The non-negative integer.
195
+
196
+ Returns:
197
+ int: The factorial of n.
198
+
199
+ Raises:
200
+ ValueError: If n is negative.
201
+ """
202
  if n < 0:
203
  raise ValueError("Factorial is not defined for negative numbers.")
204
  if n == 0 or n == 1:
 
210
 
211
  @tool
212
  def mean(numbers: list) -> float:
213
+ """
214
+ Calculates the mean of a list of numbers.
215
+
216
+ Args:
217
+ numbers (list): A list of numbers.
218
+
219
+ Returns:
220
+ float: The mean of the numbers.
221
+
222
+ Raises:
223
+ ValueError: If the list is empty.
224
+ """
225
  if not numbers:
226
  raise ValueError("The list is empty.")
227
  return sum(numbers) / len(numbers)
228
 
229
  @tool
230
  def standard_deviation(numbers: list) -> float:
231
+ """
232
+ Calculates the standard deviation of a list of numbers.
233
+
234
+ Args:
235
+ numbers (list): A list of numbers.
236
+
237
+ Returns:
238
+ float: The standard deviation of the numbers.
239
+
240
+ Raises:
241
+ ValueError: If the list is empty.
242
+ """
243
  if not numbers:
244
  raise ValueError("The list is empty.")
245
  mean_value = mean(numbers)
 
248
 
249
  # Build the LangGraph
250
  def build_graph():
251
+ """
252
+ Builds the LangGraph with the defined tools and assistant node.
253
+
254
+ Returns:
255
+ RunnableGraph: The compiled LangGraph.
256
+ """
257
  llm = initialize_llm()
258
  search_tool = initialize_search_tool()
259
  recommendation_chain = initialize_recommendation_chain(llm)
260
 
261
  @tool
262
  def weather_tool(location: str) -> str:
263
+ """
264
+ Fetches the weather for a location.
265
+
266
+ Args:
267
+ location (str): The location to fetch weather for.
268
+
269
+ Returns:
270
+ str: The weather information.
271
+ """
272
  return get_weather(location, search_tool)
273
 
274
  @tool
275
  def recommendation_tool(weather_condition: str) -> str:
276
+ """
277
+ Provides recommendations based on weather conditions.
278
+
279
+ Args:
280
+ weather_condition (str): The weather condition.
281
+
282
+ Returns:
283
+ str: The recommendations.
284
+ """
285
  return get_recommendation(weather_condition, recommendation_chain)
286
 
287
  tools = [weather_tool, recommendation_tool, add, subtract, multiply, divide, square, cube, power, factorial, mean, standard_deviation]
 
289
  llm_with_tools = llm.bind_tools(tools)
290
 
291
  def assistant(state: MessagesState):
292
+ """
293
+ Assistant node in the LangGraph.
294
+
295
+ Args:
296
+ state (MessagesState): The current state of the conversation.
297
+
298
+ Returns:
299
+ dict: The next state of the conversation.
300
+ """
301
  print("Entering assistant node...")
302
  response = llm_with_tools.invoke(state["messages"])
303
  print(f"Assistant says: {response.content}")