ImageColorizer / model.py
sivakum4's picture
Initial commit
2d3cd80
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import json
from torch.nn.utils import spectral_norm
from torchinfo import summary
class EncoderBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(EncoderBlock, self).__init__()
self.conv_block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.GroupNorm(8, out_channels),
nn.LeakyReLU(0.01, inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.GroupNorm(8, out_channels),
nn.LeakyReLU(0.01, inplace=True),
)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
features = self.conv_block(x)
pooled = self.pool(features)
return pooled, features
class DecoderBlock(nn.Module):
def __init__(self, in_channels, skip_channels, out_channels):
super(DecoderBlock, self).__init__()
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.ag = AttentionGate(F_g=in_channels // 2, F_l=skip_channels, F_int=in_channels // 4)
conv_in_channels = in_channels // 2 + skip_channels
self.conv_block = nn.Sequential(
nn.Conv2d(conv_in_channels, out_channels, kernel_size=3, padding=1),
nn.GroupNorm(8, out_channels),
nn.LeakyReLU(0.01, inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.GroupNorm(8, out_channels),
nn.LeakyReLU(0.01, inplace=True),
)
def forward(self, x, skip):
x = self.up(x)
skip = self.ag(x, skip)
x = torch.cat([x, skip], dim=1)
x = self.conv_block(x)
return x
class AttentionGate(nn.Module):
def __init__(self, F_g, F_l, F_int):
super(AttentionGate, self).__init__()
self.W_g = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.GroupNorm(8, F_int),
)
self.W_x = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.GroupNorm(8, F_int),
)
self.psi = nn.Sequential(
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.GroupNorm(1, 1),
nn.Sigmoid(),
)
self.relu = nn.LeakyReLU(0.01, inplace=True)
def forward(self, g, x):
g1 = self.W_g(g)
x1 = self.W_x(x)
psi = self.relu(g1 + x1)
psi = self.psi(psi)
return x * psi
class ViTUNetColorizer(nn.Module):
def __init__(self, vit_model_name="vit_tiny_patch16_224", freeze_vit_epochs=10):
super(ViTUNetColorizer, self).__init__()
self.vit = timm.create_model(vit_model_name, pretrained=True, num_classes=0)
self.vit_embed_dim = self.vit.embed_dim
self.vit.head = nn.Identity()
self.enc1 = EncoderBlock(1, 16)
self.enc2 = EncoderBlock(16, 32)
self.enc3 = EncoderBlock(32, 64)
self.enc4 = EncoderBlock(64, 128)
self.bottleneck_processor = nn.Sequential(
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.GroupNorm(8, 128),
nn.LeakyReLU(0.01, inplace=True),
nn.AdaptiveAvgPool2d((14, 14)),
)
self.fusion_layer = nn.Sequential(
nn.Conv2d(128 + self.vit_embed_dim, 128, kernel_size=1), # type: ignore
nn.GroupNorm(8, 128),
nn.LeakyReLU(0.01, inplace=True),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.GroupNorm(8, 128),
nn.LeakyReLU(0.01, inplace=True),
)
self.dec4 = DecoderBlock(128, 64, 64)
self.dec3 = DecoderBlock(64, 32, 32)
self.dec2 = DecoderBlock(32, 16, 16)
self.final_conv = nn.Sequential(
nn.Conv2d(16, 8, kernel_size=3, padding=1),
nn.GroupNorm(8, 8),
nn.LeakyReLU(0.01, inplace=True),
nn.Conv2d(8, 2, kernel_size=1),
nn.Tanh(),
)
self.freeze_vit_epochs = freeze_vit_epochs
self.current_epoch = 0
def extract_vit_features(self, x):
B = x.shape[0]
x_3ch = x.repeat(1, 3, 1, 1)
if x_3ch.shape[-1] != 224:
x_3ch = F.interpolate(
x_3ch, size=(224, 224), mode="bicubic", align_corners=False
)
x_vit = self.vit.patch_embed(x_3ch) # type: ignore
if hasattr(self.vit, 'pos_embed') and self.vit.pos_embed is not None:
x_vit = x_vit + self.vit.pos_embed[:, 1:, :] # type: ignore
x_vit = self.vit.pos_drop(x_vit) # type: ignore
for block in self.vit.blocks: # type: ignore
x_vit = block(x_vit)
x_vit = self.vit.norm(x_vit) # type: ignore
x_vit = x_vit.transpose(1, 2).reshape(B, self.vit_embed_dim, 14, 14)
return x_vit
def forward(self, x):
x1, skip1 = self.enc1(x)
x2, skip2 = self.enc2(x1)
x3, skip3 = self.enc3(x2)
x4, skip4 = self.enc4(x3)
bottleneck = self.bottleneck_processor(x4)
vit_features = self.extract_vit_features(x)
fused = torch.cat([bottleneck, vit_features], dim=1)
fused = self.fusion_layer(fused)
fused = F.interpolate(fused, size=x3.shape[2:], mode="bilinear", align_corners=False)
d4 = self.dec4(fused, skip3)
d3 = self.dec3(d4, skip2)
d2 = self.dec2(d3, skip1)
out = self.final_conv(d2)
return out
def set_epoch(self, epoch):
self.current_epoch = epoch
requires_grad = epoch >= self.freeze_vit_epochs
for param in self.vit.parameters():
param.requires_grad = requires_grad
def get_param_groups(self, lr_decoder=1e-4, lr_vit=1e-5):
vit_params = []
decoder_params = []
for name, param in self.named_parameters():
if "vit" in name:
vit_params.append(param)
else:
decoder_params.append(param)
return [
{"params": decoder_params, "lr": lr_decoder},
{"params": vit_params, "lr": lr_vit},
]
class PatchDiscriminator(nn.Module):
def __init__(self, in_channels=3, n_filters=64):
super(PatchDiscriminator, self).__init__()
def discriminator_block(in_filters, out_filters, stride=2):
return [
spectral_norm(
nn.Conv2d(
in_filters, out_filters, kernel_size=4, stride=stride, padding=1
)
),
nn.LeakyReLU(0.01, inplace=True)
]
self.model = nn.Sequential(
*discriminator_block(in_channels, n_filters),
*discriminator_block(n_filters, n_filters * 2),
*discriminator_block(n_filters * 2, n_filters * 4),
spectral_norm(nn.Conv2d(n_filters * 4, 1, kernel_size=4, padding=1))
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, 0.0, 0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, L, ab):
img_input = torch.cat((L, ab), dim=1)
return self.model(img_input)
if __name__ == "__main__":
try:
with open("hyperparameters.json", "r") as f:
hparams = json.load(f)
resolution = hparams.get("resolution", 256)
except FileNotFoundError:
resolution = 256
print("Using default resolution: 256x256")
generator = ViTUNetColorizer()
generator_input_size = (1, 1, resolution, resolution)
summary(generator, input_size=generator_input_size)
discriminator = PatchDiscriminator()
discriminator_input_size = [(1, 1, resolution, resolution), (1, 2, resolution, resolution)]
summary(discriminator, input_size=discriminator_input_size)