Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| from mmengine.model import BaseModule | |
| from mmpretrain.models.heads import ClsHead | |
| from mmpretrain.registry import MODELS | |
| from ..utils import build_norm_layer | |
| class BatchNormLinear(BaseModule): | |
| def __init__(self, in_channels, out_channels, norm_cfg=dict(type='BN1d')): | |
| super(BatchNormLinear, self).__init__() | |
| self.bn = build_norm_layer(norm_cfg, in_channels) | |
| self.linear = nn.Linear(in_channels, out_channels) | |
| def fuse(self): | |
| w = self.bn.weight / (self.bn.running_var + self.bn.eps)**0.5 | |
| b = self.bn.bias - self.bn.running_mean * \ | |
| self.bn.weight / (self.bn.running_var + self.bn.eps) ** 0.5 | |
| w = self.linear.weight * w[None, :] | |
| b = (self.linear.weight @ b[:, None]).view(-1) + self.linear.bias | |
| self.linear.weight.data.copy_(w) | |
| self.linear.bias.data.copy_(b) | |
| return self.linear | |
| def forward(self, x): | |
| x = self.bn(x) | |
| x = self.linear(x) | |
| return x | |
| def fuse_parameters(module): | |
| for child_name, child in module.named_children(): | |
| if hasattr(child, 'fuse'): | |
| setattr(module, child_name, child.fuse()) | |
| else: | |
| fuse_parameters(child) | |
| class LeViTClsHead(ClsHead): | |
| def __init__(self, | |
| num_classes=1000, | |
| distillation=True, | |
| in_channels=None, | |
| deploy=False, | |
| **kwargs): | |
| super(LeViTClsHead, self).__init__(**kwargs) | |
| self.num_classes = num_classes | |
| self.distillation = distillation | |
| self.deploy = deploy | |
| self.head = BatchNormLinear(in_channels, num_classes) | |
| if distillation: | |
| self.head_dist = BatchNormLinear(in_channels, num_classes) | |
| if self.deploy: | |
| self.switch_to_deploy(self) | |
| def switch_to_deploy(self): | |
| if self.deploy: | |
| return | |
| fuse_parameters(self) | |
| self.deploy = True | |
| def forward(self, x): | |
| x = self.pre_logits(x) | |
| if self.distillation: | |
| x = self.head(x), self.head_dist(x) # 2 16 384 -> 2 1000 | |
| if not self.training: | |
| x = (x[0] + x[1]) / 2 | |
| else: | |
| raise NotImplementedError("MMPretrain doesn't support " | |
| 'training in distillation mode.') | |
| else: | |
| x = self.head(x) | |
| return x | |