vrp-shanghai-transformer / utils /gradient_clipping.py
a-ragab-h-m's picture
Update utils/gradient_clipping.py
14c20ce verified
raw
history blame contribute delete
451 Bytes
import torch
def clip_gradients(model, max_norm=1.0, norm_type=2):
"""
Clip the gradients of the model parameters in place.
Args:
model (torch.nn.Module): The model whose gradients will be clipped.
max_norm (float): Maximum allowed norm of the gradients.
norm_type (int or float): Type of the used p-norm (e.g., 2 for L2 norm).
"""
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm, norm_type)