import datetime from pydantic import BaseModel, Field from typing import Dict, List, Optional import yfinance as yf import plotly.graph_objs as go import plotly.express as px from prophet import Prophet from workcell.integrations.types import PlotlyPlot class Input(BaseModel): ticker: str = Field(default="AAPL", description="A ticker value, like `AAPL`, etc...") def load_data(ticker): """Download ticker price data from ticker. e.g. ticker = 'AAPL'|'AMZN'|'GOOG' """ start = datetime.datetime(2022, 1, 1) end = datetime.datetime.now() # latest data = yf.download(ticker, start=start, end=end, interval='1d') # adjust close close = data['Adj Close'] return close def preprocess_data(df): """ Preprocess dataframe for prediction. - Filter out predict value. """ # post process df_processed = df.reset_index() df_processed.rename(columns={'Adj Close': 'y', 'Date': 'ds'}, inplace=True) return df_processed def predict_data(df, periods=30): """Predict future prices by prophet. e.g. df = preprocess_df(df) """ # init prophet model model = Prophet() # fit model.fit(df) # predict data future_prices = model.make_future_dataframe(periods=periods) forecast = model.predict(future_prices) # forecast data df_forecast = forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']] return df_forecast def visualization(df_processed, df_forecast, ticker): """Visualization price plot by df_forecast dataframe. """ trace_open = go.Scatter( x = df_forecast["ds"], y = df_forecast["yhat"], mode = 'lines', name="Forecast" ) trace_high = go.Scatter( x = df_forecast["ds"], y = df_forecast["yhat_upper"], mode = 'lines', fill = "tonexty", line = {"color": "#57b8ff"}, name="Higher uncertainty interval" ) trace_low = go.Scatter( x = df_forecast["ds"], y = df_forecast["yhat_lower"], mode = 'lines', fill = "tonexty", line = {"color": "#57b8ff"}, name="Lower uncertainty interval" ) trace_close = go.Scatter( x = df_processed["ds"], y = df_processed["y"], name="Data values" ) data = [trace_open,trace_high,trace_low,trace_close] layout = go.Layout(title="Repsol Stock Price Forecast for: {}".format(ticker), xaxis_rangeslider_visible=True) fig = go.Figure(data=data,layout=layout) fig.update_xaxes( rangeslider_visible=True, rangeselector=dict( buttons=list([ dict(count=1, label="1m", step="month", stepmode="backward"), dict(count=6, label="6m", step="month", stepmode="backward"), dict(count=1, label="YTD", step="year", stepmode="todate"), dict(count=1, label="1y", step="year", stepmode="backward"), dict(step="all") ]) ) ) fig.update_layout( hovermode="x", legend=dict( yanchor="top", y=0.99, xanchor="left", x=0.01 ) ) return fig def stock_predictor(input: Input) -> PlotlyPlot: """Input ticker, predict stocks price in 30 days by prophet. Data from yahoo finance.""" # Step1. load data & preprocess df = load_data(input.ticker) df_processed = preprocess_data(df) # Step2. predict df_forecast = predict_data(df_processed) # Step3. visualization fig = visualization(df_processed, df_forecast, input.ticker) # Step3. wrapped by output output = PlotlyPlot(data=fig) return output