Spaces:
Running
on
Zero
Running
on
Zero
from torch import nn, Tensor | |
from torch.nn import functional as F | |
from typing import Union | |
from functools import partial | |
from .utils import _init_weights | |
from .refine import ConvRefine, LightConvRefine, LighterConvRefine | |
class ConvUpsample(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
scale_factor: int = 2, | |
norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d, | |
activation: nn.Module = nn.ReLU(inplace=True), | |
groups: int = 1, | |
) -> None: | |
super().__init__() | |
assert scale_factor >= 1, f"Scale factor should be greater than or equal to 1, but got {scale_factor}" | |
self.scale_factor = scale_factor | |
self.upsample = partial( | |
F.interpolate, | |
scale_factor=scale_factor, | |
mode="bilinear", | |
align_corners=False, | |
recompute_scale_factor=False, | |
antialias=False, | |
) if scale_factor > 1 else nn.Identity() | |
self.refine = ConvRefine( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
norm_layer=norm_layer, | |
activation=activation, | |
groups=groups, | |
) | |
self.apply(_init_weights) | |
def forward(self, x: Tensor) -> Tensor: | |
x = self.upsample(x) | |
x = self.refine(x) | |
return x | |
class LightConvUpsample(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
scale_factor: int = 2, | |
norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d, | |
activation: nn.Module = nn.ReLU(inplace=True), | |
) -> None: | |
super().__init__() | |
assert scale_factor >= 1, f"Scale factor should be greater than or equal to 1, but got {scale_factor}" | |
self.scale_factor = scale_factor | |
self.upsample = partial( | |
F.interpolate, | |
scale_factor=scale_factor, | |
mode="bilinear", | |
align_corners=False, | |
recompute_scale_factor=False, | |
antialias=False, | |
) if scale_factor > 1 else nn.Identity() | |
self.refine = LightConvRefine( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
norm_layer=norm_layer, | |
activation=activation, | |
) | |
self.apply(_init_weights) | |
def forward(self, x: Tensor) -> Tensor: | |
x = self.upsample(x) | |
x = self.refine(x) | |
return x | |
class LighterConvUpsample(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
scale_factor: int = 2, | |
norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d, | |
activation: nn.Module = nn.ReLU(inplace=True), | |
) -> None: | |
super().__init__() | |
assert scale_factor >= 1, f"Scale factor should be greater than or equal to 1, but got {scale_factor}" | |
self.scale_factor = scale_factor | |
self.upsample = partial( | |
F.interpolate, | |
scale_factor=scale_factor, | |
mode="bilinear", | |
align_corners=False, | |
recompute_scale_factor=False, | |
antialias=False, | |
) if scale_factor > 1 else nn.Identity() | |
self.refine = LighterConvRefine( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
norm_layer=norm_layer, | |
activation=activation, | |
) | |
self.apply(_init_weights) | |
def forward(self, x: Tensor) -> Tensor: | |
x = self.upsample(x) | |
x = self.refine(x) | |
return x | |