parjun commited on
Commit
003d36d
·
verified ·
1 Parent(s): 9cb0fcd

Create stock_agent.py

Browse files
Files changed (1) hide show
  1. tools/stock_agent.py +40 -0
tools/stock_agent.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ from smolagents.tools import Tool
3
+ import yfinance as yf
4
+ import matplotlib.pyplot as plt
5
+ from io import BytesIO
6
+ import base64
7
+ import matplotlib
8
+
9
+ class StockChartTool(Tool):
10
+ name = "stock_chart"
11
+ description = "Generates a stock price chart for the given ticker symbol. Returns the chart as a base64 encoded PNG image."
12
+ inputs = {'ticker': {'type': 'string', 'description': 'The stock ticker symbol (e.g., AAPL, MSFT).'}}
13
+ output_type = "string" # Base64 encoded image string
14
+
15
+ def forward(self, ticker: str) -> str:
16
+ try:
17
+ data = yf.download(ticker, period="30d") # Download 30 days of data
18
+ if data.empty:
19
+ return "Error: No data found for that ticker."
20
+
21
+ plt.figure(figsize=(10, 6))
22
+ plt.plot(data['Close'])
23
+ plt.title(f"{ticker} Stock Price (Last 30 Days)")
24
+ plt.xlabel("Date")
25
+ plt.ylabel("Closing Price")
26
+ plt.grid(True)
27
+
28
+ img_buf = BytesIO()
29
+ plt.savefig(img_buf, format='png')
30
+ img_buf.seek(0)
31
+ img_base64 = base64.b64encode(img_buf.getvalue()).decode('utf-8')
32
+ plt.close() # Important to close the plot to release memory
33
+
34
+ return f"<img src='data:image/png;base64,{img_base64}' alt='Stock Price Plot'>"
35
+
36
+ except Exception as e:
37
+ return f"Error generating chart: {e}"
38
+
39
+ def __init__(self, *args, **kwargs):
40
+ self.is_initialized = False # Add this line to avoid the warning