from torch import nn, Tensor from typing import Union from .blocks import DepthSeparableConv2d, conv1x1, conv3x3 from .utils import _init_weights class ConvDownsample(nn.Module): def __init__( self, in_channels: int, out_channels: int, norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d, activation: nn.Module = nn.ReLU(inplace=True), groups: int = 1, ) -> None: super().__init__() assert isinstance(groups, int) and groups > 0, f"Number of groups should be an integer greater than 0, but got {groups}." assert in_channels % groups == 0, f"Number of input channels {in_channels} should be divisible by number of groups {groups}." assert out_channels % groups == 0, f"Number of output channels {out_channels} should be divisible by number of groups {groups}." self.grouped_conv = groups > 1 # conv1 is used for downsampling # self.conv1 = nn.Conv2d( # in_channels=in_channels, # out_channels=in_channels, # kernel_size=2, # stride=2, # padding=0, # bias=not norm_layer, # groups=groups, # ) # if self.grouped_conv: # self.conv1_1x1 = conv1x1(in_channels, in_channels, stride=1, bias=not norm_layer) self.conv1 = nn.AvgPool2d(kernel_size=2, stride=2) # downsample by 2 if self.grouped_conv: self.conv1_1x1 = nn.Identity() self.norm1 = norm_layer(in_channels) if norm_layer else nn.Identity() self.act1 = activation self.conv2 = conv3x3( in_channels=in_channels, out_channels=in_channels, stride=1, groups=groups, bias=not norm_layer, ) if self.grouped_conv: self.conv2_1x1 = conv1x1(in_channels, in_channels, stride=1, bias=not norm_layer) self.norm2 = norm_layer(in_channels) if norm_layer else nn.Identity() self.act2 = activation self.conv3 = conv3x3( in_channels=in_channels, out_channels=out_channels, stride=1, groups=groups, bias=not norm_layer, ) if self.grouped_conv: self.conv3_1x1 = conv1x1(out_channels, out_channels, stride=1, bias=not norm_layer) self.norm3 = norm_layer(out_channels) if norm_layer else nn.Identity() self.act3 = activation self.downsample = nn.Sequential( nn.AvgPool2d(kernel_size=2, stride=2), # make sure the spatial sizes match conv1x1(in_channels, out_channels, stride=1, bias=not norm_layer), norm_layer(out_channels) if norm_layer else nn.Identity(), ) self.apply(_init_weights) def forward(self, x: Tensor) -> Tensor: identity = x # downsample out = self.conv1(x) out = self.conv1_1x1(out) if self.grouped_conv else out out = self.norm1(out) out = self.act1(out) out = self.conv2(out) out = self.conv2_1x1(out) if self.grouped_conv else out out = self.norm2(out) out = self.act2(out) out = self.conv3(out) out = self.conv3_1x1(out) if self.grouped_conv else out out = self.norm3(out) # shortcut out += self.downsample(identity) out = self.act3(out) return out class LightConvDownsample(nn.Module): def __init__( self, in_channels: int, out_channels: int, norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d, activation: nn.Module = nn.ReLU(inplace=True), ) -> None: super().__init__() self.conv1 = DepthSeparableConv2d( in_channels=in_channels, out_channels=in_channels, kernel_size=2, stride=2, padding=0, bias=not norm_layer, ) self.norm1 = norm_layer(in_channels) if norm_layer else nn.Identity() self.act1 = activation self.conv2 = DepthSeparableConv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=not norm_layer, ) self.norm2 = norm_layer(out_channels) if norm_layer else nn.Identity() self.act2 = activation self.conv3 = DepthSeparableConv2d( in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=not norm_layer, ) self.norm3 = norm_layer(out_channels) if norm_layer else nn.Identity() self.act3 = activation self.downsample = nn.Sequential( nn.AvgPool2d(kernel_size=2, stride=2), # make sure the spatial sizes match conv1x1(in_channels, out_channels, stride=1, bias=not norm_layer), norm_layer(out_channels) if norm_layer else nn.Identity(), ) self.apply(_init_weights) def forward(self, x: Tensor) -> Tensor: identity = x # downsample out = self.conv1(x) out = self.norm1(out) out = self.act1(out) # refine 1 out = self.conv2(out) out = self.norm2(out) out = self.act2(out) # refine 2 out = self.conv3(out) out = self.norm3(out) # shortcut out += self.downsample(identity) out = self.act3(out) return x class LighterConvDownsample(nn.Module): def __init__( self, in_channels: int, out_channels: int, norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d, activation: nn.Module = nn.ReLU(inplace=True), ) -> None: super().__init__() self.conv1 = DepthSeparableConv2d( in_channels=in_channels, out_channels=in_channels, kernel_size=2, stride=2, padding=0, bias=not norm_layer, ) self.norm1 = norm_layer(in_channels) if norm_layer else nn.Identity() self.act1 = activation self.conv2 = conv3x3( in_channels=in_channels, out_channels=in_channels, stride=1, groups=in_channels, bias=not norm_layer, ) self.norm2 = norm_layer(in_channels) if norm_layer else nn.Identity() self.act2 = activation self.conv3 = conv1x1( in_channels=in_channels, out_channels=out_channels, stride=1, bias=not norm_layer, ) self.norm3 = norm_layer(out_channels) if norm_layer else nn.Identity() self.act3 = activation self.downsample = nn.Sequential( nn.AvgPool2d(kernel_size=2, stride=2), # make sure the spatial sizes match conv1x1(in_channels, out_channels, stride=1, bias=not norm_layer), norm_layer(out_channels) if norm_layer else nn.Identity(), ) def forward(self, x: Tensor) -> Tensor: identity = x # downsample out = self.conv1(x) out = self.norm1(out) out = self.act1(out) # refine, depthwise conv out = self.conv2(out) out = self.norm2(out) out = self.act2(out) # refine, pointwise conv out = self.conv3(out) out = self.norm3(out) # shortcut out += self.downsample(identity) out = self.act3(out) return out