# load_basis_model.py # Load and initialize the base MOMENT model before finetuning import logging import torch from momentfm import MOMENTPipeline from transformer_model.scripts.config_transformer import (FORECAST_HORIZON, FREEZE_EMBEDDER, FREEZE_ENCODER, FREEZE_HEAD, HEAD_DROPOUT, SEQ_LEN, WEIGHT_DECAY) # Setup logging logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") def load_moment_model(): """ Loads and configures the MOMENT model for forecasting. """ logging.info("Loading MOMENT model...") model = MOMENTPipeline.from_pretrained( "AutonLab/MOMENT-1-large", model_kwargs={ "task_name": "forecasting", "forecast_horizon": FORECAST_HORIZON, # default = 1 "head_dropout": HEAD_DROPOUT, # default = 0.1 "weight_decay": WEIGHT_DECAY, # default = 0.0 "freeze_encoder": FREEZE_ENCODER, # default = True "freeze_embedder": FREEZE_EMBEDDER, # default = True "freeze_head": FREEZE_HEAD, # default = False }, ) model.init() logging.info("Model initialized successfully.") return model def print_trainable_params(model): """ Logs all trainable (unfrozen) parameters of the model. """ logging.info("Unfrozen parameters:") for name, param in model.named_parameters(): if param.requires_grad: logging.info(f" {name}") def test_dummy_forward(model): """ Performs a dummy forward pass to verify the model runs without error. """ logging.info( "Running dummy forward pass with random tensors to see if model is running." ) dummy_x = torch.randn(16, 1, SEQ_LEN) output = model(x_enc=dummy_x) logging.info(f"Dummy forward pass successful.Output shape: {output.shape}") if __name__ == "__main__": model = load_moment_model() print_trainable_params(model) test_dummy_forward(model)