jiandong commited on
Commit
9af85d0
·
1 Parent(s): 5eece23

Upload with huggingface_hub

Browse files
Files changed (4) hide show
  1. Dockerfile +20 -0
  2. app.py +126 -0
  3. requirements.txt +4 -0
  4. workcell.yaml +10 -0
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM python:3.8
5
+
6
+ # Set up a new user named "user" with user ID 1000
7
+ RUN useradd -m -u 1000 user
8
+ # Switch to the "user" user
9
+ USER user
10
+ # Set home to the user's home directory
11
+ ENV HOME=/home/user \
12
+ PATH=/home/user/.local/bin:$PATH
13
+ # Set the working directory to the user's home directory
14
+ WORKDIR $HOME/app
15
+
16
+ # Copy the current directory contents into the container at $HOME/app setting the owner to the user
17
+ COPY --chown=user . $HOME/app
18
+ RUN pip install --no-cache-dir --upgrade -r $HOME/app/requirements.txt
19
+
20
+ CMD ["workcell", "serve", "--config", "workcell.yaml", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ from pydantic import BaseModel, Field
3
+ from typing import Dict, List, Optional
4
+ import yfinance as yf
5
+ import plotly.graph_objs as go
6
+ import plotly.express as px
7
+ from prophet import Prophet
8
+ from workcell.integrations.types import PlotlyPlot
9
+
10
+
11
+ class Input(BaseModel):
12
+ ticker: str = Field(default="AAPL", description="A ticker value, like `AAPL`, etc...")
13
+
14
+
15
+ def load_data(ticker):
16
+ """Download ticker price data from ticker.
17
+ e.g. ticker = 'AAPL'|'AMZN'|'GOOG'
18
+ """
19
+ start = datetime.datetime(2022, 1, 1)
20
+ end = datetime.datetime.now() # latest
21
+ data = yf.download(ticker, start=start, end=end, interval='1d')
22
+ # adjust close
23
+ close = data['Adj Close']
24
+ return close
25
+
26
+
27
+ def preprocess_data(df):
28
+ """
29
+ Preprocess dataframe for prediction.
30
+ - Filter out predict value.
31
+ """
32
+ # post process
33
+ df_processed = df.reset_index()
34
+ df_processed.rename(columns={'Adj Close': 'y', 'Date': 'ds'}, inplace=True)
35
+ return df_processed
36
+
37
+
38
+ def predict_data(df, periods=30):
39
+ """Predict future prices by prophet.
40
+ e.g. df = preprocess_df(df)
41
+ """
42
+ # init prophet model
43
+ model = Prophet()
44
+ # fit
45
+ model.fit(df)
46
+ # predict data
47
+ future_prices = model.make_future_dataframe(periods=periods)
48
+ forecast = model.predict(future_prices)
49
+ # forecast data
50
+ df_forecast = forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']]
51
+ return df_forecast
52
+
53
+
54
+ def visualization(df_processed, df_forecast, ticker):
55
+ """Visualization price plot by df_forecast dataframe.
56
+ """
57
+ trace_open = go.Scatter(
58
+ x = df_forecast["ds"],
59
+ y = df_forecast["yhat"],
60
+ mode = 'lines',
61
+ name="Forecast"
62
+ )
63
+
64
+ trace_high = go.Scatter(
65
+ x = df_forecast["ds"],
66
+ y = df_forecast["yhat_upper"],
67
+ mode = 'lines',
68
+ fill = "tonexty",
69
+ line = {"color": "#57b8ff"},
70
+ name="Higher uncertainty interval"
71
+ )
72
+
73
+ trace_low = go.Scatter(
74
+ x = df_forecast["ds"],
75
+ y = df_forecast["yhat_lower"],
76
+ mode = 'lines',
77
+ fill = "tonexty",
78
+ line = {"color": "#57b8ff"},
79
+ name="Lower uncertainty interval"
80
+ )
81
+
82
+ trace_close = go.Scatter(
83
+ x = df_processed["ds"],
84
+ y = df_processed["y"],
85
+ name="Data values"
86
+ )
87
+
88
+ data = [trace_open,trace_high,trace_low,trace_close]
89
+ layout = go.Layout(title="Repsol Stock Price Forecast for: {}".format(ticker), xaxis_rangeslider_visible=True)
90
+ fig = go.Figure(data=data,layout=layout)
91
+ fig.update_xaxes(
92
+ rangeslider_visible=True,
93
+ rangeselector=dict(
94
+ buttons=list([
95
+ dict(count=1, label="1m", step="month", stepmode="backward"),
96
+ dict(count=6, label="6m", step="month", stepmode="backward"),
97
+ dict(count=1, label="YTD", step="year", stepmode="todate"),
98
+ dict(count=1, label="1y", step="year", stepmode="backward"),
99
+ dict(step="all")
100
+ ])
101
+ )
102
+ )
103
+ fig.update_layout(
104
+ hovermode="x",
105
+ legend=dict(
106
+ yanchor="top",
107
+ y=0.99,
108
+ xanchor="left",
109
+ x=0.01
110
+ )
111
+ )
112
+ return fig
113
+
114
+
115
+ def stock_predictor(input: Input) -> PlotlyPlot:
116
+ """Input ticker, predict stocks price in 30 days by prophet. Data from yahoo finance."""
117
+ # Step1. load data & preprocess
118
+ df = load_data(input.ticker)
119
+ df_processed = preprocess_data(df)
120
+ # Step2. predict
121
+ df_forecast = predict_data(df_processed)
122
+ # Step3. visualization
123
+ fig = visualization(df_processed, df_forecast, input.ticker)
124
+ # Step3. wrapped by output
125
+ output = PlotlyPlot(data=fig)
126
+ return output
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ workcell
2
+ yfinance
3
+ plotly
4
+ prophet
workcell.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ workcell_name: stock_predictor
2
+ workcell_provider: huggingface
3
+ workcell_id: weanalyze/stock_predictor
4
+ workcell_version: latest
5
+ workcell_runtime: python3.8
6
+ workcell_entrypoint: app:stock_predictor
7
+ workcell_code:
8
+ ImageUri: ''
9
+ workcell_tags: {}
10
+ workcell_envs: {}