import torch def clip_gradients(model, max_norm=1.0, norm_type=2): """ Clip the gradients of the model parameters in place. Args: model (torch.nn.Module): The model whose gradients will be clipped. max_norm (float): Maximum allowed norm of the gradients. norm_type (int or float): Type of the used p-norm (e.g., 2 for L2 norm). """ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm, norm_type)