Spaces:
Sleeping
Sleeping
Deploy from GitHub
Browse files
streamlit_simulation/app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import time
|
2 |
import warnings
|
3 |
|
@@ -8,6 +9,7 @@ import pandas as pd
|
|
8 |
import streamlit as st
|
9 |
import torch
|
10 |
from config_streamlit import DATA_PATH, PLOT_COLOR, TRAIN_RATIO
|
|
|
11 |
|
12 |
from lightgbm_model.scripts.config_lightgbm import FEATURES
|
13 |
from lightgbm_model.scripts.model_loader_wrapper import load_lightgbm_model
|
@@ -84,6 +86,21 @@ init_session_state()
|
|
84 |
|
85 |
|
86 |
# ============================== Loaders Cache ==============================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
@st.cache_data
|
88 |
def load_cached_lightgbm_model():
|
89 |
return load_lightgbm_model()
|
@@ -449,7 +466,7 @@ if model_choice == "Transformer Model (moments)":
|
|
449 |
len(InformerDataset(data_split="train", forecast_horizon=FORECAST_HORIZON))
|
450 |
+ SEQ_LEN
|
451 |
)
|
452 |
-
base_timestamp = pd.read_csv(
|
453 |
test_start_idx
|
454 |
] # get original timestamp for later, cause not in dataset anymore
|
455 |
|
|
|
1 |
+
import os
|
2 |
import time
|
3 |
import warnings
|
4 |
|
|
|
9 |
import streamlit as st
|
10 |
import torch
|
11 |
from config_streamlit import DATA_PATH, PLOT_COLOR, TRAIN_RATIO
|
12 |
+
from huggingface_hub import hf_hub_download
|
13 |
|
14 |
from lightgbm_model.scripts.config_lightgbm import FEATURES
|
15 |
from lightgbm_model.scripts.model_loader_wrapper import load_lightgbm_model
|
|
|
86 |
|
87 |
|
88 |
# ============================== Loaders Cache ==============================
|
89 |
+
HF_REPO = "dlaj/energy-forecasting-files"
|
90 |
+
HF_FILENAME = "data/processed/energy_consumption_aggregated_cleaned.csv"
|
91 |
+
|
92 |
+
# if local data, use them, if not, download from huggingface
|
93 |
+
if os.path.exists(DATA_PATH):
|
94 |
+
CSV_PATH = DATA_PATH
|
95 |
+
else:
|
96 |
+
CSV_PATH = hf_hub_download(
|
97 |
+
repo_id=HF_REPO,
|
98 |
+
filename=HF_FILENAME,
|
99 |
+
repo_type="dataset",
|
100 |
+
cache_dir="hf_cache", # Optional
|
101 |
+
)
|
102 |
+
|
103 |
+
|
104 |
@st.cache_data
|
105 |
def load_cached_lightgbm_model():
|
106 |
return load_lightgbm_model()
|
|
|
466 |
len(InformerDataset(data_split="train", forecast_horizon=FORECAST_HORIZON))
|
467 |
+ SEQ_LEN
|
468 |
)
|
469 |
+
base_timestamp = pd.read_csv(CSV_PATH, parse_dates=["date"])["date"].iloc[
|
470 |
test_start_idx
|
471 |
] # get original timestamp for later, cause not in dataset anymore
|
472 |
|
streamlit_simulation/utils_streamlit.py
CHANGED
@@ -27,3 +27,17 @@ def load_data():
|
|
27 |
|
28 |
print(f"Lade lokale Datei: {DATA_PATH}")
|
29 |
return pd.read_csv(DATA_PATH, parse_dates=["date"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
print(f"Lade lokale Datei: {DATA_PATH}")
|
29 |
return pd.read_csv(DATA_PATH, parse_dates=["date"])
|
30 |
+
|
31 |
+
|
32 |
+
def resolve_csv_path() -> str:
|
33 |
+
if os.path.exists(DATA_PATH):
|
34 |
+
print(f"Lokale Datei verwendet: {DATA_PATH}")
|
35 |
+
return DATA_PATH
|
36 |
+
else:
|
37 |
+
print(f"Lokale Datei nicht gefunden, lade von HF: {HF_FILENAME}")
|
38 |
+
return hf_hub_download(
|
39 |
+
repo_id=HF_REPO,
|
40 |
+
filename=HF_FILENAME,
|
41 |
+
repo_type="dataset",
|
42 |
+
cache_dir="hf_cache",
|
43 |
+
)
|
transformer_model/scripts/utils/informer_dataset_class.py
CHANGED
@@ -12,6 +12,21 @@ from transformer_model.scripts.config_transformer import DATA_PATH, SEQ_LEN
|
|
12 |
|
13 |
logging.basicConfig(level=logging.INFO)
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
class InformerDataset:
|
17 |
def __init__(
|
@@ -39,28 +54,12 @@ class InformerDataset:
|
|
39 |
|
40 |
self.seq_len = SEQ_LEN
|
41 |
self.forecast_horizon = forecast_horizon
|
|
|
42 |
self.data_split = data_split
|
43 |
self.data_stride_len = data_stride_len
|
44 |
self.task_name = task_name
|
45 |
self.random_seed = random_seed
|
46 |
|
47 |
-
# use local dataset if available, else download it from huggingface
|
48 |
-
HF_REPO = "dlaj/energy-forecasting-files"
|
49 |
-
HF_FILENAME = "data/processed/energy_consumption_aggregated_cleaned.csv"
|
50 |
-
|
51 |
-
if not os.path.exists(DATA_PATH):
|
52 |
-
print(f"Lokale Datei nicht gefunden: {DATA_PATH}")
|
53 |
-
print("Lade von Hugging Face Hub...")
|
54 |
-
|
55 |
-
self.full_file_path_and_name = hf_hub_download(
|
56 |
-
repo_id=HF_REPO,
|
57 |
-
filename=HF_FILENAME,
|
58 |
-
repo_type="dataset",
|
59 |
-
cache_dir="hf_cache", # optional
|
60 |
-
)
|
61 |
-
else:
|
62 |
-
self.full_file_path_and_name = DATA_PATH
|
63 |
-
|
64 |
self._read_data()
|
65 |
|
66 |
def _get_borders(self):
|
|
|
12 |
|
13 |
logging.basicConfig(level=logging.INFO)
|
14 |
|
15 |
+
HF_REPO = "dlaj/energy-forecasting-files"
|
16 |
+
HF_FILENAME = "data/processed/energy_consumption_aggregated_cleaned.csv"
|
17 |
+
|
18 |
+
if os.path.exists(DATA_PATH):
|
19 |
+
print(f"Lokale Datei gefunden: {DATA_PATH}")
|
20 |
+
CSV_PATH = DATA_PATH
|
21 |
+
else:
|
22 |
+
print("Lokale Datei NICHT gefunden! Lade von Hugging Face...")
|
23 |
+
CSV_PATH = hf_hub_download(
|
24 |
+
repo_id=HF_REPO,
|
25 |
+
filename=HF_FILENAME,
|
26 |
+
repo_type="dataset",
|
27 |
+
cache_dir="hf_cache", # Optional
|
28 |
+
)
|
29 |
+
|
30 |
|
31 |
class InformerDataset:
|
32 |
def __init__(
|
|
|
54 |
|
55 |
self.seq_len = SEQ_LEN
|
56 |
self.forecast_horizon = forecast_horizon
|
57 |
+
self.full_file_path_and_name = CSV_PATH
|
58 |
self.data_split = data_split
|
59 |
self.data_stride_len = data_stride_len
|
60 |
self.task_name = task_name
|
61 |
self.random_seed = random_seed
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
self._read_data()
|
64 |
|
65 |
def _get_borders(self):
|