Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Author: Soubhik Sanyal | |
| Copyright (c) 2019, Soubhik Sanyal | |
| All rights reserved. | |
| Loads different resnet models | |
| """ | |
| ''' | |
| file: Resnet.py | |
| date: 2018_05_02 | |
| author: zhangxiong([email protected]) | |
| mark: copied from pytorch source code | |
| ''' | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch | |
| from torch.nn.parameter import Parameter | |
| import torch.optim as optim | |
| import numpy as np | |
| import math | |
| import torchvision | |
| class ResNet(nn.Module): | |
| def __init__(self, block, layers, num_classes=1000): | |
| self.inplanes = 64 | |
| super(ResNet, self).__init__() | |
| self.conv1 = nn.Conv2d(3, | |
| 64, | |
| kernel_size=7, | |
| stride=2, | |
| padding=3, | |
| bias=False) | |
| self.bn1 = nn.BatchNorm2d(64) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |
| self.layer1 = self._make_layer(block, 64, layers[0]) | |
| self.layer2 = self._make_layer(block, 128, layers[1], stride=2) | |
| self.layer3 = self._make_layer(block, 256, layers[2], stride=2) | |
| self.layer4 = self._make_layer(block, 512, layers[3], stride=2) | |
| self.avgpool = nn.AvgPool2d(7, stride=1) | |
| # self.fc = nn.Linear(512 * block.expansion, num_classes) | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |
| m.weight.data.normal_(0, math.sqrt(2. / n)) | |
| elif isinstance(m, nn.BatchNorm2d): | |
| m.weight.data.fill_(1) | |
| m.bias.data.zero_() | |
| def _make_layer(self, block, planes, blocks, stride=1): | |
| downsample = None | |
| if stride != 1 or self.inplanes != planes * block.expansion: | |
| downsample = nn.Sequential( | |
| nn.Conv2d(self.inplanes, | |
| planes * block.expansion, | |
| kernel_size=1, | |
| stride=stride, | |
| bias=False), | |
| nn.BatchNorm2d(planes * block.expansion), | |
| ) | |
| layers = [] | |
| layers.append(block(self.inplanes, planes, stride, downsample)) | |
| self.inplanes = planes * block.expansion | |
| for i in range(1, blocks): | |
| layers.append(block(self.inplanes, planes)) | |
| return nn.Sequential(*layers) | |
| def forward(self, x): | |
| x = self.conv1(x) | |
| x = self.bn1(x) | |
| x = self.relu(x) | |
| x = self.maxpool(x) | |
| x = self.layer1(x) | |
| x = self.layer2(x) | |
| x = self.layer3(x) | |
| x1 = self.layer4(x) | |
| x2 = self.avgpool(x1) | |
| x2 = x2.view(x2.size(0), -1) | |
| # x = self.fc(x) | |
| # x2: [bz, 2048] for shape | |
| # x1: [bz, 2048, 7, 7] for texture | |
| return x2 | |
| class Bottleneck(nn.Module): | |
| expansion = 4 | |
| def __init__(self, inplanes, planes, stride=1, downsample=None): | |
| super(Bottleneck, self).__init__() | |
| self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) | |
| self.bn1 = nn.BatchNorm2d(planes) | |
| self.conv2 = nn.Conv2d(planes, | |
| planes, | |
| kernel_size=3, | |
| stride=stride, | |
| padding=1, | |
| bias=False) | |
| self.bn2 = nn.BatchNorm2d(planes) | |
| self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) | |
| self.bn3 = nn.BatchNorm2d(planes * 4) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.downsample = downsample | |
| self.stride = stride | |
| def forward(self, x): | |
| residual = x | |
| out = self.conv1(x) | |
| out = self.bn1(out) | |
| out = self.relu(out) | |
| out = self.conv2(out) | |
| out = self.bn2(out) | |
| out = self.relu(out) | |
| out = self.conv3(out) | |
| out = self.bn3(out) | |
| if self.downsample is not None: | |
| residual = self.downsample(x) | |
| out += residual | |
| out = self.relu(out) | |
| return out | |
| def conv3x3(in_planes, out_planes, stride=1): | |
| """3x3 convolution with padding""" | |
| return nn.Conv2d(in_planes, | |
| out_planes, | |
| kernel_size=3, | |
| stride=stride, | |
| padding=1, | |
| bias=False) | |
| class BasicBlock(nn.Module): | |
| expansion = 1 | |
| def __init__(self, inplanes, planes, stride=1, downsample=None): | |
| super(BasicBlock, self).__init__() | |
| self.conv1 = conv3x3(inplanes, planes, stride) | |
| self.bn1 = nn.BatchNorm2d(planes) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.conv2 = conv3x3(planes, planes) | |
| self.bn2 = nn.BatchNorm2d(planes) | |
| self.downsample = downsample | |
| self.stride = stride | |
| def forward(self, x): | |
| residual = x | |
| out = self.conv1(x) | |
| out = self.bn1(out) | |
| out = self.relu(out) | |
| out = self.conv2(out) | |
| out = self.bn2(out) | |
| if self.downsample is not None: | |
| residual = self.downsample(x) | |
| out += residual | |
| out = self.relu(out) | |
| return out | |
| def copy_parameter_from_resnet(model, resnet_dict): | |
| cur_state_dict = model.state_dict() | |
| # import ipdb; ipdb.set_trace() | |
| for name, param in list(resnet_dict.items())[0:None]: | |
| if name not in cur_state_dict: | |
| # print(name, ' not available in reconstructed resnet') | |
| continue | |
| if isinstance(param, Parameter): | |
| param = param.data | |
| try: | |
| cur_state_dict[name].copy_(param) | |
| except: | |
| # print(name, ' is inconsistent!') | |
| continue | |
| # print('copy resnet state dict finished!') | |
| def load_ResNet50Model(): | |
| model = ResNet(Bottleneck, [3, 4, 6, 3]) | |
| copy_parameter_from_resnet( | |
| model, | |
| torchvision.models.resnet50(pretrained=True).state_dict()) | |
| return model | |
| def load_ResNet101Model(): | |
| model = ResNet(Bottleneck, [3, 4, 23, 3]) | |
| copy_parameter_from_resnet( | |
| model, | |
| torchvision.models.resnet101(pretrained=True).state_dict()) | |
| return model | |
| def load_ResNet152Model(): | |
| model = ResNet(Bottleneck, [3, 8, 36, 3]) | |
| copy_parameter_from_resnet( | |
| model, | |
| torchvision.models.resnet152(pretrained=True).state_dict()) | |
| return model | |
| # model.load_state_dict(checkpoint['model_state_dict']) | |
| # Unet | |
| class DoubleConv(nn.Module): | |
| """(convolution => [BN] => ReLU) * 2""" | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| self.double_conv = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), | |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)) | |
| def forward(self, x): | |
| return self.double_conv(x) | |
| class Down(nn.Module): | |
| """Downscaling with maxpool then double conv""" | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| self.maxpool_conv = nn.Sequential( | |
| nn.MaxPool2d(2), DoubleConv(in_channels, out_channels)) | |
| def forward(self, x): | |
| return self.maxpool_conv(x) | |
| class Up(nn.Module): | |
| """Upscaling then double conv""" | |
| def __init__(self, in_channels, out_channels, bilinear=True): | |
| super().__init__() | |
| # if bilinear, use the normal convolutions to reduce the number of channels | |
| if bilinear: | |
| self.up = nn.Upsample(scale_factor=2, | |
| mode='bilinear', | |
| align_corners=True) | |
| else: | |
| self.up = nn.ConvTranspose2d(in_channels // 2, | |
| in_channels // 2, | |
| kernel_size=2, | |
| stride=2) | |
| self.conv = DoubleConv(in_channels, out_channels) | |
| def forward(self, x1, x2): | |
| x1 = self.up(x1) | |
| # input is CHW | |
| diffY = x2.size()[2] - x1.size()[2] | |
| diffX = x2.size()[3] - x1.size()[3] | |
| x1 = F.pad( | |
| x1, | |
| [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) | |
| # if you have padding issues, see | |
| # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a | |
| # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd | |
| x = torch.cat([x2, x1], dim=1) | |
| return self.conv(x) | |
| class OutConv(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(OutConv, self).__init__() | |
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) | |
| def forward(self, x): | |
| return self.conv(x) | |
| class UNet(nn.Module): | |
| def __init__(self, n_channels, n_classes, bilinear=True): | |
| super(UNet, self).__init__() | |
| self.n_channels = n_channels | |
| self.n_classes = n_classes | |
| self.bilinear = bilinear | |
| self.inc = DoubleConv(n_channels, 64) | |
| self.down1 = Down(64, 128) | |
| self.down2 = Down(128, 256) | |
| self.down3 = Down(256, 512) | |
| self.down4 = Down(512, 512) | |
| self.up1 = Up(1024, 256, bilinear) | |
| self.up2 = Up(512, 128, bilinear) | |
| self.up3 = Up(256, 64, bilinear) | |
| self.up4 = Up(128, 64, bilinear) | |
| self.outc = OutConv(64, n_classes) | |
| def forward(self, x): | |
| x1 = self.inc(x) | |
| x2 = self.down1(x1) | |
| x3 = self.down2(x2) | |
| x4 = self.down3(x3) | |
| x5 = self.down4(x4) | |
| x = self.up1(x5, x4) | |
| x = self.up2(x, x3) | |
| x = self.up3(x, x2) | |
| x = self.up4(x, x1) | |
| x = F.normalize(x) | |
| return x | |