ZIP / models /utils /downsample.py
Yiming-M's picture
2025-07-31 18:59 🐣
a7dedf9
from torch import nn, Tensor
from typing import Union
from .blocks import DepthSeparableConv2d, conv1x1, conv3x3
from .utils import _init_weights
class ConvDownsample(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
activation: nn.Module = nn.ReLU(inplace=True),
groups: int = 1,
) -> None:
super().__init__()
assert isinstance(groups, int) and groups > 0, f"Number of groups should be an integer greater than 0, but got {groups}."
assert in_channels % groups == 0, f"Number of input channels {in_channels} should be divisible by number of groups {groups}."
assert out_channels % groups == 0, f"Number of output channels {out_channels} should be divisible by number of groups {groups}."
self.grouped_conv = groups > 1
# conv1 is used for downsampling
# self.conv1 = nn.Conv2d(
# in_channels=in_channels,
# out_channels=in_channels,
# kernel_size=2,
# stride=2,
# padding=0,
# bias=not norm_layer,
# groups=groups,
# )
# if self.grouped_conv:
# self.conv1_1x1 = conv1x1(in_channels, in_channels, stride=1, bias=not norm_layer)
self.conv1 = nn.AvgPool2d(kernel_size=2, stride=2) # downsample by 2
if self.grouped_conv:
self.conv1_1x1 = nn.Identity()
self.norm1 = norm_layer(in_channels) if norm_layer else nn.Identity()
self.act1 = activation
self.conv2 = conv3x3(
in_channels=in_channels,
out_channels=in_channels,
stride=1,
groups=groups,
bias=not norm_layer,
)
if self.grouped_conv:
self.conv2_1x1 = conv1x1(in_channels, in_channels, stride=1, bias=not norm_layer)
self.norm2 = norm_layer(in_channels) if norm_layer else nn.Identity()
self.act2 = activation
self.conv3 = conv3x3(
in_channels=in_channels,
out_channels=out_channels,
stride=1,
groups=groups,
bias=not norm_layer,
)
if self.grouped_conv:
self.conv3_1x1 = conv1x1(out_channels, out_channels, stride=1, bias=not norm_layer)
self.norm3 = norm_layer(out_channels) if norm_layer else nn.Identity()
self.act3 = activation
self.downsample = nn.Sequential(
nn.AvgPool2d(kernel_size=2, stride=2), # make sure the spatial sizes match
conv1x1(in_channels, out_channels, stride=1, bias=not norm_layer),
norm_layer(out_channels) if norm_layer else nn.Identity(),
)
self.apply(_init_weights)
def forward(self, x: Tensor) -> Tensor:
identity = x
# downsample
out = self.conv1(x)
out = self.conv1_1x1(out) if self.grouped_conv else out
out = self.norm1(out)
out = self.act1(out)
out = self.conv2(out)
out = self.conv2_1x1(out) if self.grouped_conv else out
out = self.norm2(out)
out = self.act2(out)
out = self.conv3(out)
out = self.conv3_1x1(out) if self.grouped_conv else out
out = self.norm3(out)
# shortcut
out += self.downsample(identity)
out = self.act3(out)
return out
class LightConvDownsample(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
activation: nn.Module = nn.ReLU(inplace=True),
) -> None:
super().__init__()
self.conv1 = DepthSeparableConv2d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=2,
stride=2,
padding=0,
bias=not norm_layer,
)
self.norm1 = norm_layer(in_channels) if norm_layer else nn.Identity()
self.act1 = activation
self.conv2 = DepthSeparableConv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=not norm_layer,
)
self.norm2 = norm_layer(out_channels) if norm_layer else nn.Identity()
self.act2 = activation
self.conv3 = DepthSeparableConv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=not norm_layer,
)
self.norm3 = norm_layer(out_channels) if norm_layer else nn.Identity()
self.act3 = activation
self.downsample = nn.Sequential(
nn.AvgPool2d(kernel_size=2, stride=2), # make sure the spatial sizes match
conv1x1(in_channels, out_channels, stride=1, bias=not norm_layer),
norm_layer(out_channels) if norm_layer else nn.Identity(),
)
self.apply(_init_weights)
def forward(self, x: Tensor) -> Tensor:
identity = x
# downsample
out = self.conv1(x)
out = self.norm1(out)
out = self.act1(out)
# refine 1
out = self.conv2(out)
out = self.norm2(out)
out = self.act2(out)
# refine 2
out = self.conv3(out)
out = self.norm3(out)
# shortcut
out += self.downsample(identity)
out = self.act3(out)
return x
class LighterConvDownsample(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
activation: nn.Module = nn.ReLU(inplace=True),
) -> None:
super().__init__()
self.conv1 = DepthSeparableConv2d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=2,
stride=2,
padding=0,
bias=not norm_layer,
)
self.norm1 = norm_layer(in_channels) if norm_layer else nn.Identity()
self.act1 = activation
self.conv2 = conv3x3(
in_channels=in_channels,
out_channels=in_channels,
stride=1,
groups=in_channels,
bias=not norm_layer,
)
self.norm2 = norm_layer(in_channels) if norm_layer else nn.Identity()
self.act2 = activation
self.conv3 = conv1x1(
in_channels=in_channels,
out_channels=out_channels,
stride=1,
bias=not norm_layer,
)
self.norm3 = norm_layer(out_channels) if norm_layer else nn.Identity()
self.act3 = activation
self.downsample = nn.Sequential(
nn.AvgPool2d(kernel_size=2, stride=2), # make sure the spatial sizes match
conv1x1(in_channels, out_channels, stride=1, bias=not norm_layer),
norm_layer(out_channels) if norm_layer else nn.Identity(),
)
def forward(self, x: Tensor) -> Tensor:
identity = x
# downsample
out = self.conv1(x)
out = self.norm1(out)
out = self.act1(out)
# refine, depthwise conv
out = self.conv2(out)
out = self.norm2(out)
out = self.act2(out)
# refine, pointwise conv
out = self.conv3(out)
out = self.norm3(out)
# shortcut
out += self.downsample(identity)
out = self.act3(out)
return out