aiqtech commited on
Commit
ae75bb8
ยท
verified ยท
1 Parent(s): b7d0c91

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -3
app.py CHANGED
@@ -1,10 +1,17 @@
1
  import gradio as gr
2
  import yfinance as yf
3
  from prophet import Prophet
4
- from sklearn.linear_model import LinearRegression
 
 
 
 
5
  import pandas as pd
 
6
  from datetime import datetime
7
  import plotly.graph_objects as go
 
 
8
 
9
  def download_data(ticker, start_date='2010-01-01'):
10
  """
@@ -45,14 +52,67 @@ def predict_future_prices(ticker, periods=1825):
45
  X_future = pd.to_numeric(pd.Series(range(len(data), len(data) + len(future_lr))))
46
  future_lr['yhat'] = model_lr.predict(X_future.values.reshape(-1, 1))
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  # ์˜ˆ์ธก ๊ฒฐ๊ณผ ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ
49
  forecast_prophet['ds'] = forecast_prophet['ds'].dt.strftime('%Y-%m-%d')
50
  fig = go.Figure()
51
  fig.add_trace(go.Scatter(x=forecast_prophet['ds'], y=forecast_prophet['yhat'], mode='lines', name='Prophet Forecast (Blue)'))
52
  fig.add_trace(go.Scatter(x=future_lr['ds'], y=future_lr['yhat'], mode='lines', name='Linear Regression Forecast (Red)', line=dict(color='red')))
 
 
 
 
 
53
  fig.add_trace(go.Scatter(x=data['ds'], y=data['y'], mode='lines', name='Actual (Black)', line=dict(color='black')))
54
 
55
- return fig, forecast_prophet[['ds', 'yhat', 'yhat_lower', 'yhat_upper']], future_lr[['ds', 'yhat']]
56
 
57
  # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ • ๋ฐ ์‹คํ–‰
58
  with gr.Blocks() as app:
@@ -64,11 +124,16 @@ with gr.Blocks() as app:
64
  forecast_chart = gr.Plot(label="Forecast Chart")
65
  forecast_data_prophet = gr.Dataframe(label="Prophet Forecast Data")
66
  forecast_data_lr = gr.Dataframe(label="Linear Regression Forecast Data")
 
 
 
 
 
67
 
68
  forecast_button.click(
69
  fn=predict_future_prices,
70
  inputs=[ticker_input, periods_input],
71
- outputs=[forecast_chart, forecast_data_prophet, forecast_data_lr]
72
  )
73
 
74
  app.launch()
 
1
  import gradio as gr
2
  import yfinance as yf
3
  from prophet import Prophet
4
+ from sklearn.linear_model import LinearRegression, BayesianRidge
5
+ from sklearn.svm import SVR
6
+ from sklearn.preprocessing import MinMaxScaler
7
+ from statsmodels.tsa.arima.model import ARIMA
8
+ from xgboost import XGBRegressor
9
  import pandas as pd
10
+ import numpy as np
11
  from datetime import datetime
12
  import plotly.graph_objects as go
13
+ from tensorflow.keras.models import Sequential
14
+ from tensorflow.keras.layers import LSTM, Dense
15
 
16
  def download_data(ticker, start_date='2010-01-01'):
17
  """
 
52
  X_future = pd.to_numeric(pd.Series(range(len(data), len(data) + len(future_lr))))
53
  future_lr['yhat'] = model_lr.predict(X_future.values.reshape(-1, 1))
54
 
