ReCEP / src /bce /utils /trainer.py
NielTT's picture
Upload 108 files
e611d1f verified
import os
import json
import time
import numpy as np
from tqdm import tqdm
from pathlib import Path
from datetime import datetime
import matplotlib.pyplot as plt
from collections import defaultdict
import torch
from torch.amp import autocast, GradScaler
from ..loss import CLoss
from .metrics import calculate_graph_metrics, calculate_node_metrics
from .results import evaluate_ReCEP
from .constants import BASE_DIR
from ..data.data import create_data_loader
from ..model.ReCEP import get_model, ReCEP
from ..model.scheduler import get_scheduler
torch.set_num_threads(12)
def check_for_nan_inf(tensor, name="tensor"):
"""Check for NaN or Inf values in tensor and raise error if found."""
if torch.isnan(tensor).any():
raise ValueError(f"NaN detected in {name}")
if torch.isinf(tensor).any():
raise ValueError(f"Inf detected in {name}")
def format_loss_info(loss_dict, prefix=""):
"""
Format loss information from loss_dict for display.
Args:
loss_dict: Dictionary containing loss components
prefix: Prefix for the output string
Returns:
Formatted string with loss breakdown
"""
main_info = f"{prefix}Total: {loss_dict['loss/total']:.4f}, " \
f"Region: {loss_dict['loss/region']:.4f}, " \
f"Node: {loss_dict['loss/node']:.4f}"
additional_info = []
if loss_dict.get('loss/cls', 0) > 0:
additional_info.append(f"Cls: {loss_dict['loss/cls']:.4f}, Reg: {loss_dict['loss/reg']:.4f}")
if loss_dict.get('loss/consistency', 0) > 0:
additional_info.append(f"Consistency: {loss_dict['loss/consistency']:.4f}")
if additional_info:
return main_info + "\n" + prefix + " " + ", ".join(additional_info)
return main_info
class Trainer:
"""
Trainer class for ReCEP model with comprehensive training and evaluation.
Features:
- Early stopping with patience
- Model checkpointing (best AUPRC and best F1)
- Mixed precision training
- Comprehensive metrics calculation
- Learning rate scheduling
"""
@staticmethod
def convert_to_serializable(obj):
"""Convert numpy/torch types to Python native types for JSON serialization."""
if isinstance(obj, dict):
return {k: Trainer.convert_to_serializable(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [Trainer.convert_to_serializable(item) for item in obj]
elif isinstance(obj, (np.float32, np.float64)):
return float(obj)
elif isinstance(obj, (np.int32, np.int64)):
return int(obj)
elif hasattr(obj, 'item'): # For torch tensors and numpy scalars
return obj.item()
else:
return obj
def __init__(self, args):
"""
Initialize Trainer with datasets, model, and hyperparameters.
Args:
args: Argument namespace containing all training parameters
"""
self.args = args
# Set device
self.device = torch.device(f"cuda:{args.device_id}" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Using device: {self.device}")
# Create data loaders
use_embeddings2 = False if args.encoder == "esmc" else True
start_time = time.time()
self.train_loader, self.test_loader = create_data_loader(
radii=args.radii,
batch_size=args.batch_size,
undersample=args.undersample,
zero_ratio=args.zero_ratio,
seed=args.seed,
use_embeddings2=use_embeddings2
)
self.val_loader = self.test_loader
end_time = time.time()
print(f"[INFO] Data loading time: {end_time - start_time:.2f} seconds")
print(f"[INFO] Train samples: {len(self.train_loader.dataset)}")
print(f"[INFO] Test samples: {len(self.test_loader.dataset)}")
# Create directories
timestamp = getattr(args, 'timestamp', None)
if timestamp is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self.results_dir = Path(f"{BASE_DIR}/results/ReCEP/{timestamp}")
self.results_dir.mkdir(parents=True, exist_ok=True)
self.model_dir = Path(f"{BASE_DIR}/models/ReCEP/{timestamp}")
self.model_dir.mkdir(parents=True, exist_ok=True)
# Model save paths
self.best_auprc_model_path = Path(self.model_dir / "best_auprc_model.bin")
self.best_mcc_model_path = Path(self.model_dir / "best_mcc_model.bin")
self.last_model_path = Path(self.model_dir / "last_model.bin")
# Initialize model using the flexible model loader
self.model = get_model(args)
# Store finetune configuration for later use
self.is_finetune_mode = getattr(args, 'mode', 'train') == 'finetune'
if self.is_finetune_mode:
print("[INFO] Finetune mode enabled - will load pretrained model and freeze most parameters")
# Initialize optimizer
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay
)
# Initialize scheduler
self.scheduler = get_scheduler(
args=args,
optimizer=self.optimizer,
num_samples=len(self.train_loader.dataset)
)
# Initialize loss function
self.criterion = CLoss(
region_loss_type=args.region_loss_type,
region_weight=args.region_weight,
node_loss_type=args.node_loss_type,
node_loss_weight=args.node_loss_weight,
consistency_weight=args.consistency_weight,
consistency_type=args.consistency_type,
threshold=args.threshold,
label_smoothing=args.label_smoothing,
gradnorm=getattr(args, 'gradnorm', False),
gradnorm_alpha=getattr(args, 'gradnorm_alpha', 1.5),
**{k: v for k, v in vars(args).items() if k in ['alpha', 'pos_weight', 'reg_weight', 'gamma_high_cls', 'cls_type', 'regression_type', 'weight_mode']}
)
# GradNorm parameters
self.gradnorm_enabled = getattr(args, 'gradnorm', False)
self.gradnorm_update_freq = getattr(args, 'gradnorm_update_freq', 10)
self.gradnorm_step_counter = 0
if self.gradnorm_enabled:
print(f"[INFO] GradNorm enabled with alpha={getattr(args, 'gradnorm_alpha', 1.5)}, update frequency={self.gradnorm_update_freq}")
# Initialize optimizer for task weights
if hasattr(self.criterion, 'log_w_region') and hasattr(self.criterion, 'log_w_node'):
self.task_weight_optimizer = torch.optim.Adam([
self.criterion.log_w_region,
self.criterion.log_w_node
], lr=0.025) # Higher learning rate for task weights
print(f"[INFO] GradNorm task weight optimizer initialized")
else:
print(f"[WARNING] GradNorm enabled but task weight parameters not found in criterion")
self.gradnorm_enabled = False
# Mixed precision setup
self.mixed_precision = args.mixed_precision
if self.mixed_precision:
self.scaler = GradScaler('cuda')
print("[INFO] Using mixed precision training")
else:
self.scaler = None
print("[INFO] Using full precision training")
# Training parameters
self.patience = args.patience
self.threshold = args.threshold
self.log_steps = 50
self.max_grad_norm = getattr(args, 'max_grad_norm', 1.0) # Gradient clipping threshold
# Training history
self.train_history = []
self.val_history = []
# Save configuration
if not self.is_finetune_mode:
self.save_config()
self.best_threshold = None
def train(self):
"""Train the model with early stopping and checkpointing."""
print("[INFO] Starting training...")
self.model.to(self.device)
# Also move loss function to device (important for GradNorm parameters)
self.criterion.to(self.device)
# Print model info
self.model.print_param_count()
# Initialize tracking variables
best_val_auprc = 0.0
best_val_mcc = 0.0
patience_counter = -5 # 5 epochs avoid stopping too early
start_time = time.time()
for epoch in range(self.args.num_epoch):
print(f"\n=== Epoch {epoch + 1}/{self.args.num_epoch} ===")
# Training phase
train_metrics = self._train_epoch(epoch)
self.train_history.append(train_metrics)
# Validation phase
if self.val_loader:
val_metrics = self._validate_epoch(epoch)
self.val_history.append(val_metrics)
# Check for improvement
current_auprc = val_metrics['node_auprc']
current_mcc = val_metrics['graph_mcc']
improved = False
# Save best AUPRC model
if current_auprc > best_val_auprc:
best_val_auprc = current_auprc
print(f"[INFO] New best AUPRC: {best_val_auprc:.4f} - Model saved")
self.model.save(self.best_auprc_model_path, threshold=val_metrics['threshold'])
improved = True
# Save best F1 model
if current_mcc > best_val_mcc:
best_val_mcc = current_mcc
print(f"[INFO] New best MCC: {best_val_mcc:.4f} - Model saved")
self.model.save(self.best_mcc_model_path, threshold=val_metrics['threshold'])
self.best_threshold = val_metrics['threshold']
improved = True
# Early stopping logic
if improved:
patience_counter = 0
else:
patience_counter += 1
print(f"[INFO] No improvement for {patience_counter}/{self.patience} epochs")
if patience_counter >= self.patience:
print(f"[INFO] Early stopping triggered after {epoch + 1} epochs")
break
# Save last model
self.model.save(self.last_model_path)
end_time = time.time()
print(f"\n[INFO] Training completed for {end_time - start_time:.2f} seconds")
print(f"[INFO] Best validation AUPRC: {best_val_auprc:.4f}")
print(f"[INFO] Best validation MCC: {best_val_mcc:.4f}")
# Save training history
self._save_training_history()
# Evaluate the model
print("\n" + "="*80)
print("[INFO] Evaluating best AUPRC model...")
results = evaluate_ReCEP(
model_path=self.best_auprc_model_path,
device_id=self.args.device_id,
radius=self.args.radius,
threshold=self.best_threshold,
k=self.args.k,
verbose=True,
split="test",
encoder=self.args.encoder
)
print("="*80)
print("\n[INFO] Evaluating best MCC model...")
results = evaluate_ReCEP(
model_path=self.best_mcc_model_path,
device_id=self.args.device_id,
radius=self.args.radius,
threshold=self.best_threshold,
k=self.args.k,
verbose=True,
split="test",
encoder=self.args.encoder
)
def finetune(self, pretrained_model_path=None):
"""
Finetune the node prediction parameters of the model.
Args:
pretrained_model_path: Path to pretrained model. If None, uses best_auprc_model_path
"""
print("[INFO] Starting finetuning...")
# Load pretrained model
if pretrained_model_path is None:
# Prefer AUPRC model, fallback to MCC model
if self.best_mcc_model_path.exists():
pretrained_model_path = self.best_mcc_model_path
print(f"[INFO] Using best MCC model for finetuning")
elif self.best_auprc_model_path.exists():
pretrained_model_path = self.best_auprc_model_path
print(f"[INFO] Using best AUPRC model for finetuning")
else:
raise FileNotFoundError(
f"No pretrained model found. Please train the model first or provide a pretrained_model_path.\n"
f"Expected paths: {self.best_auprc_model_path} or {self.best_mcc_model_path}"
)
else:
# Convert string path to Path object if needed
pretrained_model_path = Path(pretrained_model_path)
if not pretrained_model_path.exists():
raise FileNotFoundError(f"Pretrained model not found at {pretrained_model_path}")
print(f"[INFO] Loading pretrained model from {pretrained_model_path}")
self.model, loaded_threshold = ReCEP.load(pretrained_model_path, device=self.device)
print(f"[INFO] Loaded model with threshold: {loaded_threshold}")
# Freeze all parameters first
for param in self.model.parameters():
param.requires_grad = False
# Unfreeze specific modules for node prediction
trainable_modules = ['node_gate', 'node_classifier'] # Correct module names from ReCEP model
trainable_params = []
frozen_params = []
print("[INFO] Setting trainable parameters:")
for name, param in self.model.named_parameters():
is_trainable = False
for module_name in trainable_modules:
if module_name in name:
param.requires_grad = True
is_trainable = True
trainable_params.append(param)
print(f" - {name} (trainable)")
break
if not is_trainable:
frozen_params.append(param)
print(f"[INFO] Trainable parameters: {len(trainable_params)}")
print(f"[INFO] Frozen parameters: {len(frozen_params)}")
# Print model info
self.model.print_param_count()
if len(trainable_params) == 0:
raise ValueError("No trainable parameters found! Check trainable_modules list.")
# Reinitialize optimizer with only trainable parameters
self.optimizer = torch.optim.AdamW(
trainable_params,
lr=self.args.lr * 0.1, # Use smaller learning rate for finetuning
weight_decay=self.args.weight_decay
)
# Reinitialize scheduler
self.args.schedule_type = "cosine"
self.args.warmup_ratio = 0.0
self.scheduler = get_scheduler(
args=self.args,
optimizer=self.optimizer,
num_samples=len(self.train_loader.dataset)
)
# Adjust loss function weights for node-focused training
print("[INFO] Adjusting loss weights for node-focused training")
self.criterion.region_weight = 0.0
self.criterion.node_loss_weight = 1.0
self.gradnorm_enabled = False
# Reset training history for finetuning
self.train_history = []
self.val_history = []
# Update model save paths for finetuning
self.best_finetuned_model_path = self.model_dir / "best_finetuned_model.bin"
self.last_finetuned_model_path = self.model_dir / "last_finetuned_model.bin"
print("[INFO] Finetune setup completed. Use train() method to start finetuning.")
# Start finetuning training loop
self._finetune_train()
print("\n" + "="*80)
print("[INFO] Evaluating finetuned model...")
results = evaluate_ReCEP(
model_path=self.best_finetuned_model_path,
device_id=self.args.device_id,
radius=self.args.radius,
k=self.args.k,
verbose=True,
split="test"
)
def _finetune_train(self):
"""Training loop specifically for finetuning."""
print("[INFO] Starting finetune training loop...")
self.model.to(self.device)
# Also move loss function to device (important for GradNorm parameters)
self.criterion.to(self.device)
# Initialize tracking variables
best_val_auprc = 0.0
patience_counter = 0 # No initial patience for finetuning
start_time = time.time()
# Use fewer epochs for finetuning
finetune_epochs = self.args.num_epoch # Max 60 epochs for finetuning
for epoch in range(finetune_epochs):
print(f"\n=== Finetune Epoch {epoch + 1}/{finetune_epochs} ===")
# Training phase
train_metrics = self._train_epoch(epoch)
self.train_history.append(train_metrics)
# Validation phase
if self.val_loader:
val_metrics = self._validate_epoch(epoch)
self.val_history.append(val_metrics)
# Check for improvement (focus on node metrics)
current_auprc = val_metrics['node_auprc']
improved = False
# Save best AUPRC model
if current_auprc > best_val_auprc:
best_val_auprc = current_auprc
print(f"[INFO] New best AUPRC: {best_val_auprc:.4f} - Finetuned model saved")
self.model.save(self.best_finetuned_model_path, threshold=val_metrics['threshold'])
improved = True
# Early stopping logic (shorter patience for finetuning)
if improved:
patience_counter = 0
else:
patience_counter += 1
print(f"[INFO] No improvement for {patience_counter}/{min(self.patience, 10)} epochs")
if patience_counter >= min(self.patience, 10): # Shorter patience
print(f"[INFO] Early stopping triggered after {epoch + 1} epochs")
break
# Save last model
self.model.save(self.last_finetuned_model_path)
end_time = time.time()
print(f"\n[INFO] Finetuning completed in {end_time - start_time:.2f} seconds")
print(f"[INFO] Best validation AUPRC: {best_val_auprc:.4f}")
# Save training history
self._save_training_history()
def _train_epoch(self, epoch):
"""Train for one epoch."""
self.model.train()
# Initialize accumulators for loss components and predictions
loss_accumulators = defaultdict(float)
all_global_preds = []
all_global_labels = []
all_node_preds = []
all_node_labels = []
pbar = tqdm(self.train_loader, desc=f"Training Epoch {epoch + 1}")
for batch_idx, batch in enumerate(pbar):
batch = batch.to(self.device)
# Prepare targets
targets = {
'y': batch.y,
'y_node': batch.y_node
}
# Forward pass
if self.mixed_precision:
with autocast('cuda'):
outputs = self.model(batch)
if self.gradnorm_enabled:
loss, loss_dict, individual_losses = self.criterion(outputs, targets, batch.batch, return_individual_losses=True)
else:
loss, loss_dict = self.criterion(outputs, targets, batch.batch)
# Check for NaN/Inf in loss
if torch.isnan(loss) or torch.isinf(loss):
# print(f"[WARNING] NaN/Inf loss detected at batch {batch_idx}, skipping batch")
continue
# Backward pass and GradNorm update
self.optimizer.zero_grad()
# GradNorm update BEFORE main backward to preserve computation graph
if self.gradnorm_enabled and self.gradnorm_step_counter % self.gradnorm_update_freq == 0:
# Update GradNorm weights first
self.task_weight_optimizer.zero_grad()
gradnorm_info = self.criterion.update_gradnorm_weights(individual_losses, self.model)
if gradnorm_info:
loss_dict.update(gradnorm_info)
self.task_weight_optimizer.step()
# Main backward pass
self.scaler.scale(loss).backward()
# Gradient clipping
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
outputs = self.model(batch)
if self.gradnorm_enabled:
loss, loss_dict, individual_losses = self.criterion(outputs, targets, batch.batch, return_individual_losses=True)
else:
loss, loss_dict = self.criterion(outputs, targets, batch.batch)
# Check for NaN/Inf in loss
if torch.isnan(loss) or torch.isinf(loss):
# print(f"[WARNING] NaN/Inf loss detected at batch {batch_idx}, skipping batch")
continue
# Backward pass and GradNorm update
self.optimizer.zero_grad()
loss.backward(retain_graph=self.gradnorm_enabled)
# GradNorm update BEFORE main backward to preserve computation graph
if self.gradnorm_enabled and self.gradnorm_step_counter % self.gradnorm_update_freq == 0:
# Update GradNorm weights first
self.task_weight_optimizer.zero_grad()
gradnorm_info = self.criterion.update_gradnorm_weights(individual_losses, self.model)
if gradnorm_info:
loss_dict.update(gradnorm_info)
self.task_weight_optimizer.step()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.optimizer.step()
if hasattr(self.scheduler, 'step'):
self.scheduler.step()
# Update GradNorm step counter
if self.gradnorm_enabled:
self.gradnorm_step_counter += 1
# Accumulate loss components from loss_dict
for key, value in loss_dict.items():
if key.startswith('loss/'):
loss_accumulators[key] += value
# Extract predictions from loss_dict (already detached)
global_preds = torch.sigmoid(loss_dict['logits/global']).cpu().numpy()
global_labels = batch.y.detach().cpu().numpy()
node_preds = torch.sigmoid(loss_dict['logits/node']).cpu().numpy()
node_labels = batch.y_node.detach().cpu().numpy()
all_global_preds.extend(global_preds)
all_global_labels.extend(global_labels)
all_node_preds.extend(node_preds)
all_node_labels.extend(node_labels)
# Update progress bar with detailed loss info
pbar.set_postfix({
'Total': f"{loss_dict['loss/total']:.4f}",
'Region': f"{loss_dict['loss/region']:.4f}",
'Node': f"{loss_dict['loss/node']:.4f}",
'LR': f"{self.optimizer.param_groups[0]['lr']:.2e}",
'RW': f"{loss_dict['loss/region_weight']:.3f}" if self.gradnorm_enabled else "",
'NW': f"{loss_dict['loss/node_weight']:.3f}" if self.gradnorm_enabled else ""
})
# Calculate average loss components
num_batches = len(self.train_loader)
avg_losses = {key: value / num_batches for key, value in loss_accumulators.items()}
# Convert to numpy arrays
all_global_preds = np.array(all_global_preds)
all_global_labels = np.array(all_global_labels)
all_node_preds = np.array(all_node_preds)
all_node_labels = np.array(all_node_labels)
# Calculate metrics
graph_metrics = calculate_graph_metrics(all_global_preds, all_global_labels, self.threshold)
node_metrics = calculate_node_metrics(all_node_preds, all_node_labels, find_threshold=False)
# Combine all metrics
metrics = {
'lr': self.optimizer.param_groups[0]['lr'],
**avg_losses, # Include all loss components
}
# Add prefixed metrics
for k, v in graph_metrics.items():
metrics[f'graph_{k}'] = v
for k, v in node_metrics.items():
metrics[f'node_{k}'] = v
# Print epoch summary with detailed loss breakdown
print(format_loss_info(avg_losses, "Train Loss - "))
print(f"Train Graph MCC: {graph_metrics['mcc']:.4f}, F1: {graph_metrics['f1']:.4f}, Recall: {graph_metrics['recall']:.4f}, Precision: {graph_metrics['precision']:.4f}")
print(f"Train Node AUPRC: {node_metrics['auprc']:.4f}, AUROC: {node_metrics['auroc']:.4f}, F1: {node_metrics['f1']:.4f} (threshold: {node_metrics['threshold_used']:.3f})")
# Print GradNorm info if enabled
if self.gradnorm_enabled:
print(f"GradNorm Weights - Region: {avg_losses.get('loss/region_weight', 1.0):.3f}, Node: {avg_losses.get('loss/node_weight', 1.0):.3f}")
if 'gradnorm/region_grad_norm' in avg_losses:
print(f"GradNorm Info - Region GradNorm: {avg_losses.get('gradnorm/region_grad_norm', 0):.4f}, Node GradNorm: {avg_losses.get('gradnorm/node_grad_norm', 0):.4f}")
return metrics
def _validate_epoch(self, epoch):
"""Validate for one epoch."""
self.model.eval()
# Initialize accumulators for loss components and predictions
loss_accumulators = defaultdict(float)
all_global_preds = []
all_global_labels = []
all_node_preds = []
all_node_labels = []
with torch.no_grad():
pbar = tqdm(self.val_loader, desc=f"Validation Epoch {epoch + 1}")
for batch in pbar:
batch = batch.to(self.device)
# Prepare targets
targets = {
'y': batch.y,
'y_node': batch.y_node
}
# Forward pass
if self.mixed_precision:
with autocast('cuda'):
outputs = self.model(batch)
loss, loss_dict = self.criterion(outputs, targets, batch.batch)
else:
outputs = self.model(batch)
loss, loss_dict = self.criterion(outputs, targets, batch.batch)
# Accumulate loss components from loss_dict
for key, value in loss_dict.items():
if key.startswith('loss/'):
loss_accumulators[key] += value
# Extract predictions from loss_dict (already detached)
global_preds = torch.sigmoid(loss_dict['logits/global']).cpu().numpy()
global_labels = batch.y.detach().cpu().numpy()
node_preds = torch.sigmoid(loss_dict['logits/node']).cpu().numpy()
node_labels = batch.y_node.detach().cpu().numpy()
all_global_preds.extend(global_preds)
all_global_labels.extend(global_labels)
all_node_preds.extend(node_preds)
all_node_labels.extend(node_labels)
# Update progress bar with detailed loss info
pbar.set_postfix({
'Total': f"{loss_dict['loss/total']:.4f}",
'Region': f"{loss_dict['loss/region']:.4f}",
'Node': f"{loss_dict['loss/node']:.4f}"
})
# Calculate average loss components
num_batches = len(self.val_loader)
avg_losses = {key: value / num_batches for key, value in loss_accumulators.items()}
# Convert to numpy arrays
all_global_preds = np.array(all_global_preds)
all_global_labels = np.array(all_global_labels)
all_node_preds = np.array(all_node_preds)
all_node_labels = np.array(all_node_labels)
# Calculate metrics
graph_metrics = calculate_graph_metrics(all_global_preds, all_global_labels, self.threshold)
node_metrics = calculate_node_metrics(all_node_preds, all_node_labels, find_threshold=True)
# Combine all metrics
metrics = {
'epoch': epoch,
**avg_losses, # Include all loss components
}
# Add prefixed metrics
for k, v in graph_metrics.items():
metrics[f'graph_{k}'] = v
for k, v in node_metrics.items():
metrics[f'node_{k}'] = v
metrics['threshold'] = node_metrics['threshold_used']
# Print epoch summary with detailed loss breakdown
print(format_loss_info(avg_losses, "Val Loss - "))
print(f"Val Graph MCC: {graph_metrics['mcc']:.4f}, F1: {graph_metrics['f1']:.4f}, Recall: {graph_metrics['recall']:.4f}, Precision: {graph_metrics['precision']:.4f}")
print(f"Val Node AUPRC: {node_metrics['auprc']:.4f}, AUROC: {node_metrics['auroc']:.4f}, F1: {node_metrics['f1']:.4f}, MCC: {node_metrics['mcc']:.4f}, Precision: {node_metrics['precision']:.4f}, Recall: {node_metrics['recall']:.4f} (threshold: {node_metrics['threshold_used']:.3f})")
return metrics
def evaluate(self, model_path=None, split='test'):
"""
Evaluate the model on test set.
Args:
model_path: Path to model checkpoint (if None, uses current model)
split: Which split to evaluate ('test', 'val', 'train')
"""
print(f"[INFO] Evaluating on {split} set...")
# Load model if path provided
if model_path:
print(f"[INFO] Loading model from {model_path}")
self.model, _ = ReCEP.load(model_path, device=self.device)
# Select data loader
if split == 'test':
data_loader = self.test_loader
elif split == 'val' and self.val_loader:
data_loader = self.val_loader
elif split == 'train':
data_loader = self.train_loader
else:
raise ValueError(f"Invalid split: {split}")
self.model.eval()
# Initialize accumulators for loss components and predictions
loss_accumulators = defaultdict(float)
all_global_preds = []
all_global_labels = []
all_node_preds = []
all_node_labels = []
with torch.no_grad():
pbar = tqdm(data_loader, desc=f"Evaluating {split}")
for batch in pbar:
batch = batch.to(self.device)
# Prepare targets
targets = {
'y': batch.y,
'y_node': batch.y_node
}
# Forward pass
if self.mixed_precision:
with autocast('cuda'):
outputs = self.model(batch)
loss, loss_dict = self.criterion(outputs, targets, batch.batch)
else:
outputs = self.model(batch)
loss, loss_dict = self.criterion(outputs, targets, batch.batch)
# Accumulate loss components from loss_dict
for key, value in loss_dict.items():
if key.startswith('loss/'):
loss_accumulators[key] += value
# Extract predictions from loss_dict (already detached)
global_preds = torch.sigmoid(loss_dict['logits/global']).cpu().numpy()
global_labels = batch.y.detach().cpu().numpy()
node_preds = torch.sigmoid(loss_dict['logits/node']).cpu().numpy()
node_labels = batch.y_node.detach().cpu().numpy()
all_global_preds.extend(global_preds)
all_global_labels.extend(global_labels)
all_node_preds.extend(node_preds)
all_node_labels.extend(node_labels)
# Calculate average loss components
num_batches = len(data_loader)
avg_losses = {key: value / num_batches for key, value in loss_accumulators.items()}
# Convert to numpy arrays
all_global_preds = np.array(all_global_preds)
all_global_labels = np.array(all_global_labels)
all_node_preds = np.array(all_node_preds)
all_node_labels = np.array(all_node_labels)
# Calculate metrics (use optimal threshold for evaluation)
graph_metrics = calculate_graph_metrics(all_global_preds, all_global_labels, self.threshold)
node_metrics = calculate_node_metrics(all_node_preds, all_node_labels, find_threshold=True)
# Print results with detailed loss breakdown
print(f"\n=== {split.upper()} RESULTS ===")
print(format_loss_info(avg_losses, "Loss - "))
print(f"\nGraph-level Metrics (Recall Prediction):")
print(f" MCC: {graph_metrics['mcc']:.4f}")
print(f" F1: {graph_metrics['f1']:.4f}")
print(f" Recall: {graph_metrics['recall']:.4f}")
print(f" Precision: {graph_metrics['precision']:.4f}")
print(f" MSE: {graph_metrics['mse']:.4f}")
print(f" MAE: {graph_metrics['mae']:.4f}")
print(f" R²: {graph_metrics['r2']:.4f}")
print(f" Pearson r: {graph_metrics['pearson_r']:.4f}")
print(f"\nNode-level Metrics (Epitope Prediction):")
print(f" AUPRC: {node_metrics['auprc']:.4f}")
print(f" AUROC: {node_metrics['auroc']:.4f}")
print(f" F1: {node_metrics['f1']:.4f} (optimal threshold: {node_metrics['best_threshold']:.3f})")
print(f" MCC: {node_metrics['mcc']:.4f}")
print(f" Precision: {node_metrics['precision']:.4f}")
print(f" Recall: {node_metrics['recall']:.4f}")
print("=" * 40)
# Save results with detailed loss components
results = {
'split': split,
'loss_components': self.convert_to_serializable(avg_losses),
'graph_metrics': self.convert_to_serializable(graph_metrics),
'node_metrics': self.convert_to_serializable(node_metrics),
'model_path': str(model_path) if model_path else None
}
results_file = self.results_dir / f"{split}_results.json"
with open(results_file, 'w') as f:
json.dump(results, f, indent=2)
return results
def _save_training_history(self):
"""Save training history to file and generate loss plots."""
history = {
'train_history': self.convert_to_serializable(self.train_history),
'val_history': self.convert_to_serializable(self.val_history)
}
history_file = self.results_dir / "training_history.json"
with open(history_file, 'w') as f:
json.dump(history, f, indent=2)
print(f"[INFO] Training history saved to {history_file}")
# Generate loss plots
self._plot_training_curves()
def _plot_training_curves(self):
"""Generate and save training curves for losses."""
if not self.train_history:
print("[WARNING] No training history to plot")
return
# Extract epochs and loss data
train_epochs = list(range(1, len(self.train_history) + 1))
train_total_loss = [h.get('loss/total', 0) for h in self.train_history]
train_region_loss = [h.get('loss/region', 0) for h in self.train_history]
train_node_loss = [h.get('loss/node', 0) for h in self.train_history]
# Create subplots
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle('Training and Validation Loss Curves', fontsize=16, fontweight='bold')
# Plot total loss
axes[0, 0].plot(train_epochs, train_total_loss, 'b-', label='Train', linewidth=2, marker='o', markersize=3)
if self.val_history:
val_epochs = list(range(1, len(self.val_history) + 1))
val_total_loss = [h.get('loss/total', 0) for h in self.val_history]
axes[0, 0].plot(val_epochs, val_total_loss, 'r-', label='Validation', linewidth=2, marker='s', markersize=3)
axes[0, 0].set_title('Total Loss', fontweight='bold')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)
# Plot region loss (global loss)
axes[0, 1].plot(train_epochs, train_region_loss, 'b-', label='Train', linewidth=2, marker='o', markersize=3)
if self.val_history:
val_region_loss = [h.get('loss/region', 0) for h in self.val_history]
axes[0, 1].plot(val_epochs, val_region_loss, 'r-', label='Validation', linewidth=2, marker='s', markersize=3)
axes[0, 1].set_title('Region Loss (Global)', fontweight='bold')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)
# Plot node loss
axes[1, 0].plot(train_epochs, train_node_loss, 'b-', label='Train', linewidth=2, marker='o', markersize=3)
if self.val_history:
val_node_loss = [h.get('loss/node', 0) for h in self.val_history]
axes[1, 0].plot(val_epochs, val_node_loss, 'r-', label='Validation', linewidth=2, marker='s', markersize=3)
axes[1, 0].set_title('Node Loss', fontweight='bold')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Loss')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)
# Plot combined comparison
axes[1, 1].plot(train_epochs, train_total_loss, 'b-', label='Train Total', linewidth=2, alpha=0.8)
axes[1, 1].plot(train_epochs, train_region_loss, 'g-', label='Train Region', linewidth=2, alpha=0.8)
axes[1, 1].plot(train_epochs, train_node_loss, 'orange', label='Train Node', linewidth=2, alpha=0.8)
if self.val_history:
axes[1, 1].plot(val_epochs, val_total_loss, 'r--', label='Val Total', linewidth=2, alpha=0.8)
axes[1, 1].plot(val_epochs, val_region_loss, 'cyan', linestyle='--', label='Val Region', linewidth=2, alpha=0.8)
axes[1, 1].plot(val_epochs, val_node_loss, 'magenta', linestyle='--', label='Val Node', linewidth=2, alpha=0.8)
axes[1, 1].set_title('All Losses Comparison', fontweight='bold')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Loss')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)
# Adjust layout
plt.tight_layout()
# Save the plot
plot_file = self.results_dir / "training_loss_curves.png"
plt.savefig(plot_file, dpi=300, bbox_inches='tight')
plt.close()
print(f"[INFO] Training loss curves saved to {plot_file}")
# Also create individual plots for better readability
self._plot_individual_loss_curves()
def _plot_individual_loss_curves(self):
"""Generate individual loss curve plots for each loss type."""
if not self.train_history:
return
# Extract epochs and loss data
train_epochs = list(range(1, len(self.train_history) + 1))
loss_types = [
('total', 'Total Loss'),
('region', 'Region Loss (Global)'),
('node', 'Node Loss')
]
for loss_key, loss_title in loss_types:
fig, ax = plt.subplots(1, 1, figsize=(10, 6))
# Training loss
train_loss = [h.get(f'loss/{loss_key}', 0) for h in self.train_history]
ax.plot(train_epochs, train_loss, 'b-', label='Training', linewidth=2.5, marker='o', markersize=4)
# Validation loss
if self.val_history:
val_epochs = list(range(1, len(self.val_history) + 1))
val_loss = [h.get(f'loss/{loss_key}', 0) for h in self.val_history]
ax.plot(val_epochs, val_loss, 'r-', label='Validation', linewidth=2.5, marker='s', markersize=4)
# Formatting
ax.set_title(f'{loss_title} Over Training', fontsize=14, fontweight='bold')
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
# Add some styling
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_linewidth(0.5)
ax.spines['bottom'].set_linewidth(0.5)
# Save individual plot
plot_file = self.results_dir / f"{loss_key}_loss_curve.png"
plt.tight_layout()
plt.savefig(plot_file, dpi=300, bbox_inches='tight')
plt.close()
print(f"[INFO] {loss_title} curve saved to {plot_file}")
def save_config(self):
"""Save training configuration to JSON file."""
config = {
'model_config': self.model.get_config(), # Use model's get_config method
'training_config': {
'num_epoch': self.args.num_epoch,
'batch_size': self.args.batch_size,
'lr': self.args.lr,
'weight_decay': self.args.weight_decay,
'patience': self.args.patience,
'threshold': self.args.threshold,
'mixed_precision': self.args.mixed_precision,
'device_id': self.args.device_id
},
'data_config': {
'radii': self.args.radii,
'zero_ratio': self.args.zero_ratio,
'undersample': self.args.undersample,
'seed': self.args.seed
},
'loss_config': {
'region_loss_type': self.args.region_loss_type,
'reg_weight': self.args.reg_weight,
'cls_type': self.args.cls_type,
'gamma_high_cls': self.args.gamma_high_cls,
'regression_type': self.args.regression_type,
'node_loss_type': self.args.node_loss_type,
'alpha': self.args.alpha,
'gamma': self.args.gamma,
'pos_weight': self.args.pos_weight,
'node_loss_weight': self.args.node_loss_weight,
'region_weight': self.args.region_weight,
'consistency_weight': self.args.consistency_weight,
'consistency_type': self.args.consistency_type,
'label_smoothing': self.args.label_smoothing,
'gradnorm': getattr(self.args, 'gradnorm', False),
'gradnorm_alpha': getattr(self.args, 'gradnorm_alpha', 1.5),
'gradnorm_update_freq': getattr(self.args, 'gradnorm_update_freq', 10)
}
}
config_file = self.model_dir / "config.json"
with open(config_file, 'w') as f:
json.dump(config, f, indent=2)
print(f"[INFO] Configuration saved to {config_file}")