aiqtech commited on
Commit
0880176
·
verified ·
1 Parent(s): e266db4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -13
app.py CHANGED
@@ -91,15 +91,15 @@ def predict_future_prices(ticker, periods=1825):
91
  model_xgb.fit(X.reshape(-1, 1), y)
92
  future_xgb = pd.DataFrame({'ds': future_dates, 'yhat': model_xgb.predict(X_future)})
93
 
94
- # SVR 모델 생성 및 학습
95
- model_svr = SVR(kernel='rbf', C=1e3, gamma=0.1)
96
- model_svr.fit(X.reshape(-1, 1), y)
97
- future_svr = pd.DataFrame({'ds': future_dates, 'yhat': model_svr.predict(X_future)})
98
 
99
- # Bayesian Regression 모델 생성 및 학습
100
- model_bayes = BayesianRidge()
101
- model_bayes.fit(X.reshape(-1, 1), y)
102
- future_bayes = pd.DataFrame({'ds': future_dates, 'yhat': model_bayes.predict(X_future)})
103
 
104
 
105
  # 예측 결과 그래프 생성
@@ -110,8 +110,8 @@ def predict_future_prices(ticker, periods=1825):
110
  fig.add_trace(go.Scatter(x=future_arima['ds'], y=future_arima['yhat'], mode='lines', name='ARIMA Forecast (Green)', line=dict(color='green')))
111
  fig.add_trace(go.Scatter(x=future_lstm['ds'], y=future_lstm['yhat'], mode='lines', name='LSTM Forecast (Orange)', line=dict(color='orange')))
112
  fig.add_trace(go.Scatter(x=future_xgb['ds'], y=future_xgb['yhat'], mode='lines', name='XGBoost Forecast (Purple)', line=dict(color='purple')))
113
- fig.add_trace(go.Scatter(x=future_svr['ds'], y=future_svr['yhat'], mode='lines', name='SVR Forecast (Brown)', line=dict(color='brown')))
114
- fig.add_trace(go.Scatter(x=future_bayes['ds'], y=future_bayes['yhat'], mode='lines', name='Bayesian Regression Forecast (Pink)', line=dict(color='pink')))
115
  fig.add_trace(go.Scatter(x=data['ds'], y=data['y'], mode='lines', name='Actual (Black)', line=dict(color='black')))
116
 
117
  return fig, forecast_prophet[['ds', 'yhat', 'yhat_lower', 'yhat_upper']], future_lr[['ds', 'yhat']], future_arima[['ds', 'yhat']], future_lstm[['ds', 'yhat']], future_xgb[['ds', 'yhat']], future_svr[['ds', 'yhat']], future_bayes[['ds', 'yhat']]
@@ -130,13 +130,13 @@ with gr.Blocks() as app:
130
  forecast_data_arima = gr.Dataframe(label="ARIMA Forecast Data")
131
  forecast_data_lstm = gr.Dataframe(label="LSTM Forecast Data")
132
  forecast_data_xgb = gr.Dataframe(label="XGBoost Forecast Data")
133
- forecast_data_svr = gr.Dataframe(label="SVR Forecast Data")
134
- forecast_data_bayes = gr.Dataframe(label="Bayesian Regression Forecast Data")
135
 
136
  forecast_button.click(
137
  fn=predict_future_prices,
138
  inputs=[ticker_input, periods_input],
139
- outputs=[forecast_chart, forecast_data_prophet, forecast_data_lr, forecast_data_arima, forecast_data_lstm, forecast_data_xgb, forecast_data_svr, forecast_data_bayes]
140
  )
141
 
142
  app.launch()
 
91
  model_xgb.fit(X.reshape(-1, 1), y)
92
  future_xgb = pd.DataFrame({'ds': future_dates, 'yhat': model_xgb.predict(X_future)})
93
 
94
+ # # SVR 모델 생성 및 학습
95
+ # model_svr = SVR(kernel='rbf', C=1e3, gamma=0.1)
96
+ # model_svr.fit(X.reshape(-1, 1), y)
97
+ # future_svr = pd.DataFrame({'ds': future_dates, 'yhat': model_svr.predict(X_future)})
98
 
99
+ # # Bayesian Regression 모델 생성 및 학습
100
+ # model_bayes = BayesianRidge()
101
+ # model_bayes.fit(X.reshape(-1, 1), y)
102
+ # future_bayes = pd.DataFrame({'ds': future_dates, 'yhat': model_bayes.predict(X_future)})
103
 
104
 
105
  # 예측 결과 그래프 생성
 
110
  fig.add_trace(go.Scatter(x=future_arima['ds'], y=future_arima['yhat'], mode='lines', name='ARIMA Forecast (Green)', line=dict(color='green')))
111
  fig.add_trace(go.Scatter(x=future_lstm['ds'], y=future_lstm['yhat'], mode='lines', name='LSTM Forecast (Orange)', line=dict(color='orange')))
112
  fig.add_trace(go.Scatter(x=future_xgb['ds'], y=future_xgb['yhat'], mode='lines', name='XGBoost Forecast (Purple)', line=dict(color='purple')))
113
+ # fig.add_trace(go.Scatter(x=future_svr['ds'], y=future_svr['yhat'], mode='lines', name='SVR Forecast (Brown)', line=dict(color='brown')))
114
+ # fig.add_trace(go.Scatter(x=future_bayes['ds'], y=future_bayes['yhat'], mode='lines', name='Bayesian Regression Forecast (Pink)', line=dict(color='pink')))
115
  fig.add_trace(go.Scatter(x=data['ds'], y=data['y'], mode='lines', name='Actual (Black)', line=dict(color='black')))
116
 
117
  return fig, forecast_prophet[['ds', 'yhat', 'yhat_lower', 'yhat_upper']], future_lr[['ds', 'yhat']], future_arima[['ds', 'yhat']], future_lstm[['ds', 'yhat']], future_xgb[['ds', 'yhat']], future_svr[['ds', 'yhat']], future_bayes[['ds', 'yhat']]
 
130
  forecast_data_arima = gr.Dataframe(label="ARIMA Forecast Data")
131
  forecast_data_lstm = gr.Dataframe(label="LSTM Forecast Data")
132
  forecast_data_xgb = gr.Dataframe(label="XGBoost Forecast Data")
133
+ # forecast_data_svr = gr.Dataframe(label="SVR Forecast Data")
134
+ # forecast_data_bayes = gr.Dataframe(label="Bayesian Regression Forecast Data")
135
 
136
  forecast_button.click(
137
  fn=predict_future_prices,
138
  inputs=[ticker_input, periods_input],
139
+ outputs=[forecast_chart, forecast_data_prophet, forecast_data_lr, forecast_data_arima, forecast_data_lstm, forecast_data_xgb] #forecast_data_svr, forecast_data_bayes]
140
  )
141
 
142
  app.launch()