| import torch | |
| from torch import nn | |
| import vietocr.model.backbone.vgg as vgg | |
| from vietocr.model.backbone.resnet import Resnet50 | |
| class CNN(nn.Module): | |
| def __init__(self, backbone, **kwargs): | |
| super(CNN, self).__init__() | |
| if backbone == 'vgg11_bn': | |
| self.model = vgg.vgg11_bn(**kwargs) | |
| elif backbone == 'vgg19_bn': | |
| self.model = vgg.vgg19_bn(**kwargs) | |
| elif backbone == 'resnet50': | |
| self.model = Resnet50(**kwargs) | |
| def forward(self, x): | |
| return self.model(x) | |
| def freeze(self): | |
| for name, param in self.model.features.named_parameters(): | |
| if name != 'last_conv_1x1': | |
| param.requires_grad = False | |
| def unfreeze(self): | |
| for param in self.model.features.parameters(): | |
| param.requires_grad = True | |