Spaces:
Sleeping
Sleeping
# load_basis_model.py | |
# Load and initialize the base MOMENT model before finetuning | |
import logging | |
import torch | |
from momentfm import MOMENTPipeline | |
from transformer_model.scripts.config_transformer import (FORECAST_HORIZON, | |
FREEZE_EMBEDDER, | |
FREEZE_ENCODER, | |
FREEZE_HEAD, | |
HEAD_DROPOUT, | |
SEQ_LEN, | |
WEIGHT_DECAY) | |
# Setup logging | |
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") | |
def load_moment_model(): | |
""" | |
Loads and configures the MOMENT model for forecasting. | |
""" | |
logging.info("Loading MOMENT model...") | |
model = MOMENTPipeline.from_pretrained( | |
"AutonLab/MOMENT-1-large", | |
model_kwargs={ | |
"task_name": "forecasting", | |
"forecast_horizon": FORECAST_HORIZON, # default = 1 | |
"head_dropout": HEAD_DROPOUT, # default = 0.1 | |
"weight_decay": WEIGHT_DECAY, # default = 0.0 | |
"freeze_encoder": FREEZE_ENCODER, # default = True | |
"freeze_embedder": FREEZE_EMBEDDER, # default = True | |
"freeze_head": FREEZE_HEAD, # default = False | |
}, | |
) | |
model.init() | |
logging.info("Model initialized successfully.") | |
return model | |
def print_trainable_params(model): | |
""" | |
Logs all trainable (unfrozen) parameters of the model. | |
""" | |
logging.info("Unfrozen parameters:") | |
for name, param in model.named_parameters(): | |
if param.requires_grad: | |
logging.info(f" {name}") | |
def test_dummy_forward(model): | |
""" | |
Performs a dummy forward pass to verify the model runs without error. | |
""" | |
logging.info( | |
"Running dummy forward pass with random tensors to see if model is running." | |
) | |
dummy_x = torch.randn(16, 1, SEQ_LEN) | |
output = model(x_enc=dummy_x) | |
logging.info(f"Dummy forward pass successful.Output shape: {output.shape}") | |
if __name__ == "__main__": | |
model = load_moment_model() | |
print_trainable_params(model) | |
test_dummy_forward(model) | |