aiqtech commited on
Commit
f74a35a
ยท
verified ยท
1 Parent(s): c9dcef9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -77
app.py CHANGED
@@ -1,16 +1,10 @@
1
  import gradio as gr
2
  import yfinance as yf
3
  from prophet import Prophet
4
- from sklearn.linear_model import LinearRegression, BayesianRidge
5
- from sklearn.svm import SVR
6
- from sklearn.preprocessing import MinMaxScaler
7
- from statsmodels.tsa.arima.model import ARIMA
8
  import pandas as pd
9
- import numpy as np
10
  from datetime import datetime
11
  import plotly.graph_objects as go
12
- from tensorflow.keras.models import Sequential
13
- from tensorflow.keras.layers import LSTM, Dense
14
 
15
  def download_data(ticker, start_date='2010-01-01'):
16
  """
@@ -40,82 +34,26 @@ def predict_future_prices(ticker, periods=1825):
40
 
41
  # Linear Regression ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ ํ•™์Šต
42
  model_lr = LinearRegression()
43
- X = pd.Series(range(len(data))).values.reshape(-1, 1)
44
  y = data['y'].values
45
- model_lr.fit(X, y)
46
 
47
  # ๋ฏธ๋ž˜ ๋ฐ์ดํ„ฐ ํ”„๋ ˆ์ž„ ์ƒ์„ฑ ๋ฐ ์˜ˆ์ธก
48
- future_dates = pd.date_range(start=data['ds'].iloc[-1], periods=periods+1, freq='D')[1:].strftime('%Y-%m-%d')
49
- X_future = pd.Series(range(len(data), len(data) + len(future_dates))).values.reshape(-1, 1)
50
- future_lr = pd.DataFrame({'ds': future_dates, 'yhat': model_lr.predict(X_future)})
 
 
51
 
52
-
53
- # ARIMA ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ ํ•™์Šต
54
- model_arima = ARIMA(data['y'], order=(1, 1, 1))
55
- model_arima_fit = model_arima.fit()
56
- forecast_arima = model_arima_fit.forecast(steps=periods)
57
- future_arima = pd.DataFrame({'ds': future_dates, 'yhat': forecast_arima})
58
-
59
-
60
- # LSTM ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ ํ•™์Šต
61
- scaler = MinMaxScaler(feature_range=(0, 1))
62
- scaled_data = scaler.fit_transform(data['y'].values.reshape(-1, 1))
63
- X_train, y_train = [], []
64
- for i in range(60, len(scaled_data)):
65
- X_train.append(scaled_data[i-60:i, 0])
66
- y_train.append(scaled_data[i, 0])
67
- X_train, y_train = np.array(X_train), np.array(y_train)
68
- X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], 1))
69
-
70
- model_lstm = Sequential()
71
- model_lstm.add(LSTM(units=50, return_sequences=True, input_shape=(X_train.shape[1], 1)))
72
- model_lstm.add(LSTM(units=50))
73
- model_lstm.add(Dense(1))
74
- model_lstm.compile(loss='mean_squared_error', optimizer='adam')
75
- model_lstm.fit(X_train, y_train, epochs=10, batch_size=32)
76
-
77
- pred_lstm = []
78
- last_60_days = scaled_data[-60:]
79
- for i in range(periods):
80
- X_test = last_60_days.reshape(1, 60, 1)
81
- pred = model_lstm.predict(X_test)
82
- last_60_days = np.append(last_60_days[1:], pred)
83
- pred_lstm.append(pred[0, 0])
84
-
85
- pred_lstm = scaler.inverse_transform(np.array(pred_lstm).reshape(-1, 1))
86
- future_lstm = pd.DataFrame({'ds': future_dates[:len(pred_lstm)], 'yhat': pred_lstm.flatten()})
87
-
88
- # # XGBoost ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ ํ•™์Šต
89
- # model_xgb = XGBRegressor(n_estimators=100, learning_rate=0.1)
90
- # model_xgb.fit(X.reshape(-1, 1), y)
91
- # future_xgb = pd.DataFrame({'ds': future_dates, 'yhat': model_xgb.predict(X_future)})
92
-
93
- # # SVR ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ ํ•™์Šต
94
- # model_svr = SVR(kernel='rbf', C=1e3, gamma=0.1)
95
- # model_svr.fit(X.reshape(-1, 1), y)
96
- # future_svr = pd.DataFrame({'ds': future_dates, 'yhat': model_svr.predict(X_future)})
97
-
98
- # # Bayesian Regression ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ ํ•™์Šต
99
- # model_bayes = BayesianRidge()
100
- # model_bayes.fit(X.reshape(-1, 1), y)
101
- # future_bayes = pd.DataFrame({'ds': future_dates, 'yhat': model_bayes.predict(X_future)})
102
-
103
-
104
  # ์˜ˆ์ธก ๊ฒฐ๊ณผ ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ
105
  forecast_prophet['ds'] = forecast_prophet['ds'].dt.strftime('%Y-%m-%d')
106
  fig = go.Figure()
107
  fig.add_trace(go.Scatter(x=forecast_prophet['ds'], y=forecast_prophet['yhat'], mode='lines', name='Prophet Forecast (Blue)'))
108
  fig.add_trace(go.Scatter(x=future_lr['ds'], y=future_lr['yhat'], mode='lines', name='Linear Regression Forecast (Red)', line=dict(color='red')))
109
- fig.add_trace(go.Scatter(x=future_arima['ds'], y=future_arima['yhat'], mode='lines', name='ARIMA Forecast (Green)', line=dict(color='green')))
110
- fig.add_trace(go.Scatter(x=future_lstm['ds'], y=future_lstm['yhat'], mode='lines', name='LSTM Forecast (Orange)', line=dict(color='orange')))
111
- # fig.add_trace(go.Scatter(x=future_xgb['ds'], y=future_xgb['yhat'], mode='lines', name='XGBoost Forecast (Purple)', line=dict(color='purple')))
112
- # fig.add_trace(go.Scatter(x=future_svr['ds'], y=future_svr['yhat'], mode='lines', name='SVR Forecast (Brown)', line=dict(color='brown')))
113
- # fig.add_trace(go.Scatter(x=future_bayes['ds'], y=future_bayes['yhat'], mode='lines', name='Bayesian Regression Forecast (Pink)', line=dict(color='pink')))
114
  fig.add_trace(go.Scatter(x=data['ds'], y=data['y'], mode='lines', name='Actual (Black)', line=dict(color='black')))
115
 
116
- return fig, forecast_prophet[['ds', 'yhat', 'yhat_lower', 'yhat_upper']], future_lr[['ds', 'yhat']], future_arima[['ds', 'yhat']], future_lstm[['ds', 'yhat']], future_xgb[['ds', 'yhat']], future_svr[['ds', 'yhat']], future_bayes[['ds', 'yhat']]
117
 
118
-
119
  # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ • ๋ฐ ์‹คํ–‰
120
  with gr.Blocks() as app:
121
  with gr.Row():
@@ -126,16 +64,11 @@ with gr.Blocks() as app:
126
  forecast_chart = gr.Plot(label="Forecast Chart")
127
  forecast_data_prophet = gr.Dataframe(label="Prophet Forecast Data")
128
  forecast_data_lr = gr.Dataframe(label="Linear Regression Forecast Data")
129
- forecast_data_arima = gr.Dataframe(label="ARIMA Forecast Data")
130
- forecast_data_lstm = gr.Dataframe(label="LSTM Forecast Data")
131
- # forecast_data_xgb = gr.Dataframe(label="XGBoost Forecast Data")
132
- # forecast_data_svr = gr.Dataframe(label="SVR Forecast Data")
133
- # forecast_data_bayes = gr.Dataframe(label="Bayesian Regression Forecast Data")
134
 
135
  forecast_button.click(
136
  fn=predict_future_prices,
137
  inputs=[ticker_input, periods_input],
138
- outputs=[forecast_chart, forecast_data_prophet, forecast_data_lr, forecast_data_arima, forecast_data_lstm] #,forecast_data_xgb, forecast_data_svr, forecast_data_bayes]
139
  )
140
 
141
  app.launch()
 
1
  import gradio as gr
2
  import yfinance as yf
3
  from prophet import Prophet
4
+ from sklearn.linear_model import LinearRegression
 
 
 
5
  import pandas as pd
 
6
  from datetime import datetime
7
  import plotly.graph_objects as go
 
 
8
 
9
  def download_data(ticker, start_date='2010-01-01'):
10
  """
 
