import torch | |
def clip_gradients(model, max_norm=1.0, norm_type=2): | |
""" | |
Clip the gradients of the model parameters in place. | |
Args: | |
model (torch.nn.Module): The model whose gradients will be clipped. | |
max_norm (float): Maximum allowed norm of the gradients. | |
norm_type (int or float): Type of the used p-norm (e.g., 2 for L2 norm). | |
""" | |
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm, norm_type) | |