ZIP / models /ebc /vgg.py
Yiming-M's picture
2025-07-31 22:05 πŸš€
a5dc50a
raw
history blame
8.79 kB
from torch import nn, Tensor
from torch.hub import load_state_dict_from_url
from typing import Optional
from .utils import make_vgg_layers, vgg_cfgs, vgg_urls
from ..utils import _init_weights, _get_norm_layer, _get_activation
from ..utils import ConvDownsample, ConvUpsample
vgg_models = [
"vgg11", "vgg11_bn",
"vgg13", "vgg13_bn",
"vgg16", "vgg16_bn",
"vgg19", "vgg19_bn",
]
decoder_cfg = [512, 256, 128]
class VGGEncoder(nn.Module):
def __init__(
self,
model_name: str,
block_size: Optional[int] = None,
norm: str = "none",
act: str = "none",
) -> None:
super().__init__()
assert model_name in vgg_models, f"Model name should be one of {vgg_models}, but got {model_name}."
assert block_size is None or block_size in [8, 16, 32], f"Block size should be one of [8, 16, 32], but got {block_size}."
self.model_name = model_name
if model_name == "vgg11":
self.encoder = vgg11()
elif model_name == "vgg11_bn":
self.encoder = vgg11_bn()
elif model_name == "vgg13":
self.encoder = vgg13()
elif model_name == "vgg13_bn":
self.encoder = vgg13_bn()
elif model_name == "vgg16":
self.encoder = vgg16()
elif model_name == "vgg16_bn":
self.encoder = vgg16_bn()
elif model_name == "vgg19":
self.encoder = vgg19()
else: # model_name == "vgg19_bn"
self.encoder = vgg19_bn()
self.encoder_channels = 512
self.encoder_reduction = 16
self.block_size = block_size if block_size is not None else self.encoder_reduction
if norm == "bn":
norm_layer = nn.BatchNorm2d
elif norm == "ln":
norm_layer = nn.LayerNorm
else:
norm_layer = _get_norm_layer(self.encoder)
if act == "relu":
activation = nn.ReLU(inplace=True)
elif act == "gelu":
activation = nn.GELU()
else:
activation = _get_activation(self.encoder)
if self.encoder_reduction >= self.block_size: # 8, 16
self.refiner = ConvUpsample(
in_channels=self.encoder_channels,
out_channels=self.encoder_channels,
scale_factor=self.encoder_reduction // self.block_size,
norm_layer=norm_layer,
activation=activation,
)
else: # 32
self.refiner = ConvDownsample(
in_channels=self.encoder_channels,
out_channels=self.encoder_channels,
norm_layer=norm_layer,
activation=activation,
)
self.refiner_channels = self.encoder_channels
self.refiner_reduction = self.block_size
self.decoder = nn.Identity()
self.decoder_channels = self.encoder_channels
self.decoder_reduction = self.refiner_reduction
def encode(self, x: Tensor) -> Tensor:
return self.encoder(x)
def refine(self, x: Tensor) -> Tensor:
return self.refiner(x)
def decode(self, x: Tensor) -> Tensor:
return self.decoder(x)
def forward(self, x: Tensor) -> Tensor:
x = self.encode(x)
x = self.refine(x)
x = self.decode(x)
return x
class VGGEncoderDecoder(nn.Module):
def __init__(
self,
model_name: str,
block_size: Optional[int] = None,
norm: str = "none",
act: str = "none",
) -> None:
super().__init__()
assert model_name in vgg_models, f"Model name should be one of {vgg_models}, but got {model_name}."
assert block_size is None or block_size in [8, 16, 32], f"Block size should be one of [8, 16, 32], but got {block_size}."
self.model_name = model_name
if model_name == "vgg11":
encoder = vgg11()
elif model_name == "vgg11_bn":
encoder = vgg11_bn()
elif model_name == "vgg13":
encoder = vgg13()
elif model_name == "vgg13_bn":
encoder = vgg13_bn()
elif model_name == "vgg16":
encoder = vgg16()
elif model_name == "vgg16_bn":
encoder = vgg16_bn()
elif model_name == "vgg19":
encoder = vgg19()
else: # model_name == "vgg19_bn"
encoder = vgg19_bn()
encoder_channels = 512
encoder_reduction = 16
decoder = make_vgg_layers(decoder_cfg, in_channels=encoder_channels, batch_norm="bn" in model_name, dilation=1)
decoder.apply(_init_weights)
if norm == "bn":
norm_layer = nn.BatchNorm2d
elif norm == "ln":
norm_layer = nn.LayerNorm
else:
norm_layer = _get_norm_layer(encoder)
if act == "relu":
activation = nn.ReLU(inplace=True)
elif act == "gelu":
activation = nn.GELU()
else:
activation = _get_activation(encoder)
self.encoder = nn.Sequential(encoder, decoder)
self.encoder_channels = decoder_cfg[-1]
self.encoder_reduction = encoder_reduction
self.block_size = block_size if block_size is not None else self.encoder_reduction
if self.encoder_reduction >= self.block_size:
self.refiner = ConvUpsample(
in_channels=self.encoder_channels,
out_channels=self.encoder_channels,
scale_factor=self.encoder_reduction // self.block_size,
norm_layer=norm_layer,
activation=activation,
)
else:
self.refiner = ConvDownsample(
in_channels=self.encoder_channels,
out_channels=self.encoder_channels,
norm_layer=norm_layer,
activation=activation,
)
self.refiner_channels = self.encoder_channels
self.refiner_reduction = self.block_size
self.decoder = nn.Identity()
self.decoder_channels = self.refiner_channels
self.decoder_reduction = self.refiner_reduction
def encode(self, x: Tensor) -> Tensor:
return self.encoder(x)
def refine(self, x: Tensor) -> Tensor:
return self.refiner(x)
def decode(self, x: Tensor) -> Tensor:
return self.decoder(x)
def forward(self, x: Tensor) -> Tensor:
x = self.encode(x)
x = self.refine(x)
x = self.decode(x)
return x
class VGG(nn.Module):
def __init__(
self,
features: nn.Module,
) -> None:
super().__init__()
self.features = features
def forward(self, x: Tensor) -> Tensor:
x = self.features(x)
return x
def vgg11() -> VGG:
model = VGG(make_vgg_layers(vgg_cfgs["A"]))
# model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg11"]), strict=False)
return model
def vgg11_bn() -> VGG:
model = VGG(make_vgg_layers(vgg_cfgs["A"], batch_norm=True))
# model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg11_bn"]), strict=False)
return model
def vgg13() -> VGG:
model = VGG(make_vgg_layers(vgg_cfgs["B"]))
# model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg13"]), strict=False)
return model
def vgg13_bn() -> VGG:
model = VGG(make_vgg_layers(vgg_cfgs["B"], batch_norm=True))
# model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg13_bn"]), strict=False)
return model
def vgg16() -> VGG:
model = VGG(make_vgg_layers(vgg_cfgs["D"]))
# model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg16"]), strict=False)
return model
def vgg16_bn() -> VGG:
model = VGG(make_vgg_layers(vgg_cfgs["D"], batch_norm=True))
# model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg16_bn"]), strict=False)
return model
def vgg19() -> VGG:
model = VGG(make_vgg_layers(vgg_cfgs["E"]))
# model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg19"]), strict=False)
return model
def vgg19_bn() -> VGG:
model = VGG(make_vgg_layers(vgg_cfgs["E"], batch_norm=True))
# model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg19_bn"]), strict=False)
return model
def _vgg_encoder(model_name: str, block_size: Optional[int] = None, norm: str = "none", act: str = "none") -> VGGEncoder:
return VGGEncoder(model_name, block_size, norm=norm, act=act)
def _vgg_encoder_decoder(model_name: str, block_size: Optional[int] = None, norm: str = "none", act: str = "none") -> VGGEncoderDecoder:
return VGGEncoderDecoder(model_name, block_size, norm=norm, act=act)