3v324v23 commited on
Commit
0a66c45
·
1 Parent(s): 4a62848

tryout no plot

Browse files
streamlit_simulation/app.py CHANGED
@@ -142,29 +142,29 @@ def load_lightgbm_model():
142
  def load_transformer_model_and_dataset():
143
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
144
 
145
- try:
146
- with st.spinner("🔄 Loading transformer model..."):
147
- # Load model
148
- model = load_moment_model()
149
- checkpoint_path = hf_hub_download(
150
- repo_id="dlaj/energy-forecasting-files",
151
- filename="transformer_model/model_final.pth",
152
- repo_type="dataset"
153
- )
154
- model.load_state_dict(torch.load(checkpoint_path, map_location=device))
155
- model.to(device)
156
- model.eval()
157
 
158
- # Datasets
159
- train_dataset = InformerDataset(data_split="train", forecast_horizon=FORECAST_HORIZON, random_seed=13)
160
- test_dataset = InformerDataset(data_split="test", forecast_horizon=FORECAST_HORIZON, random_seed=13)
161
- test_dataset.scaler = train_dataset.scaler
 
 
 
 
 
 
 
 
162
 
163
- return model, test_dataset, device
164
-
165
- except Exception as e:
166
- st.error(f"❌ Fehler beim Laden des Transformer-Modells: {e}")
167
- raise e
168
 
169
  @st.cache_data
170
  def load_data():
 
142
  def load_transformer_model_and_dataset():
143
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
144
 
145
+ model = load_moment_model()
146
+ checkpoint_path = hf_hub_download(
147
+ repo_id="dlaj/energy-forecasting-files",
148
+ filename="transformer_model/model_final.pth",
149
+ repo_type="dataset"
150
+ )
151
+ model.load_state_dict(torch.load(checkpoint_path, map_location=device))
152
+ model.to(device)
153
+ model.eval()
 
 
 
154
 
155
+ csv_path = hf_hub_download(
156
+ repo_id="dlaj/energy-forecasting-files",
157
+ filename="data/processed/energy_consumption_aggregated_cleaned.csv",
158
+ repo_type="dataset"
159
+ )
160
+
161
+ # Datasets
162
+ train_dataset = InformerDataset(data_split="train", forecast_horizon=FORECAST_HORIZON, random_seed=13, csv_path=csv_path)
163
+ test_dataset = InformerDataset(data_split="test", forecast_horizon=FORECAST_HORIZON, random_seed=13, csv_path=csv_path)
164
+ test_dataset.scaler = train_dataset.scaler
165
+
166
+ return model, test_dataset, device
167
 
 
 
 
 
 
168
 
169
  @st.cache_data
170
  def load_data():
transformer_model/scripts/utils/informer_dataset_class.py CHANGED
@@ -18,6 +18,7 @@ class InformerDataset:
18
  data_stride_len: int = 1,
19
  task_name: str = "forecasting",
20
  random_seed: int = 42,
 
21
  ):
22
  """
23
  Parameters
@@ -36,7 +37,7 @@ class InformerDataset:
36
 
37
  self.seq_len = SEQ_LEN
38
  self.forecast_horizon = forecast_horizon
39
- self.full_file_path_and_name = DATA_PATH
40
  self.data_split = data_split
41
  self.data_stride_len = data_stride_len
42
  self.task_name = task_name
 
18
  data_stride_len: int = 1,
19
  task_name: str = "forecasting",
20
  random_seed: int = 42,
21
+ csv_path=None
22
  ):
23
  """
24
  Parameters
 
37
 
38
  self.seq_len = SEQ_LEN
39
  self.forecast_horizon = forecast_horizon
40
+ self.full_file_path_and_name = csv_path if csv_path is not None else DATA_PATH
41
  self.data_split = data_split
42
  self.data_stride_len = data_stride_len
43
  self.task_name = task_name