ZIP / models /utils /__init__.py
Yiming-M's picture
2025-07-31 18:59 🐣
a7dedf9
from torch import nn
from typing import Optional
from functools import partial
from .utils import _init_weights, interpolate_pos_embed
from .blocks import DepthSeparableConv2d, conv1x1, conv3x3, Conv2dLayerNorm
from .refine import ConvRefine, LightConvRefine, LighterConvRefine
from .downsample import ConvDownsample, LightConvDownsample, LighterConvDownsample
from .upsample import ConvUpsample, LightConvUpsample, LighterConvUpsample
from .multi_scale import MultiScale
from .blocks import ConvAdapter, ViTAdapter
def _get_norm_layer(model: nn.Module) -> Optional[nn.Module]:
for module in model.modules():
if isinstance(module, nn.BatchNorm2d):
return nn.BatchNorm2d
elif isinstance(module, nn.GroupNorm):
num_groups = module.num_groups
return partial(nn.GroupNorm, num_groups=num_groups)
elif isinstance(module, (nn.LayerNorm, Conv2dLayerNorm)):
return Conv2dLayerNorm
return None
def _get_activation(model: nn.Module) -> Optional[nn.Module]:
for module in model.modules():
if isinstance(module, nn.BatchNorm2d):
return nn.ReLU(inplace=True)
elif isinstance(module, nn.GroupNorm):
return nn.ReLU(inplace=True)
elif isinstance(module, (nn.LayerNorm, Conv2dLayerNorm)):
return nn.GELU()
return nn.GELU()
__all__ = [
"_init_weights", "_check_norm_layer", "_check_activation",
"conv1x1",
"conv3x3",
"Conv2dLayerNorm",
"interpolate_pos_embed",
"DepthSeparableConv2d",
"ConvRefine",
"LightConvRefine",
"LighterConvRefine",
"ConvDownsample",
"LightConvDownsample",
"LighterConvDownsample",
"ConvUpsample",
"LightConvUpsample",
"LighterConvUpsample",
"MultiScale",
"ConvAdapter", "ViTAdapter",
]