Spaces:
dreroc
/
Running on Zero

UniPic / src /optimisers /constructor.py
yichenchenchen's picture
Upload 25 files
ea88892 verified
import inspect
import torch.nn as nn
from typing import List, Optional, Union
from mmengine.optim import DefaultOptimWrapperConstructor, OptimWrapper
from mmengine.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
OPTIMIZERS)
def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
decay = []
no_decay = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list or 'diffloss' in name:
no_decay.append(param) # no weight decay on bias, norm and diffloss
else:
decay.append(param)
num_decay_params = sum(p.numel() for p in decay)
num_nodecay_params = sum(p.numel() for p in no_decay)
print(f"num decayed parameter tensors: {len(decay)}, with {num_decay_params:,} parameters")
print(f"num non-decayed parameter tensors: {len(no_decay)}, with {num_nodecay_params:,} parameters")
return [
{'params': no_decay, 'weight_decay': 0.},
{'params': decay, 'weight_decay': weight_decay}]
class MAROptimWrapperConstructor(DefaultOptimWrapperConstructor):
def __call__(self, model: nn.Module) -> OptimWrapper:
if hasattr(model, 'module'):
model = model.module
optim_wrapper_cfg = self.optim_wrapper_cfg.copy()
optim_wrapper_cfg.setdefault('type', 'OptimWrapper')
optimizer_cfg = self.optimizer_cfg.copy()
optimizer_cls = self.optimizer_cfg['type']
# Optimizer like HybridAdam in colossalai requires the argument name
# `model_params` rather than `params`. Here we get the first argument
# name and fill it with the model parameters.
if isinstance(optimizer_cls, str):
with OPTIMIZERS.switch_scope_and_registry(None) as registry:
optimizer_cls = registry.get(self.optimizer_cfg['type'])
fisrt_arg_name = next(
iter(inspect.signature(optimizer_cls).parameters))
# import pdb; pdb.set_trace()
param_groups = add_weight_decay(model, optimizer_cfg.pop('weight_decay', 0))
optimizer_cfg[fisrt_arg_name] = param_groups
optimizer = OPTIMIZERS.build(optimizer_cfg)
# # if no paramwise option is specified, just use the global setting
# if not self.paramwise_cfg:
# optimizer_cfg[fisrt_arg_name] = model.parameters()
# optimizer = OPTIMIZERS.build(optimizer_cfg)
# else:
# # set param-wise lr and weight decay recursively
# params: List = []
# self.add_params(params, model)
# optimizer_cfg[fisrt_arg_name] = params
# optimizer = OPTIMIZERS.build(optimizer_cfg)
optim_wrapper = OPTIM_WRAPPERS.build(
optim_wrapper_cfg, default_args=dict(optimizer=optimizer))
return optim_wrapper