Spaces:
Sleeping
Sleeping
| import logging | |
| import os | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from transformer_model.scripts.config_transformer import CHECKPOINT_DIR | |
| from transformer_model.scripts.training.load_basis_model import \ | |
| load_moment_model | |
| logging.basicConfig(level=logging.INFO) | |
| # load model from checkpoint if available, else download it from hugging face | |
| def load_real_transformer_model(device=None): # ⬅️ Name geändert | |
| if device is None: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = load_moment_model() | |
| filename = "model_final.pth" | |
| local_path = os.path.join(CHECKPOINT_DIR, filename) | |
| if os.path.exists(local_path): | |
| checkpoint_path = local_path | |
| print("Loading model from local path...") | |
| else: | |
| print("Downloading model from Hugging Face Hub...") | |
| checkpoint_path = hf_hub_download( | |
| repo_id="dlaj/energy-forecasting-files", # passe ggf. an | |
| filename=f"transformer_model/{filename}", | |
| repo_type="dataset", | |
| ) | |
| model.load_state_dict(torch.load(checkpoint_path, map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| logging.info(f"Model loaded from: {checkpoint_path}") | |
| return model, device | |