Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import List, Optional, Union | |
| import torch | |
| import torch.nn as nn | |
| from mmcv.cnn import build_norm_layer | |
| from mmengine.model import BaseModule | |
| from mmpretrain.registry import MODELS | |
| class SwAVNeck(BaseModule): | |
| """The non-linear neck of SwAV: fc-bn-relu-fc-normalization. | |
| Args: | |
| in_channels (int): Number of input channels. | |
| hid_channels (int): Number of hidden channels. | |
| out_channels (int): Number of output channels. | |
| with_avg_pool (bool): Whether to apply the global average pooling after | |
| backbone. Defaults to True. | |
| with_l2norm (bool): whether to normalize the output after projection. | |
| Defaults to True. | |
| norm_cfg (dict): Dictionary to construct and config norm layer. | |
| Defaults to dict(type='SyncBN'). | |
| init_cfg (dict or list[dict], optional): Initialization config dict. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| hid_channels: int, | |
| out_channels: int, | |
| with_avg_pool: bool = True, | |
| with_l2norm: bool = True, | |
| norm_cfg: dict = dict(type='SyncBN'), | |
| init_cfg: Optional[Union[dict, List[dict]]] = [ | |
| dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) | |
| ] | |
| ) -> None: | |
| super().__init__(init_cfg) | |
| self.with_avg_pool = with_avg_pool | |
| self.with_l2norm = with_l2norm | |
| if with_avg_pool: | |
| self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |
| if out_channels == 0: | |
| self.projection_neck = nn.Identity() | |
| elif hid_channels == 0: | |
| self.projection_neck = nn.Linear(in_channels, out_channels) | |
| else: | |
| self.norm = build_norm_layer(norm_cfg, hid_channels)[1] | |
| self.projection_neck = nn.Sequential( | |
| nn.Linear(in_channels, hid_channels), | |
| self.norm, | |
| nn.ReLU(inplace=True), | |
| nn.Linear(hid_channels, out_channels), | |
| ) | |
| def forward_projection(self, x: torch.Tensor) -> torch.Tensor: | |
| """Compute projection. | |
| Args: | |
| x (torch.Tensor): The feature vectors after pooling. | |
| Returns: | |
| torch.Tensor: The output features with projection or L2-norm. | |
| """ | |
| x = self.projection_neck(x) | |
| if self.with_l2norm: | |
| x = nn.functional.normalize(x, dim=1, p=2) | |
| return x | |
| def forward(self, x: List[torch.Tensor]) -> torch.Tensor: | |
| """Forward function. | |
| Args: | |
| x (List[torch.Tensor]): list of feature maps, len(x) according to | |
| len(num_crops). | |
| Returns: | |
| torch.Tensor: The projection vectors. | |
| """ | |
| avg_out = [] | |
| for _x in x: | |
| _x = _x[0] | |
| if self.with_avg_pool: | |
| _out = self.avgpool(_x) | |
| avg_out.append(_out) | |
| feat_vec = torch.cat(avg_out) # [sum(num_crops) * N, C] | |
| feat_vec = feat_vec.view(feat_vec.size(0), -1) | |
| output = self.forward_projection(feat_vec) | |
| return output | |