3v324v23's picture
initial
c689089
# train.py
import os
import json
import time
import logging
import numpy as np
import torch
from tqdm import tqdm
from sklearn.metrics import mean_squared_error, mean_absolute_error
from transformer_model.scripts.config_transformer import (
BASE_DIR,
MAX_EPOCHS,
BATCH_SIZE,
LEARNING_RATE,
MAX_LR,
GRAD_CLIP,
FORECAST_HORIZON,
CHECKPOINT_DIR,
RESULTS_DIR
)
from transformer_model.scripts.training.load_basis_model import load_moment_model
from transformer_model.scripts.utils.create_dataloaders import create_dataloaders
from transformer_model.scripts.utils.check_device import check_device
from momentfm.utils.utils import control_randomness
# === Setup logging ===
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
def train():
# Start timing
start_time = time.time()
# Setup device (CUDA / DirectML / CPU) and AMP scaler
device, backend, scaler = check_device()
# Load base model
model = load_moment_model().to(device)
# Set random seeds for reproducibility
control_randomness(seed=13)
# Setup loss function and optimizer
criterion = torch.nn.MSELoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# Load data
train_loader, test_loader = create_dataloaders()
# Setup learning rate scheduler (OneCycle policy)
total_steps = len(train_loader) * MAX_EPOCHS
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=MAX_LR,
total_steps=total_steps,
pct_start=0.3
)
# Ensure output folders exist
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)
# Store metrics
train_losses, test_mses, test_maes = [], [], []
best_mae = float("inf")
best_epoch = None
no_improve_epochs = 0
patience = 5
for epoch in range(MAX_EPOCHS):
model.train()
epoch_losses = []
for timeseries, forecast, input_mask in tqdm(train_loader, desc=f"Epoch {epoch}"):
timeseries = timeseries.float().to(device)
input_mask = input_mask.to(device)
forecast = forecast.float().to(device)
# Zero gradients
optimizer.zero_grad(set_to_none=True)
# Forward pass (with AMP if enabled)
if scaler:
with torch.amp.autocast(device_type="cuda"):
output = model(x_enc=timeseries, input_mask=input_mask)
loss = criterion(output.forecast, forecast)
else:
output = model(x_enc=timeseries, input_mask=input_mask)
loss = criterion(output.forecast, forecast)
# Backward pass + optimization
if scaler:
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
optimizer.step()
epoch_losses.append(loss.item())
average_train_loss = np.mean(epoch_losses)
train_losses.append(average_train_loss)
logging.info(f"Epoch {epoch}: Train Loss = {average_train_loss:.4f}")
# === Evaluation ===
model.eval()
trues, preds = [], []
with torch.no_grad():
for timeseries, forecast, input_mask in test_loader:
timeseries = timeseries.float().to(device)
input_mask = input_mask.to(device)
forecast = forecast.float().to(device)
if scaler:
with torch.amp.autocast(device_type="cuda"):
output = model(x_enc=timeseries, input_mask=input_mask)
else:
output = model(x_enc=timeseries, input_mask=input_mask)
trues.append(forecast.detach().cpu().numpy())
preds.append(output.forecast.detach().cpu().numpy())
trues = np.concatenate(trues, axis=0)
preds = np.concatenate(preds, axis=0)
# Reshape for sklearn metrics
trues_2d = trues.reshape(trues.shape[0], -1)
preds_2d = preds.reshape(preds.shape[0], -1)
mse = mean_squared_error(trues_2d, preds_2d)
mae = mean_absolute_error(trues_2d, preds_2d)
test_mses.append(mse)
test_maes.append(mae)
logging.info(f"Epoch {epoch}: Test MSE = {mse:.4f}, MAE = {mae:.4f}")
# === Early Stopping Check ===
if mae < best_mae:
best_mae = mae
best_epoch = epoch
no_improve_epochs = 0
# Save best model
best_model_path = os.path.join(CHECKPOINT_DIR, "best_model.pth")
torch.save(model.state_dict(), best_model_path)
logging.info(f"New best model saved to: {best_model_path} (MAE: {best_mae:.4f})")
else:
no_improve_epochs += 1
logging.info(f"No improvement in MAE for {no_improve_epochs} epoch(s).")
if no_improve_epochs >= patience:
logging.info("Early stopping triggered.")
break
# Save checkpoint
checkpoint_path = os.path.join(CHECKPOINT_DIR, f"model_epoch_{epoch}.pth")
torch.save(model.state_dict(), checkpoint_path)
scheduler.step()
logging.info(f"Best model was at epoch {best_epoch} with MAE: {best_mae:.4f}")
# Save final model
final_model_path = os.path.join(CHECKPOINT_DIR, "model_final.pth")
torch.save(model.state_dict(), final_model_path)
logging.info(f"Final model saved to: {final_model_path}")
logging.info(f"Final Test MSE: {test_mses[-1]:.4f}, MAE: {test_maes[-1]:.4f}")
# Save training metrics
metrics = {
"train_losses": [float(x) for x in train_losses],
"test_mses": [float(x) for x in test_mses],
"test_maes": [float(x) for x in test_maes]
}
metrics_path = os.path.join(RESULTS_DIR, "training_metrics.json")
with open(metrics_path, "w") as f:
json.dump(metrics, f)
logging.info(f"Training metrics saved to: {metrics_path}")
# Done
elapsed = time.time() - start_time
logging.info(f"Training complete in {elapsed / 60:.2f} minutes.")
# === Entry Point ===
if __name__ == "__main__":
try:
train()
except Exception as e:
logging.error(f"Training failed: {e}")