Commit
·
6170174
1
Parent(s):
37a70e9
added doc strings to tool
Browse files
agent.py
CHANGED
@@ -15,6 +15,7 @@ load_dotenv()
|
|
15 |
|
16 |
# Initialize LLM
|
17 |
def initialize_llm():
|
|
|
18 |
llm = ChatGroq(
|
19 |
temperature=0,
|
20 |
model_name="qwen-qwq-32b",
|
@@ -24,10 +25,21 @@ def initialize_llm():
|
|
24 |
|
25 |
# Initialize Tavily Search Tool
|
26 |
def initialize_search_tool():
|
|
|
27 |
return TavilySearchResults()
|
28 |
|
29 |
# Weather tool
|
30 |
def get_weather(location: str, search_tool: TavilySearchResults = None) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
if search_tool is None:
|
32 |
search_tool = initialize_search_tool()
|
33 |
query = f"current weather in {location}"
|
@@ -35,6 +47,15 @@ def get_weather(location: str, search_tool: TavilySearchResults = None) -> str:
|
|
35 |
|
36 |
# Recommendation chain
|
37 |
def initialize_recommendation_chain(llm: ChatGroq) -> Runnable:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
recommendation_prompt = ChatPromptTemplate.from_template("""
|
39 |
You are a helpful assistant that gives weather-based advice.
|
40 |
|
@@ -47,6 +68,16 @@ def initialize_recommendation_chain(llm: ChatGroq) -> Runnable:
|
|
47 |
return recommendation_prompt | llm
|
48 |
|
49 |
def get_recommendation(weather_condition: str, recommendation_chain: Runnable = None) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
if recommendation_chain is None:
|
51 |
llm = initialize_llm()
|
52 |
recommendation_chain = initialize_recommendation_chain(llm)
|
@@ -55,36 +86,119 @@ def get_recommendation(weather_condition: str, recommendation_chain: Runnable =
|
|
55 |
# Math tools
|
56 |
@tool
|
57 |
def add(x: int, y: int) -> int:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
return x + y
|
59 |
|
60 |
@tool
|
61 |
def subtract(x: int, y: int) -> int:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
return x - y
|
63 |
|
64 |
@tool
|
65 |
def multiply(x: int, y: int) -> int:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
return x * y
|
67 |
|
68 |
@tool
|
69 |
def divide(x: int, y: int) -> float:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
if y == 0:
|
71 |
raise ValueError("Cannot divide by zero.")
|
72 |
return x / y
|
73 |
|
74 |
@tool
|
75 |
def square(x: int) -> int:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
return x * x
|
77 |
|
78 |
@tool
|
79 |
def cube(x: int) -> int:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
return x * x * x
|
81 |
|
82 |
@tool
|
83 |
def power(x: int, y: int) -> int:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
return x ** y
|
85 |
|
86 |
@tool
|
87 |
def factorial(n: int) -> int:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
if n < 0:
|
89 |
raise ValueError("Factorial is not defined for negative numbers.")
|
90 |
if n == 0 or n == 1:
|
@@ -96,12 +210,36 @@ def factorial(n: int) -> int:
|
|
96 |
|
97 |
@tool
|
98 |
def mean(numbers: list) -> float:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
if not numbers:
|
100 |
raise ValueError("The list is empty.")
|
101 |
return sum(numbers) / len(numbers)
|
102 |
|
103 |
@tool
|
104 |
def standard_deviation(numbers: list) -> float:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
if not numbers:
|
106 |
raise ValueError("The list is empty.")
|
107 |
mean_value = mean(numbers)
|
@@ -110,16 +248,40 @@ def standard_deviation(numbers: list) -> float:
|
|
110 |
|
111 |
# Build the LangGraph
|
112 |
def build_graph():
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
llm = initialize_llm()
|
114 |
search_tool = initialize_search_tool()
|
115 |
recommendation_chain = initialize_recommendation_chain(llm)
|
116 |
|
117 |
@tool
|
118 |
def weather_tool(location: str) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
return get_weather(location, search_tool)
|
120 |
|
121 |
@tool
|
122 |
def recommendation_tool(weather_condition: str) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
return get_recommendation(weather_condition, recommendation_chain)
|
124 |
|
125 |
tools = [weather_tool, recommendation_tool, add, subtract, multiply, divide, square, cube, power, factorial, mean, standard_deviation]
|
@@ -127,6 +289,15 @@ def build_graph():
|
|
127 |
llm_with_tools = llm.bind_tools(tools)
|
128 |
|
129 |
def assistant(state: MessagesState):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
print("Entering assistant node...")
|
131 |
response = llm_with_tools.invoke(state["messages"])
|
132 |
print(f"Assistant says: {response.content}")
|
|
|
15 |
|
16 |
# Initialize LLM
|
17 |
def initialize_llm():
|
18 |
+
"""Initializes the ChatGroq LLM."""
|
19 |
llm = ChatGroq(
|
20 |
temperature=0,
|
21 |
model_name="qwen-qwq-32b",
|
|
|
25 |
|
26 |
# Initialize Tavily Search Tool
|
27 |
def initialize_search_tool():
|
28 |
+
"""Initializes the TavilySearchResults tool."""
|
29 |
return TavilySearchResults()
|
30 |
|
31 |
# Weather tool
|
32 |
def get_weather(location: str, search_tool: TavilySearchResults = None) -> str:
|
33 |
+
"""
|
34 |
+
Fetches the current weather information for a given location using Tavily search.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
location (str): The name of the location to search for.
|
38 |
+
search_tool (TavilySearchResults, optional): Defaults to None.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
str: The weather information for the specified location.
|
42 |
+
"""
|
43 |
if search_tool is None:
|
44 |
search_tool = initialize_search_tool()
|
45 |
query = f"current weather in {location}"
|
|
|
47 |
|
48 |
# Recommendation chain
|
49 |
def initialize_recommendation_chain(llm: ChatGroq) -> Runnable:
|
50 |
+
"""
|
51 |
+
Initializes the recommendation chain.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
llm(ChatGroq):The LLM to use
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
Runnable: A runnable sequence to generate recommendations.
|
58 |
+
"""
|
59 |
recommendation_prompt = ChatPromptTemplate.from_template("""
|
60 |
You are a helpful assistant that gives weather-based advice.
|
61 |
|
|
|
68 |
return recommendation_prompt | llm
|
69 |
|
70 |
def get_recommendation(weather_condition: str, recommendation_chain: Runnable = None) -> str:
|
71 |
+
"""
|
72 |
+
Gives activity/clothing recommendations and health tips based on the weather condition.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
weather_condition (str): The current weather condition.
|
76 |
+
recommendation_chain (Runnable, optional): The recommendation chain to use. Defaults to None.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
str: Recommendations and health tips for the given weather condition.
|
80 |
+
"""
|
81 |
if recommendation_chain is None:
|
82 |
llm = initialize_llm()
|
83 |
recommendation_chain = initialize_recommendation_chain(llm)
|
|
|
86 |
# Math tools
|
87 |
@tool
|
88 |
def add(x: int, y: int) -> int:
|
89 |
+
"""
|
90 |
+
Adds two integers.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
x (int): The first integer.
|
94 |
+
y (int): The second integer.
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
int: The sum of x and y.
|
98 |
+
"""
|
99 |
return x + y
|
100 |
|
101 |
@tool
|
102 |
def subtract(x: int, y: int) -> int:
|
103 |
+
"""
|
104 |
+
Subtracts two integers.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
x (int): The first integer.
|
108 |
+
y (int): The second integer.
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
int: The difference between x and y.
|
112 |
+
"""
|
113 |
return x - y
|
114 |
|
115 |
@tool
|
116 |
def multiply(x: int, y: int) -> int:
|
117 |
+
"""
|
118 |
+
Multiplies two integers.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
x (int): The first integer.
|
122 |
+
y (int): The second integer.
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
int: The product of x and y.
|
126 |
+
"""
|
127 |
return x * y
|
128 |
|
129 |
@tool
|
130 |
def divide(x: int, y: int) -> float:
|
131 |
+
"""
|
132 |
+
Divides two numbers.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
x (int): The numerator.
|
136 |
+
y (int): The denominator.
|
137 |
+
|
138 |
+
Returns:
|
139 |
+
float: The result of the division.
|
140 |
+
|
141 |
+
Raises:
|
142 |
+
ValueError: If y is zero.
|
143 |
+
"""
|
144 |
if y == 0:
|
145 |
raise ValueError("Cannot divide by zero.")
|
146 |
return x / y
|
147 |
|
148 |
@tool
|
149 |
def square(x: int) -> int:
|
150 |
+
"""
|
151 |
+
Calculates the square of a number.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
x (int): The number to square.
|
155 |
+
|
156 |
+
Returns:
|
157 |
+
int: The square of x.
|
158 |
+
"""
|
159 |
return x * x
|
160 |
|
161 |
@tool
|
162 |
def cube(x: int) -> int:
|
163 |
+
"""
|
164 |
+
Calculates the cube of a number.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
x (int): The number to cube.
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
int: The cube of x.
|
171 |
+
"""
|
172 |
return x * x * x
|
173 |
|
174 |
@tool
|
175 |
def power(x: int, y: int) -> int:
|
176 |
+
"""
|
177 |
+
Raises a number to the power of another number.
|
178 |
+
|
179 |
+
Args:
|
180 |
+
x (int): The base number.
|
181 |
+
y (int): The exponent.
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
int: x raised to the power of y.
|
185 |
+
"""
|
186 |
return x ** y
|
187 |
|
188 |
@tool
|
189 |
def factorial(n: int) -> int:
|
190 |
+
"""
|
191 |
+
Calculates the factorial of a non-negative integer.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
n (int): The non-negative integer.
|
195 |
+
|
196 |
+
Returns:
|
197 |
+
int: The factorial of n.
|
198 |
+
|
199 |
+
Raises:
|
200 |
+
ValueError: If n is negative.
|
201 |
+
"""
|
202 |
if n < 0:
|
203 |
raise ValueError("Factorial is not defined for negative numbers.")
|
204 |
if n == 0 or n == 1:
|
|
|
210 |
|
211 |
@tool
|
212 |
def mean(numbers: list) -> float:
|
213 |
+
"""
|
214 |
+
Calculates the mean of a list of numbers.
|
215 |
+
|
216 |
+
Args:
|
217 |
+
numbers (list): A list of numbers.
|
218 |
+
|
219 |
+
Returns:
|
220 |
+
float: The mean of the numbers.
|
221 |
+
|
222 |
+
Raises:
|
223 |
+
ValueError: If the list is empty.
|
224 |
+
"""
|
225 |
if not numbers:
|
226 |
raise ValueError("The list is empty.")
|
227 |
return sum(numbers) / len(numbers)
|
228 |
|
229 |
@tool
|
230 |
def standard_deviation(numbers: list) -> float:
|
231 |
+
"""
|
232 |
+
Calculates the standard deviation of a list of numbers.
|
233 |
+
|
234 |
+
Args:
|
235 |
+
numbers (list): A list of numbers.
|
236 |
+
|
237 |
+
Returns:
|
238 |
+
float: The standard deviation of the numbers.
|
239 |
+
|
240 |
+
Raises:
|
241 |
+
ValueError: If the list is empty.
|
242 |
+
"""
|
243 |
if not numbers:
|
244 |
raise ValueError("The list is empty.")
|
245 |
mean_value = mean(numbers)
|
|
|
248 |
|
249 |
# Build the LangGraph
|
250 |
def build_graph():
|
251 |
+
"""
|
252 |
+
Builds the LangGraph with the defined tools and assistant node.
|
253 |
+
|
254 |
+
Returns:
|
255 |
+
RunnableGraph: The compiled LangGraph.
|
256 |
+
"""
|
257 |
llm = initialize_llm()
|
258 |
search_tool = initialize_search_tool()
|
259 |
recommendation_chain = initialize_recommendation_chain(llm)
|
260 |
|
261 |
@tool
|
262 |
def weather_tool(location: str) -> str:
|
263 |
+
"""
|
264 |
+
Fetches the weather for a location.
|
265 |
+
|
266 |
+
Args:
|
267 |
+
location (str): The location to fetch weather for.
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
str: The weather information.
|
271 |
+
"""
|
272 |
return get_weather(location, search_tool)
|
273 |
|
274 |
@tool
|
275 |
def recommendation_tool(weather_condition: str) -> str:
|
276 |
+
"""
|
277 |
+
Provides recommendations based on weather conditions.
|
278 |
+
|
279 |
+
Args:
|
280 |
+
weather_condition (str): The weather condition.
|
281 |
+
|
282 |
+
Returns:
|
283 |
+
str: The recommendations.
|
284 |
+
"""
|
285 |
return get_recommendation(weather_condition, recommendation_chain)
|
286 |
|
287 |
tools = [weather_tool, recommendation_tool, add, subtract, multiply, divide, square, cube, power, factorial, mean, standard_deviation]
|
|
|
289 |
llm_with_tools = llm.bind_tools(tools)
|
290 |
|
291 |
def assistant(state: MessagesState):
|
292 |
+
"""
|
293 |
+
Assistant node in the LangGraph.
|
294 |
+
|
295 |
+
Args:
|
296 |
+
state (MessagesState): The current state of the conversation.
|
297 |
+
|
298 |
+
Returns:
|
299 |
+
dict: The next state of the conversation.
|
300 |
+
"""
|
301 |
print("Entering assistant node...")
|
302 |
response = llm_with_tools.invoke(state["messages"])
|
303 |
print(f"Assistant says: {response.content}")
|