3v324v23 commited on
Commit
17380d7
·
1 Parent(s): 5f2dd65

tryout no plot

Browse files
streamlit_simulation/app_backup_hug.py CHANGED
@@ -132,29 +132,23 @@ def load_lightgbm_model():
132
  def load_transformer_model_and_dataset():
133
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
134
 
135
- try:
136
- with st.spinner("🔄 Loading transformer model..."):
137
- # Load model
138
- model = load_moment_model()
139
- checkpoint_path = hf_hub_download(
140
- repo_id="dlaj/energy-forecasting-files",
141
- filename="transformer_model/model_final.pth",
142
- repo_type="dataset"
143
- )
144
- model.load_state_dict(torch.load(checkpoint_path, map_location=device))
145
- model.to(device)
146
- model.eval()
147
 
148
- # Datasets
149
- train_dataset = InformerDataset(data_split="train", forecast_horizon=FORECAST_HORIZON, random_seed=13)
150
- test_dataset = InformerDataset(data_split="test", forecast_horizon=FORECAST_HORIZON, random_seed=13)
151
- test_dataset.scaler = train_dataset.scaler
152
 
153
- return model, test_dataset, device
154
-
155
- except Exception as e:
156
- st.error(f"❌ Fehler beim Laden des Transformer-Modells: {e}")
157
- raise e
158
 
159
  @st.cache_data
160
  def load_data():
 
132
  def load_transformer_model_and_dataset():
133
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
134
 
135
+ # Load model
136
+ model = load_moment_model()
137
+ checkpoint_path = hf_hub_download(
138
+ repo_id="dlaj/energy-forecasting-files",
139
+ filename="transformer_model/model_final.pth",
140
+ repo_type="dataset"
141
+ )
142
+ model.load_state_dict(torch.load(checkpoint_path, map_location=device))
143
+ model.to(device)
144
+ model.eval()
 
 
145
 
146
+ # Datasets
147
+ train_dataset = InformerDataset(data_split="train", forecast_horizon=FORECAST_HORIZON, random_seed=13)
148
+ test_dataset = InformerDataset(data_split="test", forecast_horizon=FORECAST_HORIZON, random_seed=13)
149
+ test_dataset.scaler = train_dataset.scaler
150
 
151
+ return model, test_dataset, device
 
 
 
 
152
 
153
  @st.cache_data
154
  def load_data():