Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.nn.init as init | |
| from .box_utils import Detect, PriorBox | |
| class L2Norm(nn.Module): | |
| def __init__(self, n_channels, scale): | |
| super(L2Norm, self).__init__() | |
| self.n_channels = n_channels | |
| self.gamma = scale or None | |
| self.eps = 1e-10 | |
| self.weight = nn.Parameter(torch.Tensor(self.n_channels)) | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| init.constant_(self.weight, self.gamma) | |
| def forward(self, x): | |
| norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps | |
| x = torch.div(x, norm) | |
| out = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(x) * x | |
| return out | |
| class S3FDNet(nn.Module): | |
| def __init__(self, device='cuda'): | |
| super(S3FDNet, self).__init__() | |
| self.device = device | |
| self.vgg = nn.ModuleList([ | |
| nn.Conv2d(3, 64, 3, 1, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(64, 64, 3, 1, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool2d(2, 2), | |
| nn.Conv2d(64, 128, 3, 1, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(128, 128, 3, 1, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool2d(2, 2), | |
| nn.Conv2d(128, 256, 3, 1, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(256, 256, 3, 1, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(256, 256, 3, 1, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool2d(2, 2, ceil_mode=True), | |
| nn.Conv2d(256, 512, 3, 1, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(512, 512, 3, 1, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(512, 512, 3, 1, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool2d(2, 2), | |
| nn.Conv2d(512, 512, 3, 1, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(512, 512, 3, 1, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(512, 512, 3, 1, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool2d(2, 2), | |
| nn.Conv2d(512, 1024, 3, 1, padding=6, dilation=6), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(1024, 1024, 1, 1), | |
| nn.ReLU(inplace=True), | |
| ]) | |
| self.L2Norm3_3 = L2Norm(256, 10) | |
| self.L2Norm4_3 = L2Norm(512, 8) | |
| self.L2Norm5_3 = L2Norm(512, 5) | |
| self.extras = nn.ModuleList([ | |
| nn.Conv2d(1024, 256, 1, 1), | |
| nn.Conv2d(256, 512, 3, 2, padding=1), | |
| nn.Conv2d(512, 128, 1, 1), | |
| nn.Conv2d(128, 256, 3, 2, padding=1), | |
| ]) | |
| self.loc = nn.ModuleList([ | |
| nn.Conv2d(256, 4, 3, 1, padding=1), | |
| nn.Conv2d(512, 4, 3, 1, padding=1), | |
| nn.Conv2d(512, 4, 3, 1, padding=1), | |
| nn.Conv2d(1024, 4, 3, 1, padding=1), | |
| nn.Conv2d(512, 4, 3, 1, padding=1), | |
| nn.Conv2d(256, 4, 3, 1, padding=1), | |
| ]) | |
| self.conf = nn.ModuleList([ | |
| nn.Conv2d(256, 4, 3, 1, padding=1), | |
| nn.Conv2d(512, 2, 3, 1, padding=1), | |
| nn.Conv2d(512, 2, 3, 1, padding=1), | |
| nn.Conv2d(1024, 2, 3, 1, padding=1), | |
| nn.Conv2d(512, 2, 3, 1, padding=1), | |
| nn.Conv2d(256, 2, 3, 1, padding=1), | |
| ]) | |
| self.softmax = nn.Softmax(dim=-1) | |
| self.detect = Detect() | |
| def forward(self, x): | |
| size = x.size()[2:] | |
| sources = list() | |
| loc = list() | |
| conf = list() | |
| for k in range(16): | |
| x = self.vgg[k](x) | |
| s = self.L2Norm3_3(x) | |
| sources.append(s) | |
| for k in range(16, 23): | |
| x = self.vgg[k](x) | |
| s = self.L2Norm4_3(x) | |
| sources.append(s) | |
| for k in range(23, 30): | |
| x = self.vgg[k](x) | |
| s = self.L2Norm5_3(x) | |
| sources.append(s) | |
| for k in range(30, len(self.vgg)): | |
| x = self.vgg[k](x) | |
| sources.append(x) | |
| # apply extra layers and cache source layer outputs | |
| for k, v in enumerate(self.extras): | |
| x = F.relu(v(x), inplace=True) | |
| if k % 2 == 1: | |
| sources.append(x) | |
| # apply multibox head to source layers | |
| loc_x = self.loc[0](sources[0]) | |
| conf_x = self.conf[0](sources[0]) | |
| max_conf, _ = torch.max(conf_x[:, 0:3, :, :], dim=1, keepdim=True) | |
| conf_x = torch.cat((max_conf, conf_x[:, 3:, :, :]), dim=1) | |
| loc.append(loc_x.permute(0, 2, 3, 1).contiguous()) | |
| conf.append(conf_x.permute(0, 2, 3, 1).contiguous()) | |
| for i in range(1, len(sources)): | |
| x = sources[i] | |
| conf.append(self.conf[i](x).permute(0, 2, 3, 1).contiguous()) | |
| loc.append(self.loc[i](x).permute(0, 2, 3, 1).contiguous()) | |
| features_maps = [] | |
| for i in range(len(loc)): | |
| feat = [] | |
| feat += [loc[i].size(1), loc[i].size(2)] | |
| features_maps += [feat] | |
| loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1) | |
| conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1) | |
| with torch.no_grad(): | |
| self.priorbox = PriorBox(size, features_maps) | |
| self.priors = self.priorbox.forward() | |
| output = self.detect.forward( | |
| loc.view(loc.size(0), -1, 4), | |
| self.softmax(conf.view(conf.size(0), -1, 2)), | |
| self.priors.type(type(x.data)).to(self.device) | |
| ) | |
| return output | |