| from networks.encoders.mobilenetv2 import MobileNetV2 | |
| from networks.encoders.mobilenetv3 import MobileNetV3Large | |
| from networks.encoders.resnet import ResNet101, ResNet50 | |
| from networks.encoders.resnest import resnest | |
| from networks.encoders.swin import build_swin_model | |
| from networks.layers.normalization import FrozenBatchNorm2d | |
| from torch import nn | |
| def build_encoder(name, frozen_bn=True, freeze_at=-1): | |
| if frozen_bn: | |
| BatchNorm = FrozenBatchNorm2d | |
| else: | |
| BatchNorm = nn.BatchNorm2d | |
| if name == 'mobilenetv2': | |
| return MobileNetV2(16, BatchNorm, freeze_at=freeze_at) | |
| elif name == 'mobilenetv3': | |
| return MobileNetV3Large(16, BatchNorm, freeze_at=freeze_at) | |
| elif name == 'resnet50': | |
| return ResNet50(16, BatchNorm, freeze_at=freeze_at) | |
| elif name == 'resnet101': | |
| return ResNet101(16, BatchNorm, freeze_at=freeze_at) | |
| elif name == 'resnest50': | |
| return resnest.resnest50(norm_layer=BatchNorm, | |
| dilation=2, | |
| freeze_at=freeze_at) | |
| elif name == 'resnest101': | |
| return resnest.resnest101(norm_layer=BatchNorm, | |
| dilation=2, | |
| freeze_at=freeze_at) | |
| elif 'swin' in name: | |
| return build_swin_model(name, freeze_at=freeze_at) | |
| else: | |
| raise NotImplementedError | |