Geek7 commited on
Commit
3bb41b1
·
verified ·
1 Parent(s): c28f349

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -34
app.py CHANGED
@@ -86,38 +86,93 @@ def lstm_gru_forecast(data, model_type, steps):
86
  def ensemble_forecast(predictions_list):
87
  return pd.DataFrame(predictions_list).mean(axis=0)
88
 
89
- # Streamlit App
90
- st.title("Stock Price Forecasting App")
91
-
92
- # Load stock data
93
- symbol = 'AAPL' # Replace with the desired stock symbol
94
- start_date = '2021-01-01'
95
- end_date = '2022-01-01'
96
- stock_prices = get_stock_data(symbol, start_date, end_date)
97
-
98
- # ARIMA parameters
99
- arima_order = (3, 0, 0) # Example: AR component (p) is set to 3, differencing (d) is 0, MA component (q) is 0
100
- arima_forecast_steps = 30 # Number of steps to forecast (adjust based on your preference)
101
-
102
- # LSTM and GRU parameters
103
- lstm_gru_forecast_steps = 30 # Number of steps to forecast (adjust based on your preference)
104
-
105
- # ARIMA Forecast
106
- arima_predictions = arima_forecast(stock_prices, arima_order, arima_forecast_steps)
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
- # LSTM Forecast
109
- lstm_predictions = lstm_gru_forecast(stock_prices, 'lstm', lstm_gru_forecast_steps)
110
-
111
- # GRU Forecast
112
- gru_predictions = lstm_gru_forecast(stock_prices, 'gru', lstm_gru_forecast_steps)
113
-
114
- # Ensemble Forecast (Averaging)
115
- ensemble_predictions = ensemble_forecast([arima_predictions, lstm_predictions, gru_predictions])
116
-
117
- # Plotting
118
- st.write("### Historical Stock Prices and Forecasts")
119
- st.line_chart(stock_prices)
120
- st.line_chart(arima_predictions)
121
- st.line_chart(lstm_predictions)
122
- st.line_chart(gru_predictions)
123
- st.line_chart(ensemble_predictions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  def ensemble_forecast(predictions_list):
87
  return pd.DataFrame(predictions_list).mean(axis=0)
88
 
89
+ # Function to fit ARIMA model and make predictions
90
+ def arima_forecast(data, p, d, q, steps):
91
+ # Differencing
92
+ for i in range(d):
93
+ data_diff = np.diff(data)
94
+ data = data_diff
95
+
96
+ # Autoregressive (AR) and Moving Average (MA) components
97
+ ar_coef = np.zeros(p) if p > 0 else []
98
+ ma_coef = np.zeros(q) if q > 0 else []
99
+
100
+ # Initial prediction
101
+ predictions = list(data[:p])
102
+
103
+ # ARIMA forecasting
104
+ for i in range(len(data) - p):
105
+ ar_term = sum(ar_coef[j] * data[i + p - j - 1] for j in range(p))
106
+ ma_term = sum(ma_coef[j] * (data[i + p - j - 1] - predictions[-1]) for j in range(q))
107
+ next_prediction = data[i + p] + ar_term + ma_term
108
+ predictions.append(next_prediction)
109
+
110
+ # Update coefficients using online learning (optional)
111
+ if i + p + 1 < len(data):
112
+ ar_coef = ar_coef + (2.0 / (i + p + 2)) * (data[i + p + 1] - next_prediction) * np.flip(data[i:i + p])
113
+ ma_coef = ma_coef + (2.0 / (i + p + 2)) * (data[i + p + 1] - next_prediction) * np.flip(predictions[i - q + 1:i + 1])
114
+
115
+ # Inverse differencing
116
+ for i in range(d):
117
+ predictions = np.cumsum([data[p - 1]] + predictions)
118
+
119
+ return predictions[-steps:]
120
 
121
+ # Streamlit App
122
+ def main():
123
+ st.title("Stock Price Forecasting App")
124
+
125
+ # Load stock data using Streamlit sidebar
126
+ symbol = st.sidebar.text_input("Enter Stock Symbol", value='AAPL')
127
+ start_date = st.sidebar.text_input("Enter Start Date", value='2021-01-01')
128
+ end_date = st.sidebar.text_input("Enter End Date", value='2022-01-01')
129
+ stock_prices = get_stock_data(symbol, start_date, end_date)
130
+
131
+ # ARIMA parameters using Streamlit sliders
132
+ p = st.sidebar.slider("AR Component (p)", min_value=0, max_value=10, value=3)
133
+ d = st.sidebar.slider("Differencing (d)", min_value=0, max_value=5, value=0)
134
+ q = st.sidebar.slider("MA Component (q)", min_value=0, max_value=10, value=0)
135
+ arima_forecast_steps = st.sidebar.slider("ARIMA Forecast Steps", min_value=1, max_value=100, value=30)
136
+
137
+ # LSTM and GRU parameters using Streamlit sliders
138
+ lstm_gru_forecast_steps = st.sidebar.slider("LSTM/GRU Forecast Steps", min_value=1, max_value=100, value=30)
139
+
140
+ # Custom ARIMA Forecast using Streamlit button
141
+ if st.sidebar.button("Run Custom ARIMA Forecast"):
142
+ arima_predictions_custom = arima_forecast(stock_prices.values, p, d, q, arima_forecast_steps)
143
+ arima_predictions_custom = pd.Series(arima_predictions_custom, index=pd.date_range(start=stock_prices.index[-1], periods=arima_forecast_steps + 1, freq=stock_prices.index.freq))
144
+
145
+ # Display ARIMA Forecast Plot
146
+ st.subheader("Custom ARIMA Forecast")
147
+ st.line_chart(pd.concat([stock_prices, arima_predictions_custom], axis=1).rename(columns={0: "ARIMA Forecast"}))
148
+
149
+ # LSTM Forecast using Streamlit button
150
+ if st.sidebar.button("Run LSTM Forecast"):
151
+ lstm_predictions = lstm_gru_forecast(stock_prices, 'lstm', lstm_gru_forecast_steps)
152
+
153
+ # Display LSTM Forecast Plot
154
+ st.subheader("LSTM Forecast")
155
+ st.line_chart(pd.concat([stock_prices, pd.Series(lstm_predictions, index=pd.date_range(start=stock_prices.index[-1], periods=lstm_gru_forecast_steps + 1, freq=stock_prices.index.freq))], axis=1).rename(columns={0: "LSTM Forecast"}))
156
+
157
+ # GRU Forecast using Streamlit button
158
+ if st.sidebar.button("Run GRU Forecast"):
159
+ gru_predictions = lstm_gru_forecast(stock_prices, 'gru', lstm_gru_forecast_steps)
160
+
161
+ # Display GRU Forecast Plot
162
+ st.subheader("GRU Forecast")
163
+ st.line_chart(pd.concat([stock_prices, pd.Series(gru_predictions, index=pd.date_range(start=stock_prices.index[-1], periods=lstm_gru_forecast_steps + 1, freq=stock_prices.index.freq))], axis=1).rename(columns={0: "GRU Forecast"}))
164
+
165
+ # Ensemble Forecast using Streamlit button
166
+ if st.sidebar.button("Run Ensemble Forecast"):
167
+ ensemble_predictions = ensemble_forecast([arima_predictions_custom, lstm_predictions, gru_predictions])
168
+
169
+ # Display Ensemble Forecast Plot
170
+ st.subheader("Ensemble Forecast")
171
+ st.line_chart(pd.concat([stock_prices, ensemble_predictions], axis=1).rename(columns={0: "Ensemble Forecast"}))
172
+
173
+ # Plotting Historical Stock Prices
174
+ st.subheader("Historical Stock Prices")
175
+ st.line_chart(stock_prices)
176
+
177
+ if __name__ == "__main__":
178
+ main()