Spaces:
Runtime error
Runtime error
import datetime | |
from pydantic import BaseModel, Field | |
from typing import Dict, List, Optional | |
import yfinance as yf | |
import plotly.graph_objs as go | |
import plotly.express as px | |
from prophet import Prophet | |
from workcell.integrations.types import PlotlyPlot | |
class Input(BaseModel): | |
ticker: str = Field(default="AAPL", description="A ticker value, like `AAPL`, etc...") | |
def load_data(ticker): | |
"""Download ticker price data from ticker. | |
e.g. ticker = 'AAPL'|'AMZN'|'GOOG' | |
""" | |
start = datetime.datetime(2022, 1, 1) | |
end = datetime.datetime.now() # latest | |
data = yf.download(ticker, start=start, end=end, interval='1d') | |
# adjust close | |
close = data['Adj Close'] | |
return close | |
def preprocess_data(df): | |
""" | |
Preprocess dataframe for prediction. | |
- Filter out predict value. | |
""" | |
# post process | |
df_processed = df.reset_index() | |
df_processed.rename(columns={'Adj Close': 'y', 'Date': 'ds'}, inplace=True) | |
return df_processed | |
def predict_data(df, periods=30): | |
"""Predict future prices by prophet. | |
e.g. df = preprocess_df(df) | |
""" | |
# init prophet model | |
model = Prophet() | |
# fit | |
model.fit(df) | |
# predict data | |
future_prices = model.make_future_dataframe(periods=periods) | |
forecast = model.predict(future_prices) | |
# forecast data | |
df_forecast = forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']] | |
return df_forecast | |
def visualization(df_processed, df_forecast, ticker): | |
"""Visualization price plot by df_forecast dataframe. | |
""" | |
trace_open = go.Scatter( | |
x = df_forecast["ds"], | |
y = df_forecast["yhat"], | |
mode = 'lines', | |
name="Forecast" | |
) | |
trace_high = go.Scatter( | |
x = df_forecast["ds"], | |
y = df_forecast["yhat_upper"], | |
mode = 'lines', | |
fill = "tonexty", | |
line = {"color": "#57b8ff"}, | |
name="Higher uncertainty interval" | |
) | |
trace_low = go.Scatter( | |
x = df_forecast["ds"], | |
y = df_forecast["yhat_lower"], | |
mode = 'lines', | |
fill = "tonexty", | |
line = {"color": "#57b8ff"}, | |
name="Lower uncertainty interval" | |
) | |
trace_close = go.Scatter( | |
x = df_processed["ds"], | |
y = df_processed["y"], | |
name="Data values" | |
) | |
data = [trace_open,trace_high,trace_low,trace_close] | |
layout = go.Layout(title="Repsol Stock Price Forecast for: {}".format(ticker), xaxis_rangeslider_visible=True) | |
fig = go.Figure(data=data,layout=layout) | |
fig.update_xaxes( | |
rangeslider_visible=True, | |
rangeselector=dict( | |
buttons=list([ | |
dict(count=1, label="1m", step="month", stepmode="backward"), | |
dict(count=6, label="6m", step="month", stepmode="backward"), | |
dict(count=1, label="YTD", step="year", stepmode="todate"), | |
dict(count=1, label="1y", step="year", stepmode="backward"), | |
dict(step="all") | |
]) | |
) | |
) | |
fig.update_layout( | |
hovermode="x", | |
legend=dict( | |
yanchor="top", | |
y=0.99, | |
xanchor="left", | |
x=0.01 | |
) | |
) | |
return fig | |
def stock_predictor(input: Input) -> PlotlyPlot: | |
"""Input ticker, predict stocks price in 30 days by prophet. Data from yahoo finance.""" | |
# Step1. load data & preprocess | |
df = load_data(input.ticker) | |
df_processed = preprocess_data(df) | |
# Step2. predict | |
df_forecast = predict_data(df_processed) | |
# Step3. visualization | |
fig = visualization(df_processed, df_forecast, input.ticker) | |
# Step3. wrapped by output | |
output = PlotlyPlot(data=fig) | |
return output |