34
 
35
  # Linear Regression ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ ํ•™์Šต
36
  model_lr = LinearRegression()
37
+ X = pd.to_numeric(pd.Series(range(len(data))))
38
  y = data['y'].values
39
+ model_lr.fit(X.values.reshape(-1, 1), y)
40
 
41
  # ๋ฏธ๋ž˜ ๋ฐ์ดํ„ฐ ํ”„๋ ˆ์ž„ ์ƒ์„ฑ ๋ฐ ์˜ˆ์ธก
42
+ future_dates = pd.date_range(start=data['ds'].iloc[-1], periods=periods+1, freq='D')[1:]
43
+ future_lr = pd.DataFrame({'ds': future_dates})
44
+ future_lr['ds'] = future_lr['ds'].dt.strftime('%Y-%m-%d')
45
+ X_future = pd.to_numeric(pd.Series(range(len(data), len(data) + len(future_lr))))
46
+ future_lr['yhat'] = model_lr.predict(X_future.values.reshape(-1, 1))
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  # ์˜ˆ์ธก ๊ฒฐ๊ณผ ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ
49
  forecast_prophet['ds'] = forecast_prophet['ds'].dt.strftime('%Y-%m-%d')
50
  fig = go.Figure()
51
  fig.add_trace(go.Scatter(x=forecast_prophet['ds'], y=forecast_prophet['yhat'], mode='lines', name='Prophet Forecast (Blue)'))
52
  fig.add_trace(go.Scatter(x=future_lr['ds'], y=future_lr['yhat'], mode='lines', name='Linear Regression Forecast (Red)', line=dict(color='red')))
 
 
 
 
 
53
  fig.add_trace(go.Scatter(x=data['ds'], y=data['y'], mode='lines', name='Actual (Black)', line=dict(color='black')))
54
 
55
+ return fig, forecast_prophet[['ds', 'yhat', 'yhat_lower', 'yhat_upper']], future_lr[['ds', 'yhat']]
56
 
 
57
  # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ • ๋ฐ ์‹คํ–‰
58
  with gr.Blocks() as app:
59
  with gr.Row():
 
64
  forecast_chart = gr.Plot(label="Forecast Chart")
65
  forecast_data_prophet = gr.Dataframe(label="Prophet Forecast Data")
66
  forecast_data_lr = gr.Dataframe(label="Linear Regression Forecast Data")
 
 
 
 
 
67
 
68
  forecast_button.click(
69
  fn=predict_future_prices,
70
  inputs=[ticker_input, periods_input],
71
+ outputs=[forecast_chart, forecast_data_prophet, forecast_data_lr]
72
  )
73
 
74
  app.launch()