File size: 2,945 Bytes
ea88892 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import inspect
import torch.nn as nn
from typing import List, Optional, Union
from mmengine.optim import DefaultOptimWrapperConstructor, OptimWrapper
from mmengine.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
OPTIMIZERS)
def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
decay = []
no_decay = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list or 'diffloss' in name:
no_decay.append(param) # no weight decay on bias, norm and diffloss
else:
decay.append(param)
num_decay_params = sum(p.numel() for p in decay)
num_nodecay_params = sum(p.numel() for p in no_decay)
print(f"num decayed parameter tensors: {len(decay)}, with {num_decay_params:,} parameters")
print(f"num non-decayed parameter tensors: {len(no_decay)}, with {num_nodecay_params:,} parameters")
return [
{'params': no_decay, 'weight_decay': 0.},
{'params': decay, 'weight_decay': weight_decay}]
class MAROptimWrapperConstructor(DefaultOptimWrapperConstructor):
def __call__(self, model: nn.Module) -> OptimWrapper:
if hasattr(model, 'module'):
model = model.module
optim_wrapper_cfg = self.optim_wrapper_cfg.copy()
optim_wrapper_cfg.setdefault('type', 'OptimWrapper')
optimizer_cfg = self.optimizer_cfg.copy()
optimizer_cls = self.optimizer_cfg['type']
# Optimizer like HybridAdam in colossalai requires the argument name
# `model_params` rather than `params`. Here we get the first argument
# name and fill it with the model parameters.
if isinstance(optimizer_cls, str):
with OPTIMIZERS.switch_scope_and_registry(None) as registry:
optimizer_cls = registry.get(self.optimizer_cfg['type'])
fisrt_arg_name = next(
iter(inspect.signature(optimizer_cls).parameters))
# import pdb; pdb.set_trace()
param_groups = add_weight_decay(model, optimizer_cfg.pop('weight_decay', 0))
optimizer_cfg[fisrt_arg_name] = param_groups
optimizer = OPTIMIZERS.build(optimizer_cfg)
# # if no paramwise option is specified, just use the global setting
# if not self.paramwise_cfg:
# optimizer_cfg[fisrt_arg_name] = model.parameters()
# optimizer = OPTIMIZERS.build(optimizer_cfg)
# else:
# # set param-wise lr and weight decay recursively
# params: List = []
# self.add_params(params, model)
# optimizer_cfg[fisrt_arg_name] = params
# optimizer = OPTIMIZERS.build(optimizer_cfg)
optim_wrapper = OPTIM_WRAPPERS.build(
optim_wrapper_cfg, default_args=dict(optimizer=optimizer))
return optim_wrapper
|