# load_basis_model.py # Load and initialize the base MOMENT model before finetuning import torch import logging from momentfm import MOMENTPipeline from transformer_model.scripts.config_transformer import ( FORECAST_HORIZON, FREEZE_ENCODER, FREEZE_EMBEDDER, FREEZE_HEAD, WEIGHT_DECAY, HEAD_DROPOUT, SEQ_LEN ) # Setup logging logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") def load_moment_model(): """ Loads and configures the MOMENT model for forecasting. """ logging.info("Loading MOMENT model...") model = MOMENTPipeline.from_pretrained( "AutonLab/MOMENT-1-large", model_kwargs={ 'task_name': 'forecasting', 'forecast_horizon': FORECAST_HORIZON, # default = 1 'head_dropout': HEAD_DROPOUT, # default = 0.1 'weight_decay': WEIGHT_DECAY, # default = 0.0 'freeze_encoder': FREEZE_ENCODER, # default = True 'freeze_embedder': FREEZE_EMBEDDER, # default = True 'freeze_head': FREEZE_HEAD # default = False } ) model.init() logging.info("Model initialized successfully.") return model def print_trainable_params(model): """ Logs all trainable (unfrozen) parameters of the model. """ logging.info("Unfrozen parameters:") for name, param in model.named_parameters(): if param.requires_grad: logging.info(f" {name}") def test_dummy_forward(model): """ Performs a dummy forward pass to verify the model runs without error. """ logging.info("Running dummy forward pass with random tensors to see if model is running.") dummy_x = torch.randn(16, 1, SEQ_LEN) output = model(x_enc=dummy_x) logging.info("Dummy forward pass successful.") if __name__ == "__main__": model = load_moment_model() print_trainable_params(model) test_dummy_forward(model)