Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import math | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| from mmengine.model import BaseModule | |
| from mmpretrain.registry import MODELS | |
| from ..utils import build_norm_layer | |
| def is_pow2n(x): | |
| return x > 0 and (x & (x - 1) == 0) | |
| class ConvBlock2x(BaseModule): | |
| """The definition of convolution block.""" | |
| def __init__(self, | |
| in_channels: int, | |
| out_channels: int, | |
| mid_channels: int, | |
| norm_cfg: dict, | |
| act_cfg: dict, | |
| last_act: bool, | |
| init_cfg: Optional[dict] = None) -> None: | |
| super().__init__(init_cfg=init_cfg) | |
| self.conv1 = nn.Conv2d(in_channels, mid_channels, 3, 1, 1, bias=False) | |
| self.norm1 = build_norm_layer(norm_cfg, mid_channels) | |
| self.activate1 = MODELS.build(act_cfg) | |
| self.conv2 = nn.Conv2d(mid_channels, out_channels, 3, 1, 1, bias=False) | |
| self.norm2 = build_norm_layer(norm_cfg, out_channels) | |
| self.activate2 = MODELS.build(act_cfg) if last_act else nn.Identity() | |
| def forward(self, x: torch.Tensor): | |
| out = self.conv1(x) | |
| out = self.norm1(out) | |
| out = self.activate1(out) | |
| out = self.conv2(out) | |
| out = self.norm2(out) | |
| out = self.activate2(out) | |
| return out | |
| class DecoderConvModule(BaseModule): | |
| """The convolution module of decoder with upsampling.""" | |
| def __init__(self, | |
| in_channels: int, | |
| out_channels: int, | |
| mid_channels: int, | |
| kernel_size: int = 4, | |
| scale_factor: int = 2, | |
| num_conv_blocks: int = 1, | |
| norm_cfg: dict = dict(type='SyncBN'), | |
| act_cfg: dict = dict(type='ReLU6'), | |
| last_act: bool = True, | |
| init_cfg: Optional[dict] = None): | |
| super().__init__(init_cfg=init_cfg) | |
| assert (kernel_size - scale_factor >= 0) and\ | |
| (kernel_size - scale_factor) % 2 == 0,\ | |
| f'kernel_size should be greater than or equal to scale_factor '\ | |
| f'and (kernel_size - scale_factor) should be even numbers, '\ | |
| f'while the kernel size is {kernel_size} and scale_factor is '\ | |
| f'{scale_factor}.' | |
| padding = (kernel_size - scale_factor) // 2 | |
| self.upsample = nn.ConvTranspose2d( | |
| in_channels, | |
| in_channels, | |
| kernel_size=kernel_size, | |
| stride=scale_factor, | |
| padding=padding, | |
| bias=True) | |
| conv_blocks_list = [ | |
| ConvBlock2x( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| mid_channels=mid_channels, | |
| norm_cfg=norm_cfg, | |
| last_act=last_act, | |
| act_cfg=act_cfg) for _ in range(num_conv_blocks) | |
| ] | |
| self.conv_blocks = nn.Sequential(*conv_blocks_list) | |
| def forward(self, x): | |
| x = self.upsample(x) | |
| return self.conv_blocks(x) | |
| class SparKLightDecoder(BaseModule): | |
| """The decoder for SparK, which upsamples the feature maps. | |
| Args: | |
| feature_dim (int): The dimension of feature map. | |
| upsample_ratio (int): The ratio of upsample, equal to downsample_raito | |
| of the algorithm. | |
| mid_channels (int): The middle channel of `DecoderConvModule`. Defaults | |
| to 0. | |
| kernel_size (int): The kernel size of `ConvTranspose2d` in | |
| `DecoderConvModule`. Defaults to 4. | |
| scale_factor (int): The scale_factor of `ConvTranspose2d` in | |
| `DecoderConvModule`. Defaults to 2. | |
| num_conv_blocks (int): The number of convolution blocks in | |
| `DecoderConvModule`. Defaults to 1. | |
| norm_cfg (dict): Normalization config. Defaults to dict(type='SyncBN'). | |
| act_cfg (dict): Activation config. Defaults to dict(type='ReLU6'). | |
| last_act (bool): Whether apply the last activation in | |
| `DecoderConvModule`. Defaults to False. | |
| init_cfg (dict or list[dict], optional): Initialization config dict. | |
| """ | |
| def __init__( | |
| self, | |
| feature_dim: int, | |
| upsample_ratio: int, | |
| mid_channels: int = 0, | |
| kernel_size: int = 4, | |
| scale_factor: int = 2, | |
| num_conv_blocks: int = 1, | |
| norm_cfg: dict = dict(type='SyncBN'), | |
| act_cfg: dict = dict(type='ReLU6'), | |
| last_act: bool = False, | |
| init_cfg: Optional[dict] = [ | |
| dict(type='Kaiming', layer=['Conv2d', 'ConvTranspose2d']), | |
| dict(type='TruncNormal', std=0.02, layer=['Linear']), | |
| dict( | |
| type='Constant', | |
| val=1, | |
| layer=['_BatchNorm', 'LayerNorm', 'SyncBatchNorm']) | |
| ], | |
| ): | |
| super().__init__(init_cfg=init_cfg) | |
| self.feature_dim = feature_dim | |
| assert is_pow2n(upsample_ratio) | |
| n = round(math.log2(upsample_ratio)) | |
| channels = [feature_dim // 2**i for i in range(n + 1)] | |
| self.decoder = nn.ModuleList([ | |
| DecoderConvModule( | |
| in_channels=c_in, | |
| out_channels=c_out, | |
| mid_channels=c_in if mid_channels == 0 else mid_channels, | |
| kernel_size=kernel_size, | |
| scale_factor=scale_factor, | |
| num_conv_blocks=num_conv_blocks, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg, | |
| last_act=last_act) | |
| for (c_in, c_out) in zip(channels[:-1], channels[1:]) | |
| ]) | |
| self.proj = nn.Conv2d( | |
| channels[-1], 3, kernel_size=1, stride=1, bias=True) | |
| def forward(self, to_dec): | |
| x = 0 | |
| for i, d in enumerate(self.decoder): | |
| if i < len(to_dec) and to_dec[i] is not None: | |
| x = x + to_dec[i] | |
| x = self.decoder[i](x) | |
| return self.proj(x) | |