Spaces:
Configuration error
Configuration error
| # Copyright 2020 Google LLC | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # https://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import importlib | |
| def class_for_name(module_name, class_name): | |
| # load the module, will raise ImportError if module cannot be loaded | |
| m = importlib.import_module(module_name) | |
| return getattr(m, class_name) | |
| def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): | |
| """3x3 convolution with padding""" | |
| return nn.Conv2d( | |
| in_planes, | |
| out_planes, | |
| kernel_size=3, | |
| stride=stride, | |
| padding=dilation, | |
| groups=groups, | |
| bias=False, | |
| dilation=dilation, | |
| padding_mode="reflect", | |
| ) | |
| def conv1x1(in_planes, out_planes, stride=1): | |
| """1x1 convolution""" | |
| return nn.Conv2d( | |
| in_planes, | |
| out_planes, | |
| kernel_size=1, | |
| stride=stride, | |
| bias=False, | |
| padding_mode="reflect", | |
| ) | |
| class BasicBlock(nn.Module): | |
| expansion = 1 | |
| def __init__( | |
| self, | |
| inplanes, | |
| planes, | |
| stride=1, | |
| downsample=None, | |
| groups=1, | |
| base_width=64, | |
| dilation=1, | |
| norm_layer=None, | |
| ): | |
| super(BasicBlock, self).__init__() | |
| if norm_layer is None: | |
| norm_layer = nn.BatchNorm2d | |
| # norm_layer = nn.InstanceNorm2d | |
| if groups != 1 or base_width != 64: | |
| raise ValueError("BasicBlock only supports groups=1 and base_width=64") | |
| if dilation > 1: | |
| raise NotImplementedError("Dilation > 1 not supported in BasicBlock") | |
| # Both self.conv1 and self.downsample layers downsample the input when stride != 1 | |
| self.conv1 = conv3x3(inplanes, planes, stride) | |
| self.bn1 = norm_layer(planes, track_running_stats=False, affine=True) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.conv2 = conv3x3(planes, planes) | |
| self.bn2 = norm_layer(planes, track_running_stats=False, affine=True) | |
| self.downsample = downsample | |
| self.stride = stride | |
| def forward(self, x): | |
| identity = x | |
| out = self.conv1(x) | |
| out = self.bn1(out) | |
| out = self.relu(out) | |
| out = self.conv2(out) | |
| out = self.bn2(out) | |
| if self.downsample is not None: | |
| identity = self.downsample(x) | |
| out += identity | |
| out = self.relu(out) | |
| return out | |
| class Bottleneck(nn.Module): | |
| # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) | |
| # while original implementation places the stride at the first 1x1 convolution(self.conv1) | |
| # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. | |
| # This variant is also known as ResNet V1.5 and improves accuracy according to | |
| # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. | |
| expansion = 4 | |
| def __init__( | |
| self, | |
| inplanes, | |
| planes, | |
| stride=1, | |
| downsample=None, | |
| groups=1, | |
| base_width=64, | |
| dilation=1, | |
| norm_layer=None, | |
| ): | |
| super(Bottleneck, self).__init__() | |
| if norm_layer is None: | |
| norm_layer = nn.BatchNorm2d | |
| # norm_layer = nn.InstanceNorm2d | |
| width = int(planes * (base_width / 64.0)) * groups | |
| # Both self.conv2 and self.downsample layers downsample the input when stride != 1 | |
| self.conv1 = conv1x1(inplanes, width) | |
| self.bn1 = norm_layer(width, track_running_stats=False, affine=True) | |
| self.conv2 = conv3x3(width, width, stride, groups, dilation) | |
| self.bn2 = norm_layer(width, track_running_stats=False, affine=True) | |
| self.conv3 = conv1x1(width, planes * self.expansion) | |
| self.bn3 = norm_layer( | |
| planes * self.expansion, track_running_stats=False, affine=True | |
| ) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.downsample = downsample | |
| self.stride = stride | |
| def forward(self, x): | |
| identity = x | |
| out = self.conv1(x) | |
| out = self.bn1(out) | |
| out = self.relu(out) | |
| out = self.conv2(out) | |
| out = self.bn2(out) | |
| out = self.relu(out) | |
| out = self.conv3(out) | |
| out = self.bn3(out) | |
| if self.downsample is not None: | |
| identity = self.downsample(x) | |
| out += identity | |
| out = self.relu(out) | |
| return out | |
| class conv(nn.Module): | |
| def __init__(self, num_in_layers, num_out_layers, kernel_size, stride): | |
| super(conv, self).__init__() | |
| self.kernel_size = kernel_size | |
| self.conv = nn.Conv2d( | |
| num_in_layers, | |
| num_out_layers, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=(self.kernel_size - 1) // 2, | |
| padding_mode="reflect", | |
| ) | |
| # self.bn = nn.InstanceNorm2d( | |
| # num_out_layers, track_running_stats=False, affine=True | |
| # ) | |
| self.bn = nn.BatchNorm2d(num_out_layers, track_running_stats=False, affine=True) | |
| # self.bn = nn.LayerNorm(num_out_layers) | |
| def forward(self, x): | |
| return F.elu(self.bn(self.conv(x)), inplace=True) | |
| class upconv(nn.Module): | |
| def __init__(self, num_in_layers, num_out_layers, kernel_size, scale): | |
| super(upconv, self).__init__() | |
| self.scale = scale | |
| self.conv = conv(num_in_layers, num_out_layers, kernel_size, 1) | |
| def forward(self, x): | |
| x = nn.functional.interpolate( | |
| x, scale_factor=self.scale, align_corners=True, mode="bilinear" | |
| ) | |
| return self.conv(x) | |
| class ResUNet(nn.Module): | |
| def __init__( | |
| self, | |
| encoder="resnet34", | |
| coarse_out_ch=32, | |
| fine_out_ch=32, | |
| norm_layer=None, | |
| coarse_only=False, | |
| ): | |
| super(ResUNet, self).__init__() | |
| assert encoder in [ | |
| "resnet18", | |
| "resnet34", | |
| "resnet50", | |
| "resnet101", | |
| "resnet152", | |
| ], "Incorrect encoder type" | |
| if encoder in ["resnet18", "resnet34"]: | |
| filters = [64, 128, 256, 512] | |
| else: | |
| filters = [256, 512, 1024, 2048] | |
| self.coarse_only = coarse_only | |
| if self.coarse_only: | |
| fine_out_ch = 0 | |
| self.coarse_out_ch = coarse_out_ch | |
| self.fine_out_ch = fine_out_ch | |
| out_ch = coarse_out_ch + fine_out_ch | |
| # original | |
| layers = [3, 4, 6, 3] | |
| if norm_layer is None: | |
| norm_layer = nn.BatchNorm2d | |
| # norm_layer = nn.InstanceNorm2d | |
| self._norm_layer = norm_layer | |
| self.dilation = 1 | |
| block = BasicBlock | |
| replace_stride_with_dilation = [False, False, False] | |
| self.inplanes = 64 | |
| self.groups = 1 | |
| self.base_width = 64 | |
| self.conv1 = nn.Conv2d( | |
| 3, | |
| self.inplanes, | |
| kernel_size=7, | |
| stride=2, | |
| padding=3, | |
| bias=False, | |
| padding_mode="reflect", | |
| ) | |
| self.bn1 = norm_layer(self.inplanes, track_running_stats=False, affine=True) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.layer1 = self._make_layer(block, 64, layers[0], stride=2) | |
| self.layer2 = self._make_layer( | |
| block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] | |
| ) | |
| self.layer3 = self._make_layer( | |
| block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] | |
| ) | |
| # decoder | |
| self.upconv3 = upconv(filters[2], 128, 3, 2) | |
| self.iconv3 = conv(filters[1] + 128, 128, 3, 1) | |
| self.upconv2 = upconv(128, 64, 3, 2) | |
| self.iconv2 = conv(filters[0] + 64, out_ch, 3, 1) | |
| # fine-level conv | |
| self.out_conv = nn.Conv2d(out_ch, out_ch, 1, 1) | |
| def _make_layer(self, block, planes, blocks, stride=1, dilate=False): | |
| norm_layer = self._norm_layer | |
| downsample = None | |
| previous_dilation = self.dilation | |
| if dilate: | |
| self.dilation *= stride | |
| stride = 1 | |
| if stride != 1 or self.inplanes != planes * block.expansion: | |
| downsample = nn.Sequential( | |
| conv1x1(self.inplanes, planes * block.expansion, stride), | |
| norm_layer( | |
| planes * block.expansion, track_running_stats=False, affine=True | |
| ), | |
| ) | |
| layers = [] | |
| layers.append( | |
| block( | |
| self.inplanes, | |
| planes, | |
| stride, | |
| downsample, | |
| self.groups, | |
| self.base_width, | |
| previous_dilation, | |
| norm_layer, | |
| ) | |
| ) | |
| self.inplanes = planes * block.expansion | |
| for _ in range(1, blocks): | |
| layers.append( | |
| block( | |
| self.inplanes, | |
| planes, | |
| groups=self.groups, | |
| base_width=self.base_width, | |
| dilation=self.dilation, | |
| norm_layer=norm_layer, | |
| ) | |
| ) | |
| return nn.Sequential(*layers) | |
| def skipconnect(self, x1, x2): | |
| diffY = x2.size()[2] - x1.size()[2] | |
| diffX = x2.size()[3] - x1.size()[3] | |
| x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2)) | |
| # for padding issues, see | |
| # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a | |
| # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd | |
| x = torch.cat([x2, x1], dim=1) | |
| return x | |
| def forward(self, x): | |
| x = self.relu(self.bn1(self.conv1(x))) | |
| x1 = self.layer1(x) | |
| x2 = self.layer2(x1) | |
| x3 = self.layer3(x2) | |
| x = self.upconv3(x3) | |
| x = self.skipconnect(x2, x) | |
| x = self.iconv3(x) | |
| x = self.upconv2(x) | |
| x = self.skipconnect(x1, x) | |
| x = self.iconv2(x) | |
| x_out = self.out_conv(x) | |
| return x_out | |
| # if self.coarse_only: | |
| # x_coarse = x_out | |
| # x_fine = None | |
| # else: | |
| # x_coarse = x_out[:, : self.coarse_out_ch, :] | |
| # x_fine = x_out[:, -self.fine_out_ch :, :] | |
| # return x_coarse, x_fine | |