a-ragab-h-m commited on
Commit
14c20ce
·
verified ·
1 Parent(s): fcf5dbb

Update utils/gradient_clipping.py

Browse files
Files changed (1) hide show
  1. utils/gradient_clipping.py +13 -0
utils/gradient_clipping.py CHANGED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def clip_gradients(model, max_norm=1.0, norm_type=2):
5
+ """
6
+ Clip the gradients of the model parameters in place.
7
+
8
+ Args:
9
+ model (torch.nn.Module): The model whose gradients will be clipped.
10
+ max_norm (float): Maximum allowed norm of the gradients.
11
+ norm_type (int or float): Type of the used p-norm (e.g., 2 for L2 norm).
12
+ """
13
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm, norm_type)