File size: 451 Bytes
14c20ce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
import torch
def clip_gradients(model, max_norm=1.0, norm_type=2):
"""
Clip the gradients of the model parameters in place.
Args:
model (torch.nn.Module): The model whose gradients will be clipped.
max_norm (float): Maximum allowed norm of the gradients.
norm_type (int or float): Type of the used p-norm (e.g., 2 for L2 norm).
"""
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm, norm_type)
|