stock_predictor / app.py
jiandong's picture
Upload with huggingface_hub
9af85d0
import datetime
from pydantic import BaseModel, Field
from typing import Dict, List, Optional
import yfinance as yf
import plotly.graph_objs as go
import plotly.express as px
from prophet import Prophet
from workcell.integrations.types import PlotlyPlot
class Input(BaseModel):
ticker: str = Field(default="AAPL", description="A ticker value, like `AAPL`, etc...")
def load_data(ticker):
"""Download ticker price data from ticker.
e.g. ticker = 'AAPL'|'AMZN'|'GOOG'
"""
start = datetime.datetime(2022, 1, 1)
end = datetime.datetime.now() # latest
data = yf.download(ticker, start=start, end=end, interval='1d')
# adjust close
close = data['Adj Close']
return close
def preprocess_data(df):
"""
Preprocess dataframe for prediction.
- Filter out predict value.
"""
# post process
df_processed = df.reset_index()
df_processed.rename(columns={'Adj Close': 'y', 'Date': 'ds'}, inplace=True)
return df_processed
def predict_data(df, periods=30):
"""Predict future prices by prophet.
e.g. df = preprocess_df(df)
"""
# init prophet model
model = Prophet()
# fit
model.fit(df)
# predict data
future_prices = model.make_future_dataframe(periods=periods)
forecast = model.predict(future_prices)
# forecast data
df_forecast = forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']]
return df_forecast
def visualization(df_processed, df_forecast, ticker):
"""Visualization price plot by df_forecast dataframe.
"""
trace_open = go.Scatter(
x = df_forecast["ds"],
y = df_forecast["yhat"],
mode = 'lines',
name="Forecast"
)
trace_high = go.Scatter(
x = df_forecast["ds"],
y = df_forecast["yhat_upper"],
mode = 'lines',
fill = "tonexty",
line = {"color": "#57b8ff"},
name="Higher uncertainty interval"
)
trace_low = go.Scatter(
x = df_forecast["ds"],
y = df_forecast["yhat_lower"],
mode = 'lines',
fill = "tonexty",
line = {"color": "#57b8ff"},
name="Lower uncertainty interval"
)
trace_close = go.Scatter(
x = df_processed["ds"],
y = df_processed["y"],
name="Data values"
)
data = [trace_open,trace_high,trace_low,trace_close]
layout = go.Layout(title="Repsol Stock Price Forecast for: {}".format(ticker), xaxis_rangeslider_visible=True)
fig = go.Figure(data=data,layout=layout)
fig.update_xaxes(
rangeslider_visible=True,
rangeselector=dict(
buttons=list([
dict(count=1, label="1m", step="month", stepmode="backward"),
dict(count=6, label="6m", step="month", stepmode="backward"),
dict(count=1, label="YTD", step="year", stepmode="todate"),
dict(count=1, label="1y", step="year", stepmode="backward"),
dict(step="all")
])
)
)
fig.update_layout(
hovermode="x",
legend=dict(
yanchor="top",
y=0.99,
xanchor="left",
x=0.01
)
)
return fig
def stock_predictor(input: Input) -> PlotlyPlot:
"""Input ticker, predict stocks price in 30 days by prophet. Data from yahoo finance."""
# Step1. load data & preprocess
df = load_data(input.ticker)
df_processed = preprocess_data(df)
# Step2. predict
df_forecast = predict_data(df_processed)
# Step3. visualization
fig = visualization(df_processed, df_forecast, input.ticker)
# Step3. wrapped by output
output = PlotlyPlot(data=fig)
return output