Spaces:
Running
on
Zero
Running
on
Zero
from torch import nn | |
from typing import Optional | |
from functools import partial | |
from .utils import _init_weights, interpolate_pos_embed | |
from .blocks import DepthSeparableConv2d, conv1x1, conv3x3, Conv2dLayerNorm | |
from .refine import ConvRefine, LightConvRefine, LighterConvRefine | |
from .downsample import ConvDownsample, LightConvDownsample, LighterConvDownsample | |
from .upsample import ConvUpsample, LightConvUpsample, LighterConvUpsample | |
from .multi_scale import MultiScale | |
from .blocks import ConvAdapter, ViTAdapter | |
def _get_norm_layer(model: nn.Module) -> Optional[nn.Module]: | |
for module in model.modules(): | |
if isinstance(module, nn.BatchNorm2d): | |
return nn.BatchNorm2d | |
elif isinstance(module, nn.GroupNorm): | |
num_groups = module.num_groups | |
return partial(nn.GroupNorm, num_groups=num_groups) | |
elif isinstance(module, (nn.LayerNorm, Conv2dLayerNorm)): | |
return Conv2dLayerNorm | |
return None | |
def _get_activation(model: nn.Module) -> Optional[nn.Module]: | |
for module in model.modules(): | |
if isinstance(module, nn.BatchNorm2d): | |
return nn.ReLU(inplace=True) | |
elif isinstance(module, nn.GroupNorm): | |
return nn.ReLU(inplace=True) | |
elif isinstance(module, (nn.LayerNorm, Conv2dLayerNorm)): | |
return nn.GELU() | |
return nn.GELU() | |
__all__ = [ | |
"_init_weights", "_check_norm_layer", "_check_activation", | |
"conv1x1", | |
"conv3x3", | |
"Conv2dLayerNorm", | |
"interpolate_pos_embed", | |
"DepthSeparableConv2d", | |
"ConvRefine", | |
"LightConvRefine", | |
"LighterConvRefine", | |
"ConvDownsample", | |
"LightConvDownsample", | |
"LighterConvDownsample", | |
"ConvUpsample", | |
"LightConvUpsample", | |
"LighterConvUpsample", | |
"MultiScale", | |
"ConvAdapter", "ViTAdapter", | |
] | |