aiqtech commited on
Commit
a8444e5
·
verified ·
1 Parent(s): 0880176

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -5,7 +5,7 @@ from sklearn.linear_model import LinearRegression, BayesianRidge
5
  from sklearn.svm import SVR
6
  from sklearn.preprocessing import MinMaxScaler
7
  from statsmodels.tsa.arima.model import ARIMA
8
- from xgboost import XGBRegressor
9
  import pandas as pd
10
  import numpy as np
11
  from datetime import datetime
@@ -86,10 +86,10 @@ def predict_future_prices(ticker, periods=1825):
86
  pred_lstm = scaler.inverse_transform(np.array(pred_lstm).reshape(-1, 1))
87
  future_lstm = pd.DataFrame({'ds': future_dates[:len(pred_lstm)], 'yhat': pred_lstm.flatten()})
88
 
89
- # XGBoost 모델 생성 및 학습
90
- model_xgb = XGBRegressor(n_estimators=100, learning_rate=0.1)
91
- model_xgb.fit(X.reshape(-1, 1), y)
92
- future_xgb = pd.DataFrame({'ds': future_dates, 'yhat': model_xgb.predict(X_future)})
93
 
94
  # # SVR 모델 생성 및 학습
95
  # model_svr = SVR(kernel='rbf', C=1e3, gamma=0.1)
@@ -109,7 +109,7 @@ def predict_future_prices(ticker, periods=1825):
109
  fig.add_trace(go.Scatter(x=future_lr['ds'], y=future_lr['yhat'], mode='lines', name='Linear Regression Forecast (Red)', line=dict(color='red')))
110
  fig.add_trace(go.Scatter(x=future_arima['ds'], y=future_arima['yhat'], mode='lines', name='ARIMA Forecast (Green)', line=dict(color='green')))
111
  fig.add_trace(go.Scatter(x=future_lstm['ds'], y=future_lstm['yhat'], mode='lines', name='LSTM Forecast (Orange)', line=dict(color='orange')))
112
- fig.add_trace(go.Scatter(x=future_xgb['ds'], y=future_xgb['yhat'], mode='lines', name='XGBoost Forecast (Purple)', line=dict(color='purple')))
113
  # fig.add_trace(go.Scatter(x=future_svr['ds'], y=future_svr['yhat'], mode='lines', name='SVR Forecast (Brown)', line=dict(color='brown')))
114
  # fig.add_trace(go.Scatter(x=future_bayes['ds'], y=future_bayes['yhat'], mode='lines', name='Bayesian Regression Forecast (Pink)', line=dict(color='pink')))
115
  fig.add_trace(go.Scatter(x=data['ds'], y=data['y'], mode='lines', name='Actual (Black)', line=dict(color='black')))
@@ -129,14 +129,14 @@ with gr.Blocks() as app:
129
  forecast_data_lr = gr.Dataframe(label="Linear Regression Forecast Data")
130
  forecast_data_arima = gr.Dataframe(label="ARIMA Forecast Data")
131
  forecast_data_lstm = gr.Dataframe(label="LSTM Forecast Data")
132
- forecast_data_xgb = gr.Dataframe(label="XGBoost Forecast Data")
133
  # forecast_data_svr = gr.Dataframe(label="SVR Forecast Data")
134
  # forecast_data_bayes = gr.Dataframe(label="Bayesian Regression Forecast Data")
135
 
136
  forecast_button.click(
137
  fn=predict_future_prices,
138
  inputs=[ticker_input, periods_input],
139
- outputs=[forecast_chart, forecast_data_prophet, forecast_data_lr, forecast_data_arima, forecast_data_lstm, forecast_data_xgb] #forecast_data_svr, forecast_data_bayes]
140
  )
141
 
142
  app.launch()
 
5
  from sklearn.svm import SVR
6
  from sklearn.preprocessing import MinMaxScaler
7
  from statsmodels.tsa.arima.model import ARIMA
8
+ #from xgboost import XGBRegressor
9
  import pandas as pd
10
  import numpy as np
11
  from datetime import datetime
 
86
  pred_lstm = scaler.inverse_transform(np.array(pred_lstm).reshape(-1, 1))
87
  future_lstm = pd.DataFrame({'ds': future_dates[:len(pred_lstm)], 'yhat': pred_lstm.flatten()})
88
 
89
+ # # XGBoost 모델 생성 및 학습
90
+ # model_xgb = XGBRegressor(n_estimators=100, learning_rate=0.1)
91
+ # model_xgb.fit(X.reshape(-1, 1), y)
92
+ # future_xgb = pd.DataFrame({'ds': future_dates, 'yhat': model_xgb.predict(X_future)})
93
 
94
  # # SVR 모델 생성 및 학습
95
  # model_svr = SVR(kernel='rbf', C=1e3, gamma=0.1)
 
109
  fig.add_trace(go.Scatter(x=future_lr['ds'], y=future_lr['yhat'], mode='lines', name='Linear Regression Forecast (Red)', line=dict(color='red')))
110
  fig.add_trace(go.Scatter(x=future_arima['ds'], y=future_arima['yhat'], mode='lines', name='ARIMA Forecast (Green)', line=dict(color='green')))
111
  fig.add_trace(go.Scatter(x=future_lstm['ds'], y=future_lstm['yhat'], mode='lines', name='LSTM Forecast (Orange)', line=dict(color='orange')))
112
+ # fig.add_trace(go.Scatter(x=future_xgb['ds'], y=future_xgb['yhat'], mode='lines', name='XGBoost Forecast (Purple)', line=dict(color='purple')))
113
  # fig.add_trace(go.Scatter(x=future_svr['ds'], y=future_svr['yhat'], mode='lines', name='SVR Forecast (Brown)', line=dict(color='brown')))
114
  # fig.add_trace(go.Scatter(x=future_bayes['ds'], y=future_bayes['yhat'], mode='lines', name='Bayesian Regression Forecast (Pink)', line=dict(color='pink')))
115
  fig.add_trace(go.Scatter(x=data['ds'], y=data['y'], mode='lines', name='Actual (Black)', line=dict(color='black')))
 
129
  forecast_data_lr = gr.Dataframe(label="Linear Regression Forecast Data")
130
  forecast_data_arima = gr.Dataframe(label="ARIMA Forecast Data")
131
  forecast_data_lstm = gr.Dataframe(label="LSTM Forecast Data")
132
+ # forecast_data_xgb = gr.Dataframe(label="XGBoost Forecast Data")
133
  # forecast_data_svr = gr.Dataframe(label="SVR Forecast Data")
134
  # forecast_data_bayes = gr.Dataframe(label="Bayesian Regression Forecast Data")
135
 
136
  forecast_button.click(
137
  fn=predict_future_prices,
138
  inputs=[ticker_input, periods_input],
139
+ outputs=[forecast_chart, forecast_data_prophet, forecast_data_lr, forecast_data_arima, forecast_data_lstm] #,forecast_data_xgb, forecast_data_svr, forecast_data_bayes]
140
  )
141
 
142
  app.launch()