Geek7 commited on
Commit
ebf57c0
·
verified ·
1 Parent(s): b8ca998

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ import yfinance as yf
5
+ from TSEnsemble.ensemble import Ensemble
6
+ from TSEnsemble import arima, nn, utils
7
+
8
+ # Function to load stock data using yfinance
9
+ def get_stock_data(symbol, start_date, end_date):
10
+ stock_data = yf.download(symbol, start=start_date, end=end_date)
11
+ return stock_data['Close']
12
+
13
+ # Load stock data
14
+ symbol = 'AAPL' # Replace with the desired stock symbol
15
+ start_date = '2020-01-01'
16
+ end_date = '2023-01-01'
17
+ stock_prices = get_stock_data(symbol, start_date, end_date)
18
+
19
+ # Set up ARIMA, CNN, LSTM, and Transformer models
20
+ ar = arima.auto_arima(stock_prices, method='stepwise', season=12, max_p=3, max_q=3, max_Q=3, max_P=3, train_split=0.8, plot=False)
21
+
22
+ transformer = nn.generate_transformer(
23
+ look_back=12,
24
+ horizon=1,
25
+ n_features=1,
26
+ num_transformer_blocks=4,
27
+ dropout=0.25,
28
+ head_size=256,
29
+ num_heads=4,
30
+ ff_dim=4,
31
+ mlp_units=[128],
32
+ mlp_dropout=0.4
33
+ )
34
+
35
+ lstm = nn.generate_rnn(look_back=12, hidden_layers=1, units=64, type="LSTM", dropout=0.0)
36
+
37
+ cnn = nn.generate_cnn(look_back=12, hidden_layers=3, kernel_size=2, filters=64, dilation_rate=1, dilation_mode="multiplicative")
38
+
39
+ # Create an ensemble model
40
+ ensemble_model = Ensemble(models=[ar, cnn, lstm, transformer], regressor='wmean')
41
+
42
+ # Fit the ensemble model
43
+ ensemble_model.fit(stock_prices, train_size=0.8, look_back=12, val_size=0.2, train_models_size=0.7, epochs=20, batch_size=16, metric="rmse")
44
+
45
+ # Forecast with the ensemble model
46
+ ensemble_forecast = ensemble_model.forecast(stock_prices, steps=12, fig_size=(10, 6))
47
+
48
+ # Streamlit app
49
+ st.title("Stock Price Prediction App")
50
+
51
+ # Display historical stock prices
52
+ st.subheader("Historical Stock Prices")
53
+ st.line_chart(stock_prices)
54
+
55
+ # Display ensemble forecast
56
+ st.subheader("Ensemble Forecast")
57
+ st.line_chart(ensemble_forecast)
58
+
59
+ # Display ARIMA forecast
60
+ arima_forecast = utils.model_forecast(ar, stock_prices, steps=12)
61
+ st.subheader("ARIMA Forecast")
62
+ st.line_chart(arima_forecast)