azrai99 commited on
Commit
a3dc2e1
·
verified ·
1 Parent(s): 21479cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -9
app.py CHANGED
@@ -251,17 +251,19 @@ def transfer_learning_forecasting():
251
  nhits_model, timesnet_model, lstm_model, tft_model = select_model_based_on_frequency(frequency, nhits_models, timesnet_models, lstm_models, tft_models)
252
  forecast_results = {}
253
 
254
- start_time = time.time() # Start timing
255
- if model_choice == "NHITS":
256
- forecast_results['NHITS'] = generate_forecast(nhits_model, df)
257
- elif model_choice == "TimesNet":
258
- forecast_results['TimesNet'] = generate_forecast(timesnet_model, df)
259
- elif model_choice == "LSTM":
260
- forecast_results['LSTM'] = generate_forecast(lstm_model, df)
261
- elif model_choice == "TFT":
262
- forecast_results['TFT'] = generate_forecast(tft_model, df)
263
 
264
  if st.sidebar.button("Submit"):
 
 
 
 
 
 
 
 
 
 
265
  for model_name, forecast_df in forecast_results.items():
266
  plot_forecasts(forecast_df, df, f'{model_name} Forecast for {y_col}')
267
 
 
251
  nhits_model, timesnet_model, lstm_model, tft_model = select_model_based_on_frequency(frequency, nhits_models, timesnet_models, lstm_models, tft_models)
252
  forecast_results = {}
253
 
254
+
 
 
 
 
 
 
 
 
255
 
256
  if st.sidebar.button("Submit"):
257
+ start_time = time.time() # Start timing
258
+ if model_choice == "NHITS":
259
+ forecast_results['NHITS'] = generate_forecast(nhits_model, df)
260
+ elif model_choice == "TimesNet":
261
+ forecast_results['TimesNet'] = generate_forecast(timesnet_model, df)
262
+ elif model_choice == "LSTM":
263
+ forecast_results['LSTM'] = generate_forecast(lstm_model, df)
264
+ elif model_choice == "TFT":
265
+ forecast_results['TFT'] = generate_forecast(tft_model, df)
266
+
267
  for model_name, forecast_df in forecast_results.items():
268
  plot_forecasts(forecast_df, df, f'{model_name} Forecast for {y_col}')
269