Spaces:
Sleeping
Sleeping
import torch | |
import logging | |
from torch import nn | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
class GradientClippingHook: | |
"""Hook to clip gradients during training to prevent explosions.""" | |
def __init__(self, max_norm=1.0): | |
self.max_norm = max_norm | |
def __call__(self, module, grad_input, grad_output): | |
if grad_input is not None: | |
for g in grad_input: | |
if g is not None: | |
torch.nn.utils.clip_grad_norm_(g, self.max_norm) | |
logger.debug(f"Gradient clipped to max norm {self.max_norm}") | |
class LossSpikeDetectionHook: | |
"""Hook to detect and log loss spikes during training.""" | |
def __init__(self, threshold=10.0, window=5): | |
self.threshold = threshold | |
self.window = window | |
self.losses = [] | |
def __call__(self, module, input, output): | |
if isinstance(module, nn.Module) and output is not None: | |
loss = output.mean() if output.requires_grad else None | |
if loss is not None: | |
self.losses.append(loss.item()) | |
if len(self.losses) > self.window: | |
self.losses.pop(0) | |
avg_loss = sum(self.losses[:-1]) / (len(self.losses) - 1) | |
current_loss = self.losses[-1] | |
if current_loss > avg_loss * self.threshold: | |
logger.warning(f"Loss spike detected: {current_loss:.4f} > {avg_loss:.4f} * {self.threshold}") | |
# Utility function to register hooks on a model | |
def register_azure_hooks(model): | |
grad_clip_hook = GradientClippingHook(max_norm=1.0) | |
loss_spike_hook = LossSpikeDetectionHook(threshold=10.0, window=5) | |
for module in model.modules(): | |
module.register_full_backward_hook(grad_clip_hook) | |
module.register_forward_hook(loss_spike_hook) | |
logger.info("Azure hooks registered: GradientClippingHook and LossSpikeDetectionHook") |