Nature-Nexus / utils /model.py
smokxy's picture
add codebase
8c38d83
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class AttentionGate(nn.Module):
def __init__(self, F_g, F_l, F_int):
super(AttentionGate, self).__init__()
self.W_g = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int),
)
self.W_x = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int),
)
self.psi = nn.Sequential(
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(1),
nn.Sigmoid(),
)
self.relu = nn.ReLU(inplace=True)
def forward(self, g, x):
g1 = self.W_g(g)
x1 = self.W_x(x)
psi = self.relu(g1 + x1)
psi = self.psi(psi)
return x * psi
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.conv(x)
class AttentionUNet(nn.Module):
def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
super(AttentionUNet, self).__init__()
self.ups = nn.ModuleList()
self.downs = nn.ModuleList()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.attention_gates = nn.ModuleList()
# Down part of U-Net
for feature in features:
self.downs.append(DoubleConv(in_channels, feature))
in_channels = feature
# Up part of U-Net
for feature in reversed(features):
self.ups.append(
nn.ConvTranspose2d(
feature * 2,
feature,
kernel_size=2,
stride=2,
)
)
# Attention Gate
self.attention_gates.append(
AttentionGate(F_g=feature, F_l=feature, F_int=feature // 2)
)
self.ups.append(DoubleConv(feature * 2, feature))
# Bottleneck
self.bottleneck = DoubleConv(features[-1], features[-1] * 2)
# Final Conv
self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
def forward(self, x):
skip_connections = []
# Encoder path
for down in self.downs:
x = down(x)
skip_connections.append(x)
x = self.pool(x)
x = self.bottleneck(x)
skip_connections = skip_connections[::-1] # Reverse to use from back
# Decoder path
for idx in range(0, len(self.ups), 2):
x = self.ups[idx](x)
skip_connection = skip_connections[idx // 2]
# If sizes don't match
if x.shape != skip_connection.shape:
x = F.interpolate(x, size=skip_connection.shape[2:])
# Apply attention gate
skip_connection = self.attention_gates[idx // 2](g=x, x=skip_connection)
# Concatenate
concat_skip = torch.cat((skip_connection, x), dim=1)
x = self.ups[idx + 1](concat_skip)
# Final conv
return self.final_conv(x)
def load_model(model_path):
"""
Load the trained model
Args:
model_path: Path to the model weights
Returns:
Loaded model
"""
# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize model
model = AttentionUNet(in_channels=3, out_channels=1)
# Load model weights
model.load_state_dict(torch.load(model_path, map_location=device))
# Set model to evaluation mode
model.eval()
return model.to(device)