File size: 2,071 Bytes
1b579d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import logging
from torch import nn

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class GradientClippingHook:
    """Hook to clip gradients during training to prevent explosions."""
    def __init__(self, max_norm=1.0):
        self.max_norm = max_norm

    def __call__(self, module, grad_input, grad_output):
        if grad_input is not None:
            for g in grad_input:
                if g is not None:
                    torch.nn.utils.clip_grad_norm_(g, self.max_norm)
                    logger.debug(f"Gradient clipped to max norm {self.max_norm}")

class LossSpikeDetectionHook:
    """Hook to detect and log loss spikes during training."""
    def __init__(self, threshold=10.0, window=5):
        self.threshold = threshold
        self.window = window
        self.losses = []

    def __call__(self, module, input, output):
        if isinstance(module, nn.Module) and output is not None:
            loss = output.mean() if output.requires_grad else None
            if loss is not None:
                self.losses.append(loss.item())
                if len(self.losses) > self.window:
                    self.losses.pop(0)
                    avg_loss = sum(self.losses[:-1]) / (len(self.losses) - 1)
                    current_loss = self.losses[-1]
                    if current_loss > avg_loss * self.threshold:
                        logger.warning(f"Loss spike detected: {current_loss:.4f} > {avg_loss:.4f} * {self.threshold}")

# Utility function to register hooks on a model
def register_azure_hooks(model):
    grad_clip_hook = GradientClippingHook(max_norm=1.0)
    loss_spike_hook = LossSpikeDetectionHook(threshold=10.0, window=5)
    for module in model.modules():
        module.register_full_backward_hook(grad_clip_hook)
        module.register_forward_hook(loss_spike_hook)
    logger.info("Azure hooks registered: GradientClippingHook and LossSpikeDetectionHook")