Spaces:
dreroc
/
Running on Zero

UniPic / src /optimisers /constructor.py
yichenchenchen's picture
Upload 25 files
ea88892 verified
raw
history blame
2.95 kB
import inspect
import torch.nn as nn
from typing import List, Optional, Union
from mmengine.optim import DefaultOptimWrapperConstructor, OptimWrapper
from mmengine.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
OPTIMIZERS)
def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
decay = []
no_decay = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list or 'diffloss' in name:
no_decay.append(param) # no weight decay on bias, norm and diffloss
else:
decay.append(param)
num_decay_params = sum(p.numel() for p in decay)
num_nodecay_params = sum(p.numel() for p in no_decay)
print(f"num decayed parameter tensors: {len(decay)}, with {num_decay_params:,} parameters")
print(f"num non-decayed parameter tensors: {len(no_decay)}, with {num_nodecay_params:,} parameters")
return [
{'params': no_decay, 'weight_decay': 0.},
{'params': decay, 'weight_decay': weight_decay}]
class MAROptimWrapperConstructor(DefaultOptimWrapperConstructor):
def __call__(self, model: nn.Module) -> OptimWrapper:
if hasattr(model, 'module'):
model = model.module
optim_wrapper_cfg = self.optim_wrapper_cfg.copy()
optim_wrapper_cfg.setdefault('type', 'OptimWrapper')
optimizer_cfg = self.optimizer_cfg.copy()
optimizer_cls = self.optimizer_cfg['type']
# Optimizer like HybridAdam in colossalai requires the argument name
# `model_params` rather than `params`. Here we get the first argument
# name and fill it with the model parameters.
if isinstance(optimizer_cls, str):
with OPTIMIZERS.switch_scope_and_registry(None) as registry:
optimizer_cls = registry.get(self.optimizer_cfg['type'])
fisrt_arg_name = next(
iter(inspect.signature(optimizer_cls).parameters))
# import pdb; pdb.set_trace()
param_groups = add_weight_decay(model, optimizer_cfg.pop('weight_decay', 0))
optimizer_cfg[fisrt_arg_name] = param_groups
optimizer = OPTIMIZERS.build(optimizer_cfg)
# # if no paramwise option is specified, just use the global setting
# if not self.paramwise_cfg:
# optimizer_cfg[fisrt_arg_name] = model.parameters()
# optimizer = OPTIMIZERS.build(optimizer_cfg)
# else:
# # set param-wise lr and weight decay recursively
# params: List = []
# self.add_params(params, model)
# optimizer_cfg[fisrt_arg_name] = params
# optimizer = OPTIMIZERS.build(optimizer_cfg)
optim_wrapper = OPTIM_WRAPPERS.build(
optim_wrapper_cfg, default_args=dict(optimizer=optimizer))
return optim_wrapper