from torch import nn from typing import Optional from functools import partial from .utils import _init_weights, interpolate_pos_embed from .blocks import DepthSeparableConv2d, conv1x1, conv3x3, Conv2dLayerNorm from .refine import ConvRefine, LightConvRefine, LighterConvRefine from .downsample import ConvDownsample, LightConvDownsample, LighterConvDownsample from .upsample import ConvUpsample, LightConvUpsample, LighterConvUpsample from .multi_scale import MultiScale from .blocks import ConvAdapter, ViTAdapter def _get_norm_layer(model: nn.Module) -> Optional[nn.Module]: for module in model.modules(): if isinstance(module, nn.BatchNorm2d): return nn.BatchNorm2d elif isinstance(module, nn.GroupNorm): num_groups = module.num_groups return partial(nn.GroupNorm, num_groups=num_groups) elif isinstance(module, (nn.LayerNorm, Conv2dLayerNorm)): return Conv2dLayerNorm return None def _get_activation(model: nn.Module) -> Optional[nn.Module]: for module in model.modules(): if isinstance(module, nn.BatchNorm2d): return nn.ReLU(inplace=True) elif isinstance(module, nn.GroupNorm): return nn.ReLU(inplace=True) elif isinstance(module, (nn.LayerNorm, Conv2dLayerNorm)): return nn.GELU() return nn.GELU() __all__ = [ "_init_weights", "_check_norm_layer", "_check_activation", "conv1x1", "conv3x3", "Conv2dLayerNorm", "interpolate_pos_embed", "DepthSeparableConv2d", "ConvRefine", "LightConvRefine", "LighterConvRefine", "ConvDownsample", "LightConvDownsample", "LighterConvDownsample", "ConvUpsample", "LightConvUpsample", "LighterConvUpsample", "MultiScale", "ConvAdapter", "ViTAdapter", ]