# train.py import os import json import time import logging import numpy as np import torch from tqdm import tqdm from sklearn.metrics import mean_squared_error, mean_absolute_error from transformer_model.scripts.config_transformer import ( BASE_DIR, MAX_EPOCHS, BATCH_SIZE, LEARNING_RATE, MAX_LR, GRAD_CLIP, FORECAST_HORIZON, CHECKPOINT_DIR, RESULTS_DIR ) from transformer_model.scripts.training.load_basis_model import load_moment_model from transformer_model.scripts.utils.create_dataloaders import create_dataloaders from transformer_model.scripts.utils.check_device import check_device from momentfm.utils.utils import control_randomness # === Setup logging === logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") def train(): # Start timing start_time = time.time() # Setup device (CUDA / DirectML / CPU) and AMP scaler device, backend, scaler = check_device() # Load base model model = load_moment_model().to(device) # Set random seeds for reproducibility control_randomness(seed=13) # Setup loss function and optimizer criterion = torch.nn.MSELoss().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) # Load data train_loader, test_loader = create_dataloaders() # Setup learning rate scheduler (OneCycle policy) total_steps = len(train_loader) * MAX_EPOCHS scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=MAX_LR, total_steps=total_steps, pct_start=0.3 ) # Ensure output folders exist os.makedirs(CHECKPOINT_DIR, exist_ok=True) os.makedirs(RESULTS_DIR, exist_ok=True) # Store metrics train_losses, test_mses, test_maes = [], [], [] best_mae = float("inf") best_epoch = None no_improve_epochs = 0 patience = 5 for epoch in range(MAX_EPOCHS): model.train() epoch_losses = [] for timeseries, forecast, input_mask in tqdm(train_loader, desc=f"Epoch {epoch}"): timeseries = timeseries.float().to(device) input_mask = input_mask.to(device) forecast = forecast.float().to(device) # Zero gradients optimizer.zero_grad(set_to_none=True) # Forward pass (with AMP if enabled) if scaler: with torch.amp.autocast(device_type="cuda"): output = model(x_enc=timeseries, input_mask=input_mask) loss = criterion(output.forecast, forecast) else: output = model(x_enc=timeseries, input_mask=input_mask) loss = criterion(output.forecast, forecast) # Backward pass + optimization if scaler: scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP) scaler.step(optimizer) scaler.update() else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP) optimizer.step() epoch_losses.append(loss.item()) average_train_loss = np.mean(epoch_losses) train_losses.append(average_train_loss) logging.info(f"Epoch {epoch}: Train Loss = {average_train_loss:.4f}") # === Evaluation === model.eval() trues, preds = [], [] with torch.no_grad(): for timeseries, forecast, input_mask in test_loader: timeseries = timeseries.float().to(device) input_mask = input_mask.to(device) forecast = forecast.float().to(device) if scaler: with torch.amp.autocast(device_type="cuda"): output = model(x_enc=timeseries, input_mask=input_mask) else: output = model(x_enc=timeseries, input_mask=input_mask) trues.append(forecast.detach().cpu().numpy()) preds.append(output.forecast.detach().cpu().numpy()) trues = np.concatenate(trues, axis=0) preds = np.concatenate(preds, axis=0) # Reshape for sklearn metrics trues_2d = trues.reshape(trues.shape[0], -1) preds_2d = preds.reshape(preds.shape[0], -1) mse = mean_squared_error(trues_2d, preds_2d) mae = mean_absolute_error(trues_2d, preds_2d) test_mses.append(mse) test_maes.append(mae) logging.info(f"Epoch {epoch}: Test MSE = {mse:.4f}, MAE = {mae:.4f}") # === Early Stopping Check === if mae < best_mae: best_mae = mae best_epoch = epoch no_improve_epochs = 0 # Save best model best_model_path = os.path.join(CHECKPOINT_DIR, "best_model.pth") torch.save(model.state_dict(), best_model_path) logging.info(f"New best model saved to: {best_model_path} (MAE: {best_mae:.4f})") else: no_improve_epochs += 1 logging.info(f"No improvement in MAE for {no_improve_epochs} epoch(s).") if no_improve_epochs >= patience: logging.info("Early stopping triggered.") break # Save checkpoint checkpoint_path = os.path.join(CHECKPOINT_DIR, f"model_epoch_{epoch}.pth") torch.save(model.state_dict(), checkpoint_path) scheduler.step() logging.info(f"Best model was at epoch {best_epoch} with MAE: {best_mae:.4f}") # Save final model final_model_path = os.path.join(CHECKPOINT_DIR, "model_final.pth") torch.save(model.state_dict(), final_model_path) logging.info(f"Final model saved to: {final_model_path}") logging.info(f"Final Test MSE: {test_mses[-1]:.4f}, MAE: {test_maes[-1]:.4f}") # Save training metrics metrics = { "train_losses": [float(x) for x in train_losses], "test_mses": [float(x) for x in test_mses], "test_maes": [float(x) for x in test_maes] } metrics_path = os.path.join(RESULTS_DIR, "training_metrics.json") with open(metrics_path, "w") as f: json.dump(metrics, f) logging.info(f"Training metrics saved to: {metrics_path}") # Done elapsed = time.time() - start_time logging.info(f"Training complete in {elapsed / 60:.2f} minutes.") # === Entry Point === if __name__ == "__main__": try: train() except Exception as e: logging.error(f"Training failed: {e}")