Update utils/gradient_clipping.py
Browse files- utils/gradient_clipping.py +13 -0
utils/gradient_clipping.py
CHANGED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def clip_gradients(model, max_norm=1.0, norm_type=2):
|
5 |
+
"""
|
6 |
+
Clip the gradients of the model parameters in place.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
model (torch.nn.Module): The model whose gradients will be clipped.
|
10 |
+
max_norm (float): Maximum allowed norm of the gradients.
|
11 |
+
norm_type (int or float): Type of the used p-norm (e.g., 2 for L2 norm).
|
12 |
+
"""
|
13 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm, norm_type)
|