File size: 451 Bytes
14c20ce
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch


def clip_gradients(model, max_norm=1.0, norm_type=2):
    """
    Clip the gradients of the model parameters in place.

    Args:
        model (torch.nn.Module): The model whose gradients will be clipped.
        max_norm (float): Maximum allowed norm of the gradients.
        norm_type (int or float): Type of the used p-norm (e.g., 2 for L2 norm).
    """
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm, norm_type)