55
+ # ARIMA ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ ํ•™์Šต
56
+ model_arima = ARIMA(data['y'], order=(1, 1, 1))
57
+ model_arima_fit = model_arima.fit()
58
+ forecast_arima = model_arima_fit.forecast(steps=periods)
59
+ future_arima = pd.DataFrame({'ds': future_dates, 'yhat': forecast_arima})
60
+
61
+ # LSTM ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ ํ•™์Šต
62
+ scaler = MinMaxScaler(feature_range=(0, 1))
63
+ scaled_data = scaler.fit_transform(data['y'].values.reshape(-1, 1))
64
+ X_train, y_train = [], []
65
+ for i in range(60, len(scaled_data)):
66
+ X_train.append(scaled_data[i-60:i, 0])
67
+ y_train.append(scaled_data[i, 0])
68
+ X_train, y_train = np.array(X_train), np.array(y_train)
69
+ X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], 1))
70
+
71
+ model_lstm = Sequential()
72
+ model_lstm.add(LSTM(units=50, return_sequences=True, input_shape=(X_train.shape[1], 1)))
73
+ model_lstm.add(LSTM(units=50))
74
+ model_lstm.add(Dense(1))
75
+ model_lstm.compile(loss='mean_squared_error', optimizer='adam')
76
+ model_lstm.fit(X_train, y_train, epochs=10, batch_size=32)
77
+
78
+ last_60_days = data['y'][-60:].values
79
+ scaled_last_60_days = scaler.transform(last_60_days.reshape(-1, 1))
80
+ X_test = []
81
+ X_test.append(scaled_last_60_days)
82
+ X_test = np.array(X_test)
83
+ X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1))
84
+ pred_lstm = model_lstm.predict(X_test)
85
+ pred_lstm = scaler.inverse_transform(pred_lstm)
86
+ future_lstm = pd.DataFrame({'ds': future_dates[:periods], 'yhat': pred_lstm.flatten()})
87
+
88
+ # XGBoost ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ ํ•™์Šต
89
+ model_xgb = XGBRegressor(n_estimators=100, learning_rate=0.1)
90
+ model_xgb.fit(X.values.reshape(-1, 1), y)
91
+ future_xgb = pd.DataFrame({'ds': future_dates, 'yhat': model_xgb.predict(X_future.values.reshape(-1, 1))})
92
+
93
+ # SVR ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ ํ•™์Šต
94
+ model_svr = SVR(kernel='rbf', C=1e3, gamma=0.1)
95
+ model_svr.fit(X.values.reshape(-1, 1), y)
96
+ future_svr = pd.DataFrame({'ds': future_dates, 'yhat': model_svr.predict(X_future.values.reshape(-1, 1))})
97
+
98
+ # Bayesian Regression ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ ํ•™์Šต
99
+ model_bayes = BayesianRidge()
100
+ model_bayes.fit(X.values.reshape(-1, 1), y)
101
+ future_bayes = pd.DataFrame({'ds': future_dates, 'yhat': model_bayes.predict(X_future.values.reshape(-1, 1))})
102
+
103
  # ์˜ˆ์ธก ๊ฒฐ๊ณผ ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ
104
  forecast_prophet['ds'] = forecast_prophet['ds'].dt.strftime('%Y-%m-%d')
105
  fig = go.Figure()
106
  fig.add_trace(go.Scatter(x=forecast_prophet['ds'], y=forecast_prophet['yhat'], mode='lines', name='Prophet Forecast (Blue)'))
107
  fig.add_trace(go.Scatter(x=future_lr['ds'], y=future_lr['yhat'], mode='lines', name='Linear Regression Forecast (Red)', line=dict(color='red')))
108
+ fig.add_trace(go.Scatter(x=future_arima['ds'], y=future_arima['yhat'], mode='lines', name='ARIMA Forecast (Green)', line=dict(color='green')))
109
+ fig.add_trace(go.Scatter(x=future_lstm['ds'], y=future_lstm['yhat'], mode='lines', name='LSTM Forecast (Orange)', line=dict(color='orange')))
110
+ fig.add_trace(go.Scatter(x=future_xgb['ds'], y=future_xgb['yhat'], mode='lines', name='XGBoost Forecast (Purple)', line=dict(color='purple')))
111
+ fig.add_trace(go.Scatter(x=future_svr['ds'], y=future_svr['yhat'], mode='lines', name='SVR Forecast (Brown)', line=dict(color='brown')))
112
+ fig.add_trace(go.Scatter(x=future_bayes['ds'], y=future_bayes['yhat'], mode='lines', name='Bayesian Regression Forecast (Pink)', line=dict(color='pink')))
113
  fig.add_trace(go.Scatter(x=data['ds'], y=data['y'], mode='lines', name='Actual (Black)', line=dict(color='black')))
114
 
115
+ return fig, forecast_prophet[['ds', 'yhat', 'yhat_lower', 'yhat_upper']], future_lr[['ds', 'yhat']], future_arima[['ds', 'yhat']], future_lstm[['ds', 'yhat']], future_xgb[['ds', 'yhat']], future_svr[['ds', 'yhat']], future_bayes[['ds', 'yhat']]
116
 
117
  # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ • ๋ฐ ์‹คํ–‰
118
  with gr.Blocks() as app:
 
124
  forecast_chart = gr.Plot(label="Forecast Chart")
125
  forecast_data_prophet = gr.Dataframe(label="Prophet Forecast Data")
126
  forecast_data_lr = gr.Dataframe(label="Linear Regression Forecast Data")
127
+ forecast_data_arima = gr.Dataframe(label="ARIMA Forecast Data")
128
+ forecast_data_lstm = gr.Dataframe(label="LSTM Forecast Data")
129
+ forecast_data_xgb = gr.Dataframe(label="XGBoost Forecast Data")
130
+ forecast_data_svr = gr.Dataframe(label="SVR Forecast Data")
131
+ forecast_data_bayes = gr.Dataframe(label="Bayesian Regression Forecast Data")
132
 
133
  forecast_button.click(
134
  fn=predict_future_prices,
135
  inputs=[ticker_input, periods_input],
136
+ outputs=[forecast_chart, forecast_data_prophet, forecast_data_lr, forecast_data_arima, forecast_data_lstm, forecast_data_xgb, forecast_data_svr, forecast_data_bayes]
137
  )
138
 
139
  app.launch()