|
import inspect |
|
import torch.nn as nn |
|
from typing import List, Optional, Union |
|
from mmengine.optim import DefaultOptimWrapperConstructor, OptimWrapper |
|
from mmengine.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, |
|
OPTIMIZERS) |
|
|
|
|
|
def add_weight_decay(model, weight_decay=1e-5, skip_list=()): |
|
decay = [] |
|
no_decay = [] |
|
for name, param in model.named_parameters(): |
|
if not param.requires_grad: |
|
continue |
|
if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list or 'diffloss' in name: |
|
no_decay.append(param) |
|
else: |
|
decay.append(param) |
|
|
|
num_decay_params = sum(p.numel() for p in decay) |
|
num_nodecay_params = sum(p.numel() for p in no_decay) |
|
print(f"num decayed parameter tensors: {len(decay)}, with {num_decay_params:,} parameters") |
|
print(f"num non-decayed parameter tensors: {len(no_decay)}, with {num_nodecay_params:,} parameters") |
|
|
|
return [ |
|
{'params': no_decay, 'weight_decay': 0.}, |
|
{'params': decay, 'weight_decay': weight_decay}] |
|
|
|
|
|
class MAROptimWrapperConstructor(DefaultOptimWrapperConstructor): |
|
def __call__(self, model: nn.Module) -> OptimWrapper: |
|
if hasattr(model, 'module'): |
|
model = model.module |
|
|
|
optim_wrapper_cfg = self.optim_wrapper_cfg.copy() |
|
optim_wrapper_cfg.setdefault('type', 'OptimWrapper') |
|
optimizer_cfg = self.optimizer_cfg.copy() |
|
optimizer_cls = self.optimizer_cfg['type'] |
|
|
|
|
|
|
|
if isinstance(optimizer_cls, str): |
|
with OPTIMIZERS.switch_scope_and_registry(None) as registry: |
|
optimizer_cls = registry.get(self.optimizer_cfg['type']) |
|
fisrt_arg_name = next( |
|
iter(inspect.signature(optimizer_cls).parameters)) |
|
|
|
param_groups = add_weight_decay(model, optimizer_cfg.pop('weight_decay', 0)) |
|
optimizer_cfg[fisrt_arg_name] = param_groups |
|
optimizer = OPTIMIZERS.build(optimizer_cfg) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optim_wrapper = OPTIM_WRAPPERS.build( |
|
optim_wrapper_cfg, default_args=dict(optimizer=optimizer)) |
|
return optim_wrapper |
|
|