Custom_ResNet / model.py
HemaAM's picture
Upload custom resnet model
96d6bb6
import torch.nn as nn
import torchinfo
import torch.nn.functional as F
class convLayer(nn.Module):
def __init__(self, l_input_c,
l_output_c, bias=False,
padding=1, stride=1,
max_pooling=False,
dropout=0):
super (convLayer, self).__init__()
self.convLayer = nn.Conv2d(in_channels=l_input_c,
out_channels=l_output_c,
kernel_size=(3, 3),
stride=stride,
padding= padding,
padding_mode='replicate',
bias=bias)
self.max_pooling = None
if(max_pooling == True):
self.max_pooling = nn.MaxPool2d(2, 2)
self.normLayer = nn.BatchNorm2d(l_output_c)
self.activationLayer = nn.ReLU()
self.dropout = None
if(dropout > 0):
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.convLayer(x)
if (self.max_pooling is not None):
x = self.max_pooling(x)
x = self.normLayer(x)
x = self.activationLayer(x)
if (self.dropout is not None):
x = self.dropout(x)
return x
class custBlock(nn.Module):
def __init__(self, l_input_c,
l_output_c, bias=False,
padding=1, stride=1,
max_pooling=True,
dropout=0, residual_links=2):
super (custBlock, self).__init__()
self.conv_pool_block = convLayer(l_input_c=l_input_c,
l_output_c=l_output_c,
bias=bias, padding=padding,
stride=stride, max_pooling=max_pooling,
dropout=dropout)
self.residual_block = None
if(residual_links > 0):
res_layer_seq = []
for link in range(0, residual_links):
res_layer_seq.append(
convLayer(l_input_c=l_output_c,
l_output_c=l_output_c,
bias=bias, padding=padding,
stride=stride, max_pooling=False,
dropout=dropout)
)
self.residual_block = nn.Sequential(*res_layer_seq)
def forward(self, x):
x = self.conv_pool_block(x)
if (self.residual_block is not None):
tmp_x = x
x = self.residual_block(x)
x = x + tmp_x
return x
class custResNet(nn.Module):
def __init__(self, dropout=0):
super(custResNet, self).__init__()
##### Prep Block #####
# This block has a 3x3 convolution layer with stride=1,
# padding=1 followed by batch normalization and RELU
# 64 kernels are used in this block
# By default, dropout is set to 0
self.prep_block = custBlock(l_input_c=3, l_output_c=64,
max_pooling=False, dropout= dropout,
residual_links=0
) # output_size = 32, rf_out = 3
##### Convolution Block - 1 #####
# This block in the first step has a 3x3 convolution layer with
# stride=1, padding=1 followed by Max pooling, batch normalization
# and ReLU. In the second step a network with residual links, with
# each link having 3x3 convolution with stride=1, padding=1
# followed by batch normalization and ReLU. And in the third step,
# the result from the first step and the result of the residual network
# from the second step are added to make a skip connection.
# 128 kernels are used in this block
# By default, dropout is set to 0
self.block1 = custBlock(l_input_c=64, l_output_c=128,
max_pooling=True, dropout= dropout,
residual_links=2
) # output_size = 16, rf_out = 13
##### Convolution Block - 2 #####
# This block in the first step has a 3x3 convolution layer with
# stride=1, padding=1 followed by Max pooling, batch normalization
# and ReLU.
# 256 kernels are used in this block
# By default, dropout is set to 0
self.block2 = custBlock(l_input_c=128, l_output_c=256,
max_pooling=True, dropout= dropout,
residual_links=0
) # output_size = 8, rf_out = 17
##### Convolution Block - 3 #####
# This block in the first step has a 3x3 convolution layer with
# stride=1, padding=1 followed by Max pooling, batch normalization
# and ReLU. In the second step a network with residual links, with
# each link having 3x3 convolution with stride=1, padding=1
# followed by batch normalization and ReLU. And in the third step,
# the result from the first step and the result of the residual network
# from the second step are added to make a skip connection.
# 512 kernels are used in this block
# By default, dropout is set to 0
self.block3 = custBlock(l_input_c=256, l_output_c=512,
max_pooling=True, dropout= dropout,
residual_links=2
) # output_size = 4, rf_out = 57
self.max_pool_layer = nn.MaxPool2d(4, 4)
# output_size = 1, rf_out = 81
self.flatten_layer = nn.Flatten()
self.fc = nn.Linear(512, 10)
#self.softmax = nn.Softmax()
def forward(self, x):
x = self.prep_block(x)
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.max_pool_layer(x)
x = self.flatten_layer(x)
x = self.fc(x)
return x
# Network Summary
def summary(self, input_size=None, depth=10):
return torchinfo.summary(self, input_size=input_size,
depth=depth,
col_names=["input_size",
"output_size",
"num_params",
"kernel_size",
"params_percent"])