3v324v23 commited on
Commit
5f2dd65
·
1 Parent(s): 54d3c06
streamlit_simulation/app.py CHANGED
@@ -142,23 +142,29 @@ def load_lightgbm_model():
142
  def load_transformer_model_and_dataset():
143
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
144
 
145
- # Load model
146
- model = load_moment_model()
147
- checkpoint_path = hf_hub_download(
148
- repo_id="dlaj/energy-forecasting-files",
149
- filename="transformer_model/model_final.pth",
150
- repo_type="dataset"
151
- )
152
- model.load_state_dict(torch.load(checkpoint_path, map_location=device))
153
- model.to(device)
154
- model.eval()
 
 
155
 
156
- # Datasets
157
- train_dataset = InformerDataset(data_split="train", forecast_horizon=FORECAST_HORIZON, random_seed=13)
158
- test_dataset = InformerDataset(data_split="test", forecast_horizon=FORECAST_HORIZON, random_seed=13)
159
- test_dataset.scaler = train_dataset.scaler
160
 
161
- return model, test_dataset, device
 
 
 
 
162
 
163
  @st.cache_data
164
  def load_data():
 
142
  def load_transformer_model_and_dataset():
143
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
144
 
145
+ try:
146
+ with st.spinner("🔄 Loading transformer model..."):
147
+ # Load model
148
+ model = load_moment_model()
149
+ checkpoint_path = hf_hub_download(
150
+ repo_id="dlaj/energy-forecasting-files",
151
+ filename="transformer_model/model_final.pth",
152
+ repo_type="dataset"
153
+ )
154
+ model.load_state_dict(torch.load(checkpoint_path, map_location=device))
155
+ model.to(device)
156
+ model.eval()
157
 
158
+ # Datasets
159
+ train_dataset = InformerDataset(data_split="train", forecast_horizon=FORECAST_HORIZON, random_seed=13)
160
+ test_dataset = InformerDataset(data_split="test", forecast_horizon=FORECAST_HORIZON, random_seed=13)
161
+ test_dataset.scaler = train_dataset.scaler
162
 
163
+ return model, test_dataset, device
164
+
165
+ except Exception as e:
166
+ st.error(f"❌ Fehler beim Laden des Transformer-Modells: {e}")
167
+ raise e
168
 
169
  @st.cache_data
170
  def load_data():
streamlit_simulation/app_backup_hug.py CHANGED
@@ -132,23 +132,29 @@ def load_lightgbm_model():
132
  def load_transformer_model_and_dataset():
133
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
134
 
135
- # Load model
136
- model = load_moment_model()
137
- checkpoint_path = hf_hub_download(
138
- repo_id="dlaj/energy-forecasting-files",
139
- filename="transformer_model/model_final.pth",
140
- repo_type="dataset"
141
- )
142
- model.load_state_dict(torch.load(checkpoint_path, map_location=device))
143
- model.to(device)
144
- model.eval()
 
 
145
 
146
- # Datasets
147
- train_dataset = InformerDataset(data_split="train", forecast_horizon=FORECAST_HORIZON, random_seed=13)
148
- test_dataset = InformerDataset(data_split="test", forecast_horizon=FORECAST_HORIZON, random_seed=13)
149
- test_dataset.scaler = train_dataset.scaler
150
 
151
- return model, test_dataset, device
 
 
 
 
152
 
153
  @st.cache_data
154
  def load_data():
 
132
  def load_transformer_model_and_dataset():
133
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
134
 
135
+ try:
136
+ with st.spinner("🔄 Loading transformer model..."):
137
+ # Load model
138
+ model = load_moment_model()
139
+ checkpoint_path = hf_hub_download(
140
+ repo_id="dlaj/energy-forecasting-files",
141
+ filename="transformer_model/model_final.pth",
142
+ repo_type="dataset"
143
+ )
144
+ model.load_state_dict(torch.load(checkpoint_path, map_location=device))
145
+ model.to(device)
146
+ model.eval()
147
 
148
+ # Datasets
149
+ train_dataset = InformerDataset(data_split="train", forecast_horizon=FORECAST_HORIZON, random_seed=13)
150
+ test_dataset = InformerDataset(data_split="test", forecast_horizon=FORECAST_HORIZON, random_seed=13)
151
+ test_dataset.scaler = train_dataset.scaler
152
 
153
+ return model, test_dataset, device
154
+
155
+ except Exception as e:
156
+ st.error(f"❌ Fehler beim Laden des Transformer-Modells: {e}")
157
+ raise e
158
 
159
  @st.cache_data
160
  def load_data():