Spaces:
Running
on
Zero
Running
on
Zero
from torch import nn, Tensor | |
from torch.hub import load_state_dict_from_url | |
from typing import Optional | |
from .utils import make_vgg_layers, vgg_cfgs, vgg_urls | |
from ..utils import _init_weights, _get_norm_layer, _get_activation | |
from ..utils import ConvDownsample, ConvUpsample | |
vgg_models = [ | |
"vgg11", "vgg11_bn", | |
"vgg13", "vgg13_bn", | |
"vgg16", "vgg16_bn", | |
"vgg19", "vgg19_bn", | |
] | |
decoder_cfg = [512, 256, 128] | |
class VGGEncoder(nn.Module): | |
def __init__( | |
self, | |
model_name: str, | |
block_size: Optional[int] = None, | |
norm: str = "none", | |
act: str = "none", | |
) -> None: | |
super().__init__() | |
assert model_name in vgg_models, f"Model name should be one of {vgg_models}, but got {model_name}." | |
assert block_size is None or block_size in [8, 16, 32], f"Block size should be one of [8, 16, 32], but got {block_size}." | |
self.model_name = model_name | |
if model_name == "vgg11": | |
self.encoder = vgg11() | |
elif model_name == "vgg11_bn": | |
self.encoder = vgg11_bn() | |
elif model_name == "vgg13": | |
self.encoder = vgg13() | |
elif model_name == "vgg13_bn": | |
self.encoder = vgg13_bn() | |
elif model_name == "vgg16": | |
self.encoder = vgg16() | |
elif model_name == "vgg16_bn": | |
self.encoder = vgg16_bn() | |
elif model_name == "vgg19": | |
self.encoder = vgg19() | |
else: # model_name == "vgg19_bn" | |
self.encoder = vgg19_bn() | |
self.encoder_channels = 512 | |
self.encoder_reduction = 16 | |
self.block_size = block_size if block_size is not None else self.encoder_reduction | |
if norm == "bn": | |
norm_layer = nn.BatchNorm2d | |
elif norm == "ln": | |
norm_layer = nn.LayerNorm | |
else: | |
norm_layer = _get_norm_layer(self.encoder) | |
if act == "relu": | |
activation = nn.ReLU(inplace=True) | |
elif act == "gelu": | |
activation = nn.GELU() | |
else: | |
activation = _get_activation(self.encoder) | |
if self.encoder_reduction >= self.block_size: # 8, 16 | |
self.refiner = ConvUpsample( | |
in_channels=self.encoder_channels, | |
out_channels=self.encoder_channels, | |
scale_factor=self.encoder_reduction // self.block_size, | |
norm_layer=norm_layer, | |
activation=activation, | |
) | |
else: # 32 | |
self.refiner = ConvDownsample( | |
in_channels=self.encoder_channels, | |
out_channels=self.encoder_channels, | |
norm_layer=norm_layer, | |
activation=activation, | |
) | |
self.refiner_channels = self.encoder_channels | |
self.refiner_reduction = self.block_size | |
self.decoder = nn.Identity() | |
self.decoder_channels = self.encoder_channels | |
self.decoder_reduction = self.refiner_reduction | |
def encode(self, x: Tensor) -> Tensor: | |
return self.encoder(x) | |
def refine(self, x: Tensor) -> Tensor: | |
return self.refiner(x) | |
def decode(self, x: Tensor) -> Tensor: | |
return self.decoder(x) | |
def forward(self, x: Tensor) -> Tensor: | |
x = self.encode(x) | |
x = self.refine(x) | |
x = self.decode(x) | |
return x | |
class VGGEncoderDecoder(nn.Module): | |
def __init__( | |
self, | |
model_name: str, | |
block_size: Optional[int] = None, | |
norm: str = "none", | |
act: str = "none", | |
) -> None: | |
super().__init__() | |
assert model_name in vgg_models, f"Model name should be one of {vgg_models}, but got {model_name}." | |
assert block_size is None or block_size in [8, 16, 32], f"Block size should be one of [8, 16, 32], but got {block_size}." | |
self.model_name = model_name | |
if model_name == "vgg11": | |
encoder = vgg11() | |
elif model_name == "vgg11_bn": | |
encoder = vgg11_bn() | |
elif model_name == "vgg13": | |
encoder = vgg13() | |
elif model_name == "vgg13_bn": | |
encoder = vgg13_bn() | |
elif model_name == "vgg16": | |
encoder = vgg16() | |
elif model_name == "vgg16_bn": | |
encoder = vgg16_bn() | |
elif model_name == "vgg19": | |
encoder = vgg19() | |
else: # model_name == "vgg19_bn" | |
encoder = vgg19_bn() | |
encoder_channels = 512 | |
encoder_reduction = 16 | |
decoder = make_vgg_layers(decoder_cfg, in_channels=encoder_channels, batch_norm="bn" in model_name, dilation=1) | |
decoder.apply(_init_weights) | |
if norm == "bn": | |
norm_layer = nn.BatchNorm2d | |
elif norm == "ln": | |
norm_layer = nn.LayerNorm | |
else: | |
norm_layer = _get_norm_layer(encoder) | |
if act == "relu": | |
activation = nn.ReLU(inplace=True) | |
elif act == "gelu": | |
activation = nn.GELU() | |
else: | |
activation = _get_activation(encoder) | |
self.encoder = nn.Sequential(encoder, decoder) | |
self.encoder_channels = decoder_cfg[-1] | |
self.encoder_reduction = encoder_reduction | |
self.block_size = block_size if block_size is not None else self.encoder_reduction | |
if self.encoder_reduction >= self.block_size: | |
self.refiner = ConvUpsample( | |
in_channels=self.encoder_channels, | |
out_channels=self.encoder_channels, | |
scale_factor=self.encoder_reduction // self.block_size, | |
norm_layer=norm_layer, | |
activation=activation, | |
) | |
else: | |
self.refiner = ConvDownsample( | |
in_channels=self.encoder_channels, | |
out_channels=self.encoder_channels, | |
norm_layer=norm_layer, | |
activation=activation, | |
) | |
self.refiner_channels = self.encoder_channels | |
self.refiner_reduction = self.block_size | |
self.decoder = nn.Identity() | |
self.decoder_channels = self.refiner_channels | |
self.decoder_reduction = self.refiner_reduction | |
def encode(self, x: Tensor) -> Tensor: | |
return self.encoder(x) | |
def refine(self, x: Tensor) -> Tensor: | |
return self.refiner(x) | |
def decode(self, x: Tensor) -> Tensor: | |
return self.decoder(x) | |
def forward(self, x: Tensor) -> Tensor: | |
x = self.encode(x) | |
x = self.refine(x) | |
x = self.decode(x) | |
return x | |
class VGG(nn.Module): | |
def __init__( | |
self, | |
features: nn.Module, | |
) -> None: | |
super().__init__() | |
self.features = features | |
def forward(self, x: Tensor) -> Tensor: | |
x = self.features(x) | |
return x | |
def vgg11() -> VGG: | |
model = VGG(make_vgg_layers(vgg_cfgs["A"])) | |
model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg11"]), strict=False) | |
return model | |
def vgg11_bn() -> VGG: | |
model = VGG(make_vgg_layers(vgg_cfgs["A"], batch_norm=True)) | |
model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg11_bn"]), strict=False) | |
return model | |
def vgg13() -> VGG: | |
model = VGG(make_vgg_layers(vgg_cfgs["B"])) | |
model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg13"]), strict=False) | |
return model | |
def vgg13_bn() -> VGG: | |
model = VGG(make_vgg_layers(vgg_cfgs["B"], batch_norm=True)) | |
model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg13_bn"]), strict=False) | |
return model | |
def vgg16() -> VGG: | |
model = VGG(make_vgg_layers(vgg_cfgs["D"])) | |
model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg16"]), strict=False) | |
return model | |
def vgg16_bn() -> VGG: | |
model = VGG(make_vgg_layers(vgg_cfgs["D"], batch_norm=True)) | |
model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg16_bn"]), strict=False) | |
return model | |
def vgg19() -> VGG: | |
model = VGG(make_vgg_layers(vgg_cfgs["E"])) | |
model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg19"]), strict=False) | |
return model | |
def vgg19_bn() -> VGG: | |
model = VGG(make_vgg_layers(vgg_cfgs["E"], batch_norm=True)) | |
model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg19_bn"]), strict=False) | |
return model | |
def _vgg_encoder(model_name: str, block_size: Optional[int] = None, norm: str = "none", act: str = "none") -> VGGEncoder: | |
return VGGEncoder(model_name, block_size, norm=norm, act=act) | |
def _vgg_encoder_decoder(model_name: str, block_size: Optional[int] = None, norm: str = "none", act: str = "none") -> VGGEncoderDecoder: | |
return VGGEncoderDecoder(model_name, block_size, norm=norm, act=act) | |