import torch import torch.nn as nn import torch.nn.functional as F import timm import json from torch.nn.utils import spectral_norm from torchinfo import summary class EncoderBlock(nn.Module): def __init__(self, in_channels, out_channels): super(EncoderBlock, self).__init__() self.conv_block = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.GroupNorm(8, out_channels), nn.LeakyReLU(0.01, inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.GroupNorm(8, out_channels), nn.LeakyReLU(0.01, inplace=True), ) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) def forward(self, x): features = self.conv_block(x) pooled = self.pool(features) return pooled, features class DecoderBlock(nn.Module): def __init__(self, in_channels, skip_channels, out_channels): super(DecoderBlock, self).__init__() self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) self.ag = AttentionGate(F_g=in_channels // 2, F_l=skip_channels, F_int=in_channels // 4) conv_in_channels = in_channels // 2 + skip_channels self.conv_block = nn.Sequential( nn.Conv2d(conv_in_channels, out_channels, kernel_size=3, padding=1), nn.GroupNorm(8, out_channels), nn.LeakyReLU(0.01, inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.GroupNorm(8, out_channels), nn.LeakyReLU(0.01, inplace=True), ) def forward(self, x, skip): x = self.up(x) skip = self.ag(x, skip) x = torch.cat([x, skip], dim=1) x = self.conv_block(x) return x class AttentionGate(nn.Module): def __init__(self, F_g, F_l, F_int): super(AttentionGate, self).__init__() self.W_g = nn.Sequential( nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), nn.GroupNorm(8, F_int), ) self.W_x = nn.Sequential( nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), nn.GroupNorm(8, F_int), ) self.psi = nn.Sequential( nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), nn.GroupNorm(1, 1), nn.Sigmoid(), ) self.relu = nn.LeakyReLU(0.01, inplace=True) def forward(self, g, x): g1 = self.W_g(g) x1 = self.W_x(x) psi = self.relu(g1 + x1) psi = self.psi(psi) return x * psi class ViTUNetColorizer(nn.Module): def __init__(self, vit_model_name="vit_tiny_patch16_224", freeze_vit_epochs=10): super(ViTUNetColorizer, self).__init__() self.vit = timm.create_model(vit_model_name, pretrained=True, num_classes=0) self.vit_embed_dim = self.vit.embed_dim self.vit.head = nn.Identity() self.enc1 = EncoderBlock(1, 16) self.enc2 = EncoderBlock(16, 32) self.enc3 = EncoderBlock(32, 64) self.enc4 = EncoderBlock(64, 128) self.bottleneck_processor = nn.Sequential( nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.GroupNorm(8, 128), nn.LeakyReLU(0.01, inplace=True), nn.AdaptiveAvgPool2d((14, 14)), ) self.fusion_layer = nn.Sequential( nn.Conv2d(128 + self.vit_embed_dim, 128, kernel_size=1), # type: ignore nn.GroupNorm(8, 128), nn.LeakyReLU(0.01, inplace=True), nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.GroupNorm(8, 128), nn.LeakyReLU(0.01, inplace=True), ) self.dec4 = DecoderBlock(128, 64, 64) self.dec3 = DecoderBlock(64, 32, 32) self.dec2 = DecoderBlock(32, 16, 16) self.final_conv = nn.Sequential( nn.Conv2d(16, 8, kernel_size=3, padding=1), nn.GroupNorm(8, 8), nn.LeakyReLU(0.01, inplace=True), nn.Conv2d(8, 2, kernel_size=1), nn.Tanh(), ) self.freeze_vit_epochs = freeze_vit_epochs self.current_epoch = 0 def extract_vit_features(self, x): B = x.shape[0] x_3ch = x.repeat(1, 3, 1, 1) if x_3ch.shape[-1] != 224: x_3ch = F.interpolate( x_3ch, size=(224, 224), mode="bicubic", align_corners=False ) x_vit = self.vit.patch_embed(x_3ch) # type: ignore if hasattr(self.vit, 'pos_embed') and self.vit.pos_embed is not None: x_vit = x_vit + self.vit.pos_embed[:, 1:, :] # type: ignore x_vit = self.vit.pos_drop(x_vit) # type: ignore for block in self.vit.blocks: # type: ignore x_vit = block(x_vit) x_vit = self.vit.norm(x_vit) # type: ignore x_vit = x_vit.transpose(1, 2).reshape(B, self.vit_embed_dim, 14, 14) return x_vit def forward(self, x): x1, skip1 = self.enc1(x) x2, skip2 = self.enc2(x1) x3, skip3 = self.enc3(x2) x4, skip4 = self.enc4(x3) bottleneck = self.bottleneck_processor(x4) vit_features = self.extract_vit_features(x) fused = torch.cat([bottleneck, vit_features], dim=1) fused = self.fusion_layer(fused) fused = F.interpolate(fused, size=x3.shape[2:], mode="bilinear", align_corners=False) d4 = self.dec4(fused, skip3) d3 = self.dec3(d4, skip2) d2 = self.dec2(d3, skip1) out = self.final_conv(d2) return out def set_epoch(self, epoch): self.current_epoch = epoch requires_grad = epoch >= self.freeze_vit_epochs for param in self.vit.parameters(): param.requires_grad = requires_grad def get_param_groups(self, lr_decoder=1e-4, lr_vit=1e-5): vit_params = [] decoder_params = [] for name, param in self.named_parameters(): if "vit" in name: vit_params.append(param) else: decoder_params.append(param) return [ {"params": decoder_params, "lr": lr_decoder}, {"params": vit_params, "lr": lr_vit}, ] class PatchDiscriminator(nn.Module): def __init__(self, in_channels=3, n_filters=64): super(PatchDiscriminator, self).__init__() def discriminator_block(in_filters, out_filters, stride=2): return [ spectral_norm( nn.Conv2d( in_filters, out_filters, kernel_size=4, stride=stride, padding=1 ) ), nn.LeakyReLU(0.01, inplace=True) ] self.model = nn.Sequential( *discriminator_block(in_channels, n_filters), *discriminator_block(n_filters, n_filters * 2), *discriminator_block(n_filters * 2, n_filters * 4), spectral_norm(nn.Conv2d(n_filters * 4, 1, kernel_size=4, padding=1)) ) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Conv2d): nn.init.normal_(m.weight, 0.0, 0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, L, ab): img_input = torch.cat((L, ab), dim=1) return self.model(img_input) if __name__ == "__main__": try: with open("hyperparameters.json", "r") as f: hparams = json.load(f) resolution = hparams.get("resolution", 256) except FileNotFoundError: resolution = 256 print("Using default resolution: 256x256") generator = ViTUNetColorizer() generator_input_size = (1, 1, resolution, resolution) summary(generator, input_size=generator_input_size) discriminator = PatchDiscriminator() discriminator_input_size = [(1, 1, resolution, resolution), (1, 2, resolution, resolution)] summary(discriminator, input_size=discriminator_input_size)