|
import inspect |
|
from torch.optim import AdamW |
|
|
|
|
|
class CustomAdamW(AdamW): |
|
def __init__(self, params, weight_decay, *args, **kwargs): |
|
import pdb; pdb.set_trace() |
|
if isinstance(params, dict): |
|
params = [p for p in params.values() if p.requires_grad] |
|
else: |
|
params = [p for p in params if p.requires_grad] |
|
|
|
|
|
|
|
decay_params = [p for p in params if p.dim() >= 2] |
|
nodecay_params = [p for p in params if p.dim() < 2] |
|
optim_groups = [ |
|
{'params': decay_params, 'weight_decay': weight_decay}, |
|
{'params': nodecay_params, 'weight_decay': 0.0} |
|
] |
|
num_decay_params = sum(p.numel() for p in decay_params) |
|
num_nodecay_params = sum(p.numel() for p in nodecay_params) |
|
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") |
|
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
super().__init__(params=optim_groups, *args, **kwargs) |
|
|