import inspect import torch.nn as nn from typing import List, Optional, Union from mmengine.optim import DefaultOptimWrapperConstructor, OptimWrapper from mmengine.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, OPTIMIZERS) def add_weight_decay(model, weight_decay=1e-5, skip_list=()): decay = [] no_decay = [] for name, param in model.named_parameters(): if not param.requires_grad: continue # frozen weights if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list or 'diffloss' in name: no_decay.append(param) # no weight decay on bias, norm and diffloss else: decay.append(param) num_decay_params = sum(p.numel() for p in decay) num_nodecay_params = sum(p.numel() for p in no_decay) print(f"num decayed parameter tensors: {len(decay)}, with {num_decay_params:,} parameters") print(f"num non-decayed parameter tensors: {len(no_decay)}, with {num_nodecay_params:,} parameters") return [ {'params': no_decay, 'weight_decay': 0.}, {'params': decay, 'weight_decay': weight_decay}] class MAROptimWrapperConstructor(DefaultOptimWrapperConstructor): def __call__(self, model: nn.Module) -> OptimWrapper: if hasattr(model, 'module'): model = model.module optim_wrapper_cfg = self.optim_wrapper_cfg.copy() optim_wrapper_cfg.setdefault('type', 'OptimWrapper') optimizer_cfg = self.optimizer_cfg.copy() optimizer_cls = self.optimizer_cfg['type'] # Optimizer like HybridAdam in colossalai requires the argument name # `model_params` rather than `params`. Here we get the first argument # name and fill it with the model parameters. if isinstance(optimizer_cls, str): with OPTIMIZERS.switch_scope_and_registry(None) as registry: optimizer_cls = registry.get(self.optimizer_cfg['type']) fisrt_arg_name = next( iter(inspect.signature(optimizer_cls).parameters)) # import pdb; pdb.set_trace() param_groups = add_weight_decay(model, optimizer_cfg.pop('weight_decay', 0)) optimizer_cfg[fisrt_arg_name] = param_groups optimizer = OPTIMIZERS.build(optimizer_cfg) # # if no paramwise option is specified, just use the global setting # if not self.paramwise_cfg: # optimizer_cfg[fisrt_arg_name] = model.parameters() # optimizer = OPTIMIZERS.build(optimizer_cfg) # else: # # set param-wise lr and weight decay recursively # params: List = [] # self.add_params(params, model) # optimizer_cfg[fisrt_arg_name] = params # optimizer = OPTIMIZERS.build(optimizer_cfg) optim_wrapper = OPTIM_WRAPPERS.build( optim_wrapper_cfg, default_args=dict(optimizer=optimizer)) return optim_wrapper