dlaj's picture
Deploy from GitHub
8cc5633
# load_basis_model.py
# Load and initialize the base MOMENT model before finetuning
import logging
import torch
from momentfm import MOMENTPipeline
from transformer_model.scripts.config_transformer import (FORECAST_HORIZON,
FREEZE_EMBEDDER,
FREEZE_ENCODER,
FREEZE_HEAD,
HEAD_DROPOUT,
SEQ_LEN,
WEIGHT_DECAY)
# Setup logging
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
def load_moment_model():
"""
Loads and configures the MOMENT model for forecasting.
"""
logging.info("Loading MOMENT model...")
model = MOMENTPipeline.from_pretrained(
"AutonLab/MOMENT-1-large",
model_kwargs={
"task_name": "forecasting",
"forecast_horizon": FORECAST_HORIZON, # default = 1
"head_dropout": HEAD_DROPOUT, # default = 0.1
"weight_decay": WEIGHT_DECAY, # default = 0.0
"freeze_encoder": FREEZE_ENCODER, # default = True
"freeze_embedder": FREEZE_EMBEDDER, # default = True
"freeze_head": FREEZE_HEAD, # default = False
},
)
model.init()
logging.info("Model initialized successfully.")
return model
def print_trainable_params(model):
"""
Logs all trainable (unfrozen) parameters of the model.
"""
logging.info("Unfrozen parameters:")
for name, param in model.named_parameters():
if param.requires_grad:
logging.info(f" {name}")
def test_dummy_forward(model):
"""
Performs a dummy forward pass to verify the model runs without error.
"""
logging.info(
"Running dummy forward pass with random tensors to see if model is running."
)
dummy_x = torch.randn(16, 1, SEQ_LEN)
output = model(x_enc=dummy_x)
logging.info(f"Dummy forward pass successful.Output shape: {output.shape}")
if __name__ == "__main__":
model = load_moment_model()
print_trainable_params(model)
test_dummy_forward(model)