|
""" |
|
UNet Network in PyTorch, modified from https://github.com/milesial/Pytorch-UNet |
|
with architecture referenced from https://keras.io/examples/vision/depth_estimation |
|
for monocular depth estimation from RGB images, i.e. one output channel. |
|
""" |
|
|
|
import torch |
|
from torch import nn |
|
|
|
|
|
class UNet(nn.Module): |
|
""" |
|
The overall UNet architecture. |
|
""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
self.downscale_blocks = nn.ModuleList( |
|
[ |
|
DownBlock(16, 32), |
|
DownBlock(32, 64), |
|
DownBlock(64, 128), |
|
DownBlock(128, 256), |
|
] |
|
) |
|
self.upscale_blocks = nn.ModuleList( |
|
[ |
|
UpBlock(256, 128), |
|
UpBlock(128, 64), |
|
UpBlock(64, 32), |
|
UpBlock(32, 16), |
|
] |
|
) |
|
|
|
self.input_conv = nn.Conv2d(3, 16, kernel_size=3, padding="same") |
|
self.output_conv = nn.Conv2d(16, 1, kernel_size=1) |
|
self.bridge = BottleNeckBlock(256) |
|
self.activation = nn.Sigmoid() |
|
|
|
def forward(self, x): |
|
x = self.input_conv(x) |
|
|
|
skip_features = [] |
|
for block in self.downscale_blocks: |
|
c, x = block(x) |
|
skip_features.append(c) |
|
|
|
x = self.bridge(x) |
|
|
|
skip_features.reverse() |
|
for block, skip in zip(self.upscale_blocks, skip_features): |
|
x = block(x, skip) |
|
|
|
x = self.output_conv(x) |
|
x = self.activation(x) |
|
return x |
|
|
|
|
|
class DownBlock(nn.Module): |
|
""" |
|
Module that performs downscaling with residual connections. |
|
""" |
|
|
|
def __init__(self, in_channels, out_channels, padding="same", stride=1): |
|
super().__init__() |
|
self.conv1 = nn.Conv2d( |
|
in_channels, |
|
out_channels, |
|
kernel_size=3, |
|
stride=stride, |
|
padding=padding, |
|
bias=False, |
|
) |
|
self.conv2 = nn.Conv2d( |
|
out_channels, |
|
out_channels, |
|
kernel_size=3, |
|
stride=stride, |
|
padding=padding, |
|
bias=False, |
|
) |
|
self.bn1 = nn.BatchNorm2d(out_channels) |
|
self.bn2 = nn.BatchNorm2d(out_channels) |
|
self.relu = nn.LeakyReLU(0.2) |
|
self.maxpool = nn.MaxPool2d(2) |
|
|
|
def forward(self, x): |
|
d = self.conv1(x) |
|
x = self.bn1(d) |
|
x = self.relu(x) |
|
|
|
x = self.conv2(x) |
|
x = self.bn2(x) |
|
x = self.relu(x) |
|
|
|
x = x + d |
|
p = self.maxpool(x) |
|
return x, p |
|
|
|
|
|
class UpBlock(nn.Module): |
|
""" |
|
Module that performs upscaling after concatenation with skip connections. |
|
""" |
|
|
|
def __init__(self, in_channels, out_channels, padding="same", stride=1): |
|
super().__init__() |
|
self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) |
|
self.conv1 = nn.Conv2d( |
|
in_channels * 2, |
|
in_channels, |
|
kernel_size=3, |
|
stride=stride, |
|
padding=padding, |
|
bias=False, |
|
) |
|
self.conv2 = nn.Conv2d( |
|
in_channels, |
|
out_channels, |
|
kernel_size=3, |
|
stride=stride, |
|
padding=padding, |
|
bias=False, |
|
) |
|
self.bn1 = nn.BatchNorm2d(in_channels) |
|
self.bn2 = nn.BatchNorm2d(out_channels) |
|
self.relu = nn.LeakyReLU(0.2) |
|
|
|
def forward(self, x, skip): |
|
x = self.up(x) |
|
x = torch.cat([x, skip], dim=1) |
|
|
|
x = self.conv1(x) |
|
x = self.bn1(x) |
|
x = self.relu(x) |
|
|
|
x = self.conv2(x) |
|
x = self.bn2(x) |
|
x = self.relu(x) |
|
return x |
|
|
|
|
|
class BottleNeckBlock(nn.Module): |
|
""" |
|
BottleNeckBlock that serves as the UNet bridge. |
|
""" |
|
|
|
def __init__(self, channels, padding="same", strides=1): |
|
super().__init__() |
|
self.conv1 = nn.Conv2d(channels, channels, 3, 1, "same") |
|
self.conv2 = nn.Conv2d(channels, channels, 3, 1, "same") |
|
self.relu = nn.LeakyReLU(0.2) |
|
|
|
def forward(self, x): |
|
x = self.conv1(x) |
|
x = self.relu(x) |
|
x = self.conv2(x) |
|
x = self.relu(x) |
|
return x |