dlaj commited on
Commit
8b0f996
·
1 Parent(s): 20b8d14

Deploy from GitHub

Browse files
streamlit_simulation/app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import time
2
  import warnings
3
 
@@ -8,6 +9,7 @@ import pandas as pd
8
  import streamlit as st
9
  import torch
10
  from config_streamlit import DATA_PATH, PLOT_COLOR, TRAIN_RATIO
 
11
 
12
  from lightgbm_model.scripts.config_lightgbm import FEATURES
13
  from lightgbm_model.scripts.model_loader_wrapper import load_lightgbm_model
@@ -84,6 +86,21 @@ init_session_state()
84
 
85
 
86
  # ============================== Loaders Cache ==============================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  @st.cache_data
88
  def load_cached_lightgbm_model():
89
  return load_lightgbm_model()
@@ -449,7 +466,7 @@ if model_choice == "Transformer Model (moments)":
449
  len(InformerDataset(data_split="train", forecast_horizon=FORECAST_HORIZON))
450
  + SEQ_LEN
451
  )
452
- base_timestamp = pd.read_csv(DATA_PATH, parse_dates=["date"])["date"].iloc[
453
  test_start_idx
454
  ] # get original timestamp for later, cause not in dataset anymore
455
 
 
1
+ import os
2
  import time
3
  import warnings
4
 
 
9
  import streamlit as st
10
  import torch
11
  from config_streamlit import DATA_PATH, PLOT_COLOR, TRAIN_RATIO
12
+ from huggingface_hub import hf_hub_download
13
 
14
  from lightgbm_model.scripts.config_lightgbm import FEATURES
15
  from lightgbm_model.scripts.model_loader_wrapper import load_lightgbm_model
 
86
 
87
 
88
  # ============================== Loaders Cache ==============================
89
+ HF_REPO = "dlaj/energy-forecasting-files"
90
+ HF_FILENAME = "data/processed/energy_consumption_aggregated_cleaned.csv"
91
+
92
+ # if local data, use them, if not, download from huggingface
93
+ if os.path.exists(DATA_PATH):
94
+ CSV_PATH = DATA_PATH
95
+ else:
96
+ CSV_PATH = hf_hub_download(
97
+ repo_id=HF_REPO,
98
+ filename=HF_FILENAME,
99
+ repo_type="dataset",
100
+ cache_dir="hf_cache", # Optional
101
+ )
102
+
103
+
104
  @st.cache_data
105
  def load_cached_lightgbm_model():
106
  return load_lightgbm_model()
 
466
  len(InformerDataset(data_split="train", forecast_horizon=FORECAST_HORIZON))
467
  + SEQ_LEN
468
  )
469
+ base_timestamp = pd.read_csv(CSV_PATH, parse_dates=["date"])["date"].iloc[
470
  test_start_idx
471
  ] # get original timestamp for later, cause not in dataset anymore
472
 
streamlit_simulation/utils_streamlit.py CHANGED
@@ -27,3 +27,17 @@ def load_data():
27
 
28
  print(f"Lade lokale Datei: {DATA_PATH}")
29
  return pd.read_csv(DATA_PATH, parse_dates=["date"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  print(f"Lade lokale Datei: {DATA_PATH}")
29
  return pd.read_csv(DATA_PATH, parse_dates=["date"])
30
+
31
+
32
+ def resolve_csv_path() -> str:
33
+ if os.path.exists(DATA_PATH):
34
+ print(f"Lokale Datei verwendet: {DATA_PATH}")
35
+ return DATA_PATH
36
+ else:
37
+ print(f"Lokale Datei nicht gefunden, lade von HF: {HF_FILENAME}")
38
+ return hf_hub_download(
39
+ repo_id=HF_REPO,
40
+ filename=HF_FILENAME,
41
+ repo_type="dataset",
42
+ cache_dir="hf_cache",
43
+ )
transformer_model/scripts/utils/informer_dataset_class.py CHANGED
@@ -12,6 +12,21 @@ from transformer_model.scripts.config_transformer import DATA_PATH, SEQ_LEN
12
 
13
  logging.basicConfig(level=logging.INFO)
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  class InformerDataset:
17
  def __init__(
@@ -39,28 +54,12 @@ class InformerDataset:
39
 
40
  self.seq_len = SEQ_LEN
41
  self.forecast_horizon = forecast_horizon
 
42
  self.data_split = data_split
43
  self.data_stride_len = data_stride_len
44
  self.task_name = task_name
45
  self.random_seed = random_seed
46
 
47
- # use local dataset if available, else download it from huggingface
48
- HF_REPO = "dlaj/energy-forecasting-files"
49
- HF_FILENAME = "data/processed/energy_consumption_aggregated_cleaned.csv"
50
-
51
- if not os.path.exists(DATA_PATH):
52
- print(f"Lokale Datei nicht gefunden: {DATA_PATH}")
53
- print("Lade von Hugging Face Hub...")
54
-
55
- self.full_file_path_and_name = hf_hub_download(
56
- repo_id=HF_REPO,
57
- filename=HF_FILENAME,
58
- repo_type="dataset",
59
- cache_dir="hf_cache", # optional
60
- )
61
- else:
62
- self.full_file_path_and_name = DATA_PATH
63
-
64
  self._read_data()
65
 
66
  def _get_borders(self):
 
12
 
13
  logging.basicConfig(level=logging.INFO)
14
 
15
+ HF_REPO = "dlaj/energy-forecasting-files"
16
+ HF_FILENAME = "data/processed/energy_consumption_aggregated_cleaned.csv"
17
+
18
+ if os.path.exists(DATA_PATH):
19
+ print(f"Lokale Datei gefunden: {DATA_PATH}")
20
+ CSV_PATH = DATA_PATH
21
+ else:
22
+ print("Lokale Datei NICHT gefunden! Lade von Hugging Face...")
23
+ CSV_PATH = hf_hub_download(
24
+ repo_id=HF_REPO,
25
+ filename=HF_FILENAME,
26
+ repo_type="dataset",
27
+ cache_dir="hf_cache", # Optional
28
+ )
29
+
30
 
31
  class InformerDataset:
32
  def __init__(
 
54
 
55
  self.seq_len = SEQ_LEN
56
  self.forecast_horizon = forecast_horizon
57
+ self.full_file_path_and_name = CSV_PATH
58
  self.data_split = data_split
59
  self.data_stride_len = data_stride_len
60
  self.task_name = task_name
61
  self.random_seed = random_seed
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  self._read_data()
64
 
65
  def _get_borders(self):