Spaces:
Sleeping
Sleeping
Deploy from GitHub
Browse files
streamlit_simulation/app.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import time
|
| 2 |
import warnings
|
| 3 |
|
|
@@ -8,6 +9,7 @@ import pandas as pd
|
|
| 8 |
import streamlit as st
|
| 9 |
import torch
|
| 10 |
from config_streamlit import DATA_PATH, PLOT_COLOR, TRAIN_RATIO
|
|
|
|
| 11 |
|
| 12 |
from lightgbm_model.scripts.config_lightgbm import FEATURES
|
| 13 |
from lightgbm_model.scripts.model_loader_wrapper import load_lightgbm_model
|
|
@@ -84,6 +86,21 @@ init_session_state()
|
|
| 84 |
|
| 85 |
|
| 86 |
# ============================== Loaders Cache ==============================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
@st.cache_data
|
| 88 |
def load_cached_lightgbm_model():
|
| 89 |
return load_lightgbm_model()
|
|
@@ -449,7 +466,7 @@ if model_choice == "Transformer Model (moments)":
|
|
| 449 |
len(InformerDataset(data_split="train", forecast_horizon=FORECAST_HORIZON))
|
| 450 |
+ SEQ_LEN
|
| 451 |
)
|
| 452 |
-
base_timestamp = pd.read_csv(
|
| 453 |
test_start_idx
|
| 454 |
] # get original timestamp for later, cause not in dataset anymore
|
| 455 |
|
|
|
|
| 1 |
+
import os
|
| 2 |
import time
|
| 3 |
import warnings
|
| 4 |
|
|
|
|
| 9 |
import streamlit as st
|
| 10 |
import torch
|
| 11 |
from config_streamlit import DATA_PATH, PLOT_COLOR, TRAIN_RATIO
|
| 12 |
+
from huggingface_hub import hf_hub_download
|
| 13 |
|
| 14 |
from lightgbm_model.scripts.config_lightgbm import FEATURES
|
| 15 |
from lightgbm_model.scripts.model_loader_wrapper import load_lightgbm_model
|
|
|
|
| 86 |
|
| 87 |
|
| 88 |
# ============================== Loaders Cache ==============================
|
| 89 |
+
HF_REPO = "dlaj/energy-forecasting-files"
|
| 90 |
+
HF_FILENAME = "data/processed/energy_consumption_aggregated_cleaned.csv"
|
| 91 |
+
|
| 92 |
+
# if local data, use them, if not, download from huggingface
|
| 93 |
+
if os.path.exists(DATA_PATH):
|
| 94 |
+
CSV_PATH = DATA_PATH
|
| 95 |
+
else:
|
| 96 |
+
CSV_PATH = hf_hub_download(
|
| 97 |
+
repo_id=HF_REPO,
|
| 98 |
+
filename=HF_FILENAME,
|
| 99 |
+
repo_type="dataset",
|
| 100 |
+
cache_dir="hf_cache", # Optional
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
@st.cache_data
|
| 105 |
def load_cached_lightgbm_model():
|
| 106 |
return load_lightgbm_model()
|
|
|
|
| 466 |
len(InformerDataset(data_split="train", forecast_horizon=FORECAST_HORIZON))
|
| 467 |
+ SEQ_LEN
|
| 468 |
)
|
| 469 |
+
base_timestamp = pd.read_csv(CSV_PATH, parse_dates=["date"])["date"].iloc[
|
| 470 |
test_start_idx
|
| 471 |
] # get original timestamp for later, cause not in dataset anymore
|
| 472 |
|
streamlit_simulation/utils_streamlit.py
CHANGED
|
@@ -27,3 +27,17 @@ def load_data():
|
|
| 27 |
|
| 28 |
print(f"Lade lokale Datei: {DATA_PATH}")
|
| 29 |
return pd.read_csv(DATA_PATH, parse_dates=["date"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
print(f"Lade lokale Datei: {DATA_PATH}")
|
| 29 |
return pd.read_csv(DATA_PATH, parse_dates=["date"])
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def resolve_csv_path() -> str:
|
| 33 |
+
if os.path.exists(DATA_PATH):
|
| 34 |
+
print(f"Lokale Datei verwendet: {DATA_PATH}")
|
| 35 |
+
return DATA_PATH
|
| 36 |
+
else:
|
| 37 |
+
print(f"Lokale Datei nicht gefunden, lade von HF: {HF_FILENAME}")
|
| 38 |
+
return hf_hub_download(
|
| 39 |
+
repo_id=HF_REPO,
|
| 40 |
+
filename=HF_FILENAME,
|
| 41 |
+
repo_type="dataset",
|
| 42 |
+
cache_dir="hf_cache",
|
| 43 |
+
)
|
transformer_model/scripts/utils/informer_dataset_class.py
CHANGED
|
@@ -12,6 +12,21 @@ from transformer_model.scripts.config_transformer import DATA_PATH, SEQ_LEN
|
|
| 12 |
|
| 13 |
logging.basicConfig(level=logging.INFO)
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
class InformerDataset:
|
| 17 |
def __init__(
|
|
@@ -39,28 +54,12 @@ class InformerDataset:
|
|
| 39 |
|
| 40 |
self.seq_len = SEQ_LEN
|
| 41 |
self.forecast_horizon = forecast_horizon
|
|
|
|
| 42 |
self.data_split = data_split
|
| 43 |
self.data_stride_len = data_stride_len
|
| 44 |
self.task_name = task_name
|
| 45 |
self.random_seed = random_seed
|
| 46 |
|
| 47 |
-
# use local dataset if available, else download it from huggingface
|
| 48 |
-
HF_REPO = "dlaj/energy-forecasting-files"
|
| 49 |
-
HF_FILENAME = "data/processed/energy_consumption_aggregated_cleaned.csv"
|
| 50 |
-
|
| 51 |
-
if not os.path.exists(DATA_PATH):
|
| 52 |
-
print(f"Lokale Datei nicht gefunden: {DATA_PATH}")
|
| 53 |
-
print("Lade von Hugging Face Hub...")
|
| 54 |
-
|
| 55 |
-
self.full_file_path_and_name = hf_hub_download(
|
| 56 |
-
repo_id=HF_REPO,
|
| 57 |
-
filename=HF_FILENAME,
|
| 58 |
-
repo_type="dataset",
|
| 59 |
-
cache_dir="hf_cache", # optional
|
| 60 |
-
)
|
| 61 |
-
else:
|
| 62 |
-
self.full_file_path_and_name = DATA_PATH
|
| 63 |
-
|
| 64 |
self._read_data()
|
| 65 |
|
| 66 |
def _get_borders(self):
|
|
|
|
| 12 |
|
| 13 |
logging.basicConfig(level=logging.INFO)
|
| 14 |
|
| 15 |
+
HF_REPO = "dlaj/energy-forecasting-files"
|
| 16 |
+
HF_FILENAME = "data/processed/energy_consumption_aggregated_cleaned.csv"
|
| 17 |
+
|
| 18 |
+
if os.path.exists(DATA_PATH):
|
| 19 |
+
print(f"Lokale Datei gefunden: {DATA_PATH}")
|
| 20 |
+
CSV_PATH = DATA_PATH
|
| 21 |
+
else:
|
| 22 |
+
print("Lokale Datei NICHT gefunden! Lade von Hugging Face...")
|
| 23 |
+
CSV_PATH = hf_hub_download(
|
| 24 |
+
repo_id=HF_REPO,
|
| 25 |
+
filename=HF_FILENAME,
|
| 26 |
+
repo_type="dataset",
|
| 27 |
+
cache_dir="hf_cache", # Optional
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
|
| 31 |
class InformerDataset:
|
| 32 |
def __init__(
|
|
|
|
| 54 |
|
| 55 |
self.seq_len = SEQ_LEN
|
| 56 |
self.forecast_horizon = forecast_horizon
|
| 57 |
+
self.full_file_path_and_name = CSV_PATH
|
| 58 |
self.data_split = data_split
|
| 59 |
self.data_stride_len = data_stride_len
|
| 60 |
self.task_name = task_name
|
| 61 |
self.random_seed = random_seed
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
self._read_data()
|
| 64 |
|
| 65 |
def _get_borders(self):
|