ZIP / models /utils /refine.py
Yiming-M's picture
2025-07-31 18:59 🐣
a7dedf9
from torch import nn, Tensor
from typing import Union
from .utils import _init_weights
from .blocks import BasicBlock, LightBasicBlock, conv1x1, conv3x3
class ConvRefine(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
activation: nn.Module = nn.ReLU(inplace=True),
groups: int = 1,
) -> None:
super().__init__()
self.refine = BasicBlock(
in_channels=in_channels,
out_channels=out_channels,
norm_layer=norm_layer,
activation=activation,
groups=groups,
)
self.apply(_init_weights)
def forward(self, x: Tensor) -> Tensor:
return self.refine(x)
class LightConvRefine(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
activation: nn.Module = nn.ReLU(inplace=True),
) -> None:
super().__init__()
self.refine = LightBasicBlock(
in_channels=in_channels,
out_channels=out_channels,
norm_layer=norm_layer,
activation=activation,
)
self.apply(_init_weights)
def forward(self, x: Tensor) -> Tensor:
return self.refine(x)
class LighterConvRefine(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
activation: nn.Module = nn.ReLU(inplace=True),
) -> None:
super().__init__()
# depthwise separable convolution
self.conv1 = conv3x3(
in_channels=in_channels,
out_channels=in_channels,
stride=1,
groups=in_channels,
bias=not norm_layer,
)
self.norm1 = norm_layer(in_channels) if norm_layer else nn.Identity()
self.act1 = activation
self.conv2 = conv1x1(
in_channels=in_channels,
out_channels=out_channels,
stride=1,
bias=not norm_layer,
)
self.norm2 = norm_layer(out_channels) if norm_layer else nn.Identity()
self.act2 = activation
if in_channels != out_channels:
self.downsample = nn.Sequential(
conv1x1(in_channels, out_channels, stride=1, bias=not norm_layer),
norm_layer(out_channels) if norm_layer else nn.Identity(),
)
else:
self.downsample = nn.Identity()
self.apply(_init_weights)
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.act1(out)
out = self.conv2(out)
out = self.norm2(out)
out += self.downsample(identity)
out = self.act2(out)
return out