Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| # Modified from official impl https://github.com/apple/ml-mobileone/blob/main/mobileone.py # noqa: E501 | |
| from typing import Optional, Sequence | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer | |
| from mmengine.model import BaseModule, ModuleList, Sequential | |
| from torch.nn.modules.batchnorm import _BatchNorm | |
| from mmpretrain.registry import MODELS | |
| from ..utils.se_layer import SELayer | |
| from .base_backbone import BaseBackbone | |
| class MobileOneBlock(BaseModule): | |
| """MobileOne block for MobileOne backbone. | |
| Args: | |
| in_channels (int): The input channels of the block. | |
| out_channels (int): The output channels of the block. | |
| kernel_size (int): The kernel size of the convs in the block. If the | |
| kernel size is large than 1, there will be a ``branch_scale`` in | |
| the block. | |
| num_convs (int): Number of the convolution branches in the block. | |
| stride (int): Stride of convolution layers. Defaults to 1. | |
| padding (int): Padding of the convolution layers. Defaults to 1. | |
| dilation (int): Dilation of the convolution layers. Defaults to 1. | |
| groups (int): Groups of the convolution layers. Defaults to 1. | |
| se_cfg (None or dict): The configuration of the se module. | |
| Defaults to None. | |
| norm_cfg (dict): Configuration to construct and config norm layer. | |
| Defaults to ``dict(type='BN')``. | |
| act_cfg (dict): Config dict for activation layer. | |
| Defaults to ``dict(type='ReLU')``. | |
| deploy (bool): Whether the model structure is in the deployment mode. | |
| Defaults to False. | |
| init_cfg (dict or list[dict], optional): Initialization config dict. | |
| Defaults to None. | |
| """ | |
| def __init__(self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int, | |
| num_convs: int, | |
| stride: int = 1, | |
| padding: int = 1, | |
| dilation: int = 1, | |
| groups: int = 1, | |
| se_cfg: Optional[dict] = None, | |
| conv_cfg: Optional[dict] = None, | |
| norm_cfg: Optional[dict] = dict(type='BN'), | |
| act_cfg: Optional[dict] = dict(type='ReLU'), | |
| deploy: bool = False, | |
| init_cfg: Optional[dict] = None): | |
| super(MobileOneBlock, self).__init__(init_cfg) | |
| assert se_cfg is None or isinstance(se_cfg, dict) | |
| if se_cfg is not None: | |
| self.se = SELayer(channels=out_channels, **se_cfg) | |
| else: | |
| self.se = nn.Identity() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.kernel_size = kernel_size | |
| self.num_conv_branches = num_convs | |
| self.stride = stride | |
| self.padding = padding | |
| self.se_cfg = se_cfg | |
| self.conv_cfg = conv_cfg | |
| self.norm_cfg = norm_cfg | |
| self.act_cfg = act_cfg | |
| self.deploy = deploy | |
| self.groups = groups | |
| self.dilation = dilation | |
| if deploy: | |
| self.branch_reparam = build_conv_layer( | |
| conv_cfg, | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| groups=self.groups, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation, | |
| bias=True) | |
| else: | |
| # judge if input shape and output shape are the same. | |
| # If true, add a normalized identity shortcut. | |
| if out_channels == in_channels and stride == 1: | |
| self.branch_norm = build_norm_layer(norm_cfg, in_channels)[1] | |
| else: | |
| self.branch_norm = None | |
| self.branch_scale = None | |
| if kernel_size > 1: | |
| self.branch_scale = self.create_conv_bn(kernel_size=1) | |
| self.branch_conv_list = ModuleList() | |
| for _ in range(num_convs): | |
| self.branch_conv_list.append( | |
| self.create_conv_bn( | |
| kernel_size=kernel_size, | |
| padding=padding, | |
| dilation=dilation)) | |
| self.act = build_activation_layer(act_cfg) | |
| def create_conv_bn(self, kernel_size, dilation=1, padding=0): | |
| """cearte a (conv + bn) Sequential layer.""" | |
| conv_bn = Sequential() | |
| conv_bn.add_module( | |
| 'conv', | |
| build_conv_layer( | |
| self.conv_cfg, | |
| in_channels=self.in_channels, | |
| out_channels=self.out_channels, | |
| kernel_size=kernel_size, | |
| groups=self.groups, | |
| stride=self.stride, | |
| dilation=dilation, | |
| padding=padding, | |
| bias=False)) | |
| conv_bn.add_module( | |
| 'norm', | |
| build_norm_layer(self.norm_cfg, num_features=self.out_channels)[1]) | |
| return conv_bn | |
| def forward(self, x): | |
| def _inner_forward(inputs): | |
| if self.deploy: | |
| return self.branch_reparam(inputs) | |
| inner_out = 0 | |
| if self.branch_norm is not None: | |
| inner_out = self.branch_norm(inputs) | |
| if self.branch_scale is not None: | |
| inner_out += self.branch_scale(inputs) | |
| for branch_conv in self.branch_conv_list: | |
| inner_out += branch_conv(inputs) | |
| return inner_out | |
| return self.act(self.se(_inner_forward(x))) | |
| def switch_to_deploy(self): | |
| """Switch the model structure from training mode to deployment mode.""" | |
| if self.deploy: | |
| return | |
| assert self.norm_cfg['type'] == 'BN', \ | |
| "Switch is not allowed when norm_cfg['type'] != 'BN'." | |
| reparam_weight, reparam_bias = self.reparameterize() | |
| self.branch_reparam = build_conv_layer( | |
| self.conv_cfg, | |
| self.in_channels, | |
| self.out_channels, | |
| kernel_size=self.kernel_size, | |
| stride=self.stride, | |
| padding=self.padding, | |
| dilation=self.dilation, | |
| groups=self.groups, | |
| bias=True) | |
| self.branch_reparam.weight.data = reparam_weight | |
| self.branch_reparam.bias.data = reparam_bias | |
| for param in self.parameters(): | |
| param.detach_() | |
| delattr(self, 'branch_conv_list') | |
| if hasattr(self, 'branch_scale'): | |
| delattr(self, 'branch_scale') | |
| delattr(self, 'branch_norm') | |
| self.deploy = True | |
| def reparameterize(self): | |
| """Fuse all the parameters of all branches. | |
| Returns: | |
| tuple[torch.Tensor, torch.Tensor]: Parameters after fusion of all | |
| branches. the first element is the weights and the second is | |
| the bias. | |
| """ | |
| weight_conv, bias_conv = 0, 0 | |
| for branch_conv in self.branch_conv_list: | |
| weight, bias = self._fuse_conv_bn(branch_conv) | |
| weight_conv += weight | |
| bias_conv += bias | |
| weight_scale, bias_scale = 0, 0 | |
| if self.branch_scale is not None: | |
| weight_scale, bias_scale = self._fuse_conv_bn(self.branch_scale) | |
| # Pad scale branch kernel to match conv branch kernel size. | |
| pad = self.kernel_size // 2 | |
| weight_scale = F.pad(weight_scale, [pad, pad, pad, pad]) | |
| weight_norm, bias_norm = 0, 0 | |
| if self.branch_norm: | |
| tmp_conv_bn = self._norm_to_conv(self.branch_norm) | |
| weight_norm, bias_norm = self._fuse_conv_bn(tmp_conv_bn) | |
| return (weight_conv + weight_scale + weight_norm, | |
| bias_conv + bias_scale + bias_norm) | |
| def _fuse_conv_bn(self, branch): | |
| """Fuse the parameters in a branch with a conv and bn. | |
| Args: | |
| branch (mmcv.runner.Sequential): A branch with conv and bn. | |
| Returns: | |
| tuple[torch.Tensor, torch.Tensor]: The parameters obtained after | |
| fusing the parameters of conv and bn in one branch. | |
| The first element is the weight and the second is the bias. | |
| """ | |
| if branch is None: | |
| return 0, 0 | |
| kernel = branch.conv.weight | |
| running_mean = branch.norm.running_mean | |
| running_var = branch.norm.running_var | |
| gamma = branch.norm.weight | |
| beta = branch.norm.bias | |
| eps = branch.norm.eps | |
| std = (running_var + eps).sqrt() | |
| fused_weight = (gamma / std).reshape(-1, 1, 1, 1) * kernel | |
| fused_bias = beta - running_mean * gamma / std | |
| return fused_weight, fused_bias | |
| def _norm_to_conv(self, branch_nrom): | |
| """Convert a norm layer to a conv-bn sequence towards | |
| ``self.kernel_size``. | |
| Args: | |
| branch (nn.BatchNorm2d): A branch only with bn in the block. | |
| Returns: | |
| (mmcv.runner.Sequential): a sequential with conv and bn. | |
| """ | |
| input_dim = self.in_channels // self.groups | |
| conv_weight = torch.zeros( | |
| (self.in_channels, input_dim, self.kernel_size, self.kernel_size), | |
| dtype=branch_nrom.weight.dtype) | |
| for i in range(self.in_channels): | |
| conv_weight[i, i % input_dim, self.kernel_size // 2, | |
| self.kernel_size // 2] = 1 | |
| conv_weight = conv_weight.to(branch_nrom.weight.device) | |
| tmp_conv = self.create_conv_bn(kernel_size=self.kernel_size) | |
| tmp_conv.conv.weight.data = conv_weight | |
| tmp_conv.norm = branch_nrom | |
| return tmp_conv | |
| class MobileOne(BaseBackbone): | |
| """MobileOne backbone. | |
| A PyTorch impl of : `An Improved One millisecond Mobile Backbone | |
| <https://arxiv.org/pdf/2206.04040.pdf>`_ | |
| Args: | |
| arch (str | dict): MobileOne architecture. If use string, choose | |
| from 's0', 's1', 's2', 's3' and 's4'. If use dict, it should | |
| have below keys: | |
| - num_blocks (Sequence[int]): Number of blocks in each stage. | |
| - width_factor (Sequence[float]): Width factor in each stage. | |
| - num_conv_branches (Sequence[int]): Number of conv branches | |
| in each stage. | |
| - num_se_blocks (Sequence[int]): Number of SE layers in each | |
| stage, all the SE layers are placed in the subsequent order | |
| in each stage. | |
| Defaults to 's0'. | |
| in_channels (int): Number of input image channels. Default: 3. | |
| out_indices (Sequence[int] | int): Output from which stages. | |
| Defaults to ``(3, )``. | |
| frozen_stages (int): Stages to be frozen (all param fixed). -1 means | |
| not freezing any parameters. Defaults to -1. | |
| conv_cfg (dict | None): The config dict for conv layers. | |
| Defaults to None. | |
| norm_cfg (dict): The config dict for norm layers. | |
| Defaults to ``dict(type='BN')``. | |
| act_cfg (dict): Config dict for activation layer. | |
| Defaults to ``dict(type='ReLU')``. | |
| deploy (bool): Whether to switch the model structure to deployment | |
| mode. Defaults to False. | |
| norm_eval (bool): Whether to set norm layers to eval mode, namely, | |
| freeze running stats (mean and var). Note: Effect on Batch Norm | |
| and its variants only. Defaults to False. | |
| init_cfg (dict or list[dict], optional): Initialization config dict. | |
| Example: | |
| >>> from mmpretrain.models import MobileOne | |
| >>> import torch | |
| >>> x = torch.rand(1, 3, 224, 224) | |
| >>> model = MobileOne("s0", out_indices=(0, 1, 2, 3)) | |
| >>> model.eval() | |
| >>> outputs = model(x) | |
| >>> for out in outputs: | |
| ... print(tuple(out.shape)) | |
| (1, 48, 56, 56) | |
| (1, 128, 28, 28) | |
| (1, 256, 14, 14) | |
| (1, 1024, 7, 7) | |
| """ | |
| arch_zoo = { | |
| 's0': | |
| dict( | |
| num_blocks=[2, 8, 10, 1], | |
| width_factor=[0.75, 1.0, 1.0, 2.0], | |
| num_conv_branches=[4, 4, 4, 4], | |
| num_se_blocks=[0, 0, 0, 0]), | |
| 's1': | |
| dict( | |
| num_blocks=[2, 8, 10, 1], | |
| width_factor=[1.5, 1.5, 2.0, 2.5], | |
| num_conv_branches=[1, 1, 1, 1], | |
| num_se_blocks=[0, 0, 0, 0]), | |
| 's2': | |
| dict( | |
| num_blocks=[2, 8, 10, 1], | |
| width_factor=[1.5, 2.0, 2.5, 4.0], | |
| num_conv_branches=[1, 1, 1, 1], | |
| num_se_blocks=[0, 0, 0, 0]), | |
| 's3': | |
| dict( | |
| num_blocks=[2, 8, 10, 1], | |
| width_factor=[2.0, 2.5, 3.0, 4.0], | |
| num_conv_branches=[1, 1, 1, 1], | |
| num_se_blocks=[0, 0, 0, 0]), | |
| 's4': | |
| dict( | |
| num_blocks=[2, 8, 10, 1], | |
| width_factor=[3.0, 3.5, 3.5, 4.0], | |
| num_conv_branches=[1, 1, 1, 1], | |
| num_se_blocks=[0, 0, 5, 1]) | |
| } | |
| def __init__(self, | |
| arch, | |
| in_channels=3, | |
| out_indices=(3, ), | |
| frozen_stages=-1, | |
| conv_cfg=None, | |
| norm_cfg=dict(type='BN'), | |
| act_cfg=dict(type='ReLU'), | |
| se_cfg=dict(ratio=16), | |
| deploy=False, | |
| norm_eval=False, | |
| init_cfg=[ | |
| dict(type='Kaiming', layer=['Conv2d']), | |
| dict(type='Constant', val=1, layer=['_BatchNorm']) | |
| ]): | |
| super(MobileOne, self).__init__(init_cfg) | |
| if isinstance(arch, str): | |
| assert arch in self.arch_zoo, f'"arch": "{arch}"' \ | |
| f' is not one of the {list(self.arch_zoo.keys())}' | |
| arch = self.arch_zoo[arch] | |
| elif not isinstance(arch, dict): | |
| raise TypeError('Expect "arch" to be either a string ' | |
| f'or a dict, got {type(arch)}') | |
| self.arch = arch | |
| for k, value in self.arch.items(): | |
| assert isinstance(value, list) and len(value) == 4, \ | |
| f'the value of {k} in arch must be list with 4 items.' | |
| self.in_channels = in_channels | |
| self.deploy = deploy | |
| self.frozen_stages = frozen_stages | |
| self.norm_eval = norm_eval | |
| self.conv_cfg = conv_cfg | |
| self.norm_cfg = norm_cfg | |
| self.se_cfg = se_cfg | |
| self.act_cfg = act_cfg | |
| base_channels = [64, 128, 256, 512] | |
| channels = min(64, | |
| int(base_channels[0] * self.arch['width_factor'][0])) | |
| self.stage0 = MobileOneBlock( | |
| self.in_channels, | |
| channels, | |
| stride=2, | |
| kernel_size=3, | |
| num_convs=1, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg, | |
| deploy=deploy) | |
| self.in_planes = channels | |
| self.stages = [] | |
| for i, num_blocks in enumerate(self.arch['num_blocks']): | |
| planes = int(base_channels[i] * self.arch['width_factor'][i]) | |
| stage = self._make_stage(planes, num_blocks, | |
| arch['num_se_blocks'][i], | |
| arch['num_conv_branches'][i]) | |
| stage_name = f'stage{i + 1}' | |
| self.add_module(stage_name, stage) | |
| self.stages.append(stage_name) | |
| if isinstance(out_indices, int): | |
| out_indices = [out_indices] | |
| assert isinstance(out_indices, Sequence), \ | |
| f'"out_indices" must by a sequence or int, ' \ | |
| f'get {type(out_indices)} instead.' | |
| out_indices = list(out_indices) | |
| for i, index in enumerate(out_indices): | |
| if index < 0: | |
| out_indices[i] = len(self.stages) + index | |
| assert 0 <= out_indices[i] <= len(self.stages), \ | |
| f'Invalid out_indices {index}.' | |
| self.out_indices = out_indices | |
| def _make_stage(self, planes, num_blocks, num_se, num_conv_branches): | |
| strides = [2] + [1] * (num_blocks - 1) | |
| if num_se > num_blocks: | |
| raise ValueError('Number of SE blocks cannot ' | |
| 'exceed number of layers.') | |
| blocks = [] | |
| for i in range(num_blocks): | |
| use_se = False | |
| if i >= (num_blocks - num_se): | |
| use_se = True | |
| blocks.append( | |
| # Depthwise conv | |
| MobileOneBlock( | |
| in_channels=self.in_planes, | |
| out_channels=self.in_planes, | |
| kernel_size=3, | |
| num_convs=num_conv_branches, | |
| stride=strides[i], | |
| padding=1, | |
| groups=self.in_planes, | |
| se_cfg=self.se_cfg if use_se else None, | |
| conv_cfg=self.conv_cfg, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg, | |
| deploy=self.deploy)) | |
| blocks.append( | |
| # Pointwise conv | |
| MobileOneBlock( | |
| in_channels=self.in_planes, | |
| out_channels=planes, | |
| kernel_size=1, | |
| num_convs=num_conv_branches, | |
| stride=1, | |
| padding=0, | |
| se_cfg=self.se_cfg if use_se else None, | |
| conv_cfg=self.conv_cfg, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg, | |
| deploy=self.deploy)) | |
| self.in_planes = planes | |
| return Sequential(*blocks) | |
| def forward(self, x): | |
| x = self.stage0(x) | |
| outs = [] | |
| for i, stage_name in enumerate(self.stages): | |
| stage = getattr(self, stage_name) | |
| x = stage(x) | |
| if i in self.out_indices: | |
| outs.append(x) | |
| return tuple(outs) | |
| def _freeze_stages(self): | |
| if self.frozen_stages >= 0: | |
| self.stage0.eval() | |
| for param in self.stage0.parameters(): | |
| param.requires_grad = False | |
| for i in range(self.frozen_stages): | |
| stage = getattr(self, f'stage{i+1}') | |
| stage.eval() | |
| for param in stage.parameters(): | |
| param.requires_grad = False | |
| def train(self, mode=True): | |
| """switch the mobile to train mode or not.""" | |
| super(MobileOne, self).train(mode) | |
| self._freeze_stages() | |
| if mode and self.norm_eval: | |
| for m in self.modules(): | |
| if isinstance(m, _BatchNorm): | |
| m.eval() | |
| def switch_to_deploy(self): | |
| """switch the model to deploy mode, which has smaller amount of | |
| parameters and calculations.""" | |
| for m in self.modules(): | |
| if isinstance(m, MobileOneBlock): | |
| m.switch_to_deploy() | |
| self.deploy = True | |