Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import re | |
| from typing import Optional, Tuple | |
| import torch.nn as nn | |
| from mmpretrain.models.utils.sparse_modules import (SparseAvgPooling, | |
| SparseBatchNorm2d, | |
| SparseConv2d, | |
| SparseMaxPooling, | |
| SparseSyncBatchNorm2d) | |
| from mmpretrain.registry import MODELS | |
| from .resnet import ResNet | |
| class SparseResNet(ResNet): | |
| """ResNet with sparse module conversion function. | |
| Modified from https://github.com/keyu-tian/SparK/blob/main/encoder.py | |
| Args: | |
| depth (int): Network depth, from {18, 34, 50, 101, 152}. | |
| in_channels (int): Number of input image channels. Defaults to 3. | |
| stem_channels (int): Output channels of the stem layer. Defaults to 64. | |
| base_channels (int): Middle channels of the first stage. | |
| Defaults to 64. | |
| num_stages (int): Stages of the network. Defaults to 4. | |
| strides (Sequence[int]): Strides of the first block of each stage. | |
| Defaults to ``(1, 2, 2, 2)``. | |
| dilations (Sequence[int]): Dilation of each stage. | |
| Defaults to ``(1, 1, 1, 1)``. | |
| out_indices (Sequence[int]): Output from which stages. | |
| Defaults to ``(3, )``. | |
| style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two | |
| layer is the 3x3 conv layer, otherwise the stride-two layer is | |
| the first 1x1 conv layer. | |
| deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. | |
| Defaults to False. | |
| avg_down (bool): Use AvgPool instead of stride conv when | |
| downsampling in the bottleneck. Defaults to False. | |
| frozen_stages (int): Stages to be frozen (stop grad and set eval mode). | |
| -1 means not freezing any parameters. Defaults to -1. | |
| conv_cfg (dict | None): The config dict for conv layers. | |
| Defaults to None. | |
| norm_cfg (dict): The config dict for norm layers. | |
| norm_eval (bool): Whether to set norm layers to eval mode, namely, | |
| freeze running stats (mean and var). Note: Effect on Batch Norm | |
| and its variants only. Defaults to False. | |
| with_cp (bool): Use checkpoint or not. Using checkpoint will save some | |
| memory while slowing down the training speed. Defaults to False. | |
| zero_init_residual (bool): Whether to use zero init for last norm layer | |
| in resblocks to let them behave as identity. Defaults to True. | |
| drop_path_rate (float): stochastic depth rate. Defaults to 0. | |
| """ | |
| def __init__(self, | |
| depth: int, | |
| in_channels: int = 3, | |
| stem_channels: int = 64, | |
| base_channels: int = 64, | |
| expansion: Optional[int] = None, | |
| num_stages: int = 4, | |
| strides: Tuple[int] = (1, 2, 2, 2), | |
| dilations: Tuple[int] = (1, 1, 1, 1), | |
| out_indices: Tuple[int] = (3, ), | |
| style: str = 'pytorch', | |
| deep_stem: bool = False, | |
| avg_down: bool = False, | |
| frozen_stages: int = -1, | |
| conv_cfg: Optional[dict] = None, | |
| norm_cfg: dict = dict(type='SparseSyncBatchNorm2d'), | |
| norm_eval: bool = False, | |
| with_cp: bool = False, | |
| zero_init_residual: bool = False, | |
| init_cfg: Optional[dict] = [ | |
| dict(type='Kaiming', layer=['Conv2d']), | |
| dict( | |
| type='Constant', | |
| val=1, | |
| layer=['_BatchNorm', 'GroupNorm']) | |
| ], | |
| drop_path_rate: float = 0, | |
| **kwargs): | |
| super().__init__( | |
| depth=depth, | |
| in_channels=in_channels, | |
| stem_channels=stem_channels, | |
| base_channels=base_channels, | |
| expansion=expansion, | |
| num_stages=num_stages, | |
| strides=strides, | |
| dilations=dilations, | |
| out_indices=out_indices, | |
| style=style, | |
| deep_stem=deep_stem, | |
| avg_down=avg_down, | |
| frozen_stages=frozen_stages, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| norm_eval=norm_eval, | |
| with_cp=with_cp, | |
| zero_init_residual=zero_init_residual, | |
| init_cfg=init_cfg, | |
| drop_path_rate=drop_path_rate, | |
| **kwargs) | |
| norm_type = norm_cfg['type'] | |
| enable_sync_bn = False | |
| if re.search('Sync', norm_type) is not None: | |
| enable_sync_bn = True | |
| self.dense_model_to_sparse(m=self, enable_sync_bn=enable_sync_bn) | |
| def dense_model_to_sparse(self, m: nn.Module, | |
| enable_sync_bn: bool) -> nn.Module: | |
| """Convert regular dense modules to sparse modules.""" | |
| output = m | |
| if isinstance(m, nn.Conv2d): | |
| m: nn.Conv2d | |
| bias = m.bias is not None | |
| output = SparseConv2d( | |
| m.in_channels, | |
| m.out_channels, | |
| kernel_size=m.kernel_size, | |
| stride=m.stride, | |
| padding=m.padding, | |
| dilation=m.dilation, | |
| groups=m.groups, | |
| bias=bias, | |
| padding_mode=m.padding_mode, | |
| ) | |
| output.weight.data.copy_(m.weight.data) | |
| if bias: | |
| output.bias.data.copy_(m.bias.data) | |
| elif isinstance(m, nn.MaxPool2d): | |
| m: nn.MaxPool2d | |
| output = SparseMaxPooling( | |
| m.kernel_size, | |
| stride=m.stride, | |
| padding=m.padding, | |
| dilation=m.dilation, | |
| return_indices=m.return_indices, | |
| ceil_mode=m.ceil_mode) | |
| elif isinstance(m, nn.AvgPool2d): | |
| m: nn.AvgPool2d | |
| output = SparseAvgPooling( | |
| m.kernel_size, | |
| m.stride, | |
| m.padding, | |
| ceil_mode=m.ceil_mode, | |
| count_include_pad=m.count_include_pad, | |
| divisor_override=m.divisor_override) | |
| elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)): | |
| m: nn.BatchNorm2d | |
| output = (SparseSyncBatchNorm2d | |
| if enable_sync_bn else SparseBatchNorm2d)( | |
| m.weight.shape[0], | |
| eps=m.eps, | |
| momentum=m.momentum, | |
| affine=m.affine, | |
| track_running_stats=m.track_running_stats) | |
| output.weight.data.copy_(m.weight.data) | |
| output.bias.data.copy_(m.bias.data) | |
| output.running_mean.data.copy_(m.running_mean.data) | |
| output.running_var.data.copy_(m.running_var.data) | |
| output.num_batches_tracked.data.copy_(m.num_batches_tracked.data) | |
| elif isinstance(m, (nn.Conv1d, )): | |
| raise NotImplementedError | |
| for name, child in m.named_children(): | |
| output.add_module( | |
| name, | |
| self.dense_model_to_sparse( | |
| child, enable_sync_bn=enable_sync_bn)) | |
| del m | |
| return output | |