Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,449 Bytes
a7dedf9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from typing import List, Optional
from .csrnet import _csrnet, _csrnet_bn
from ..utils import _init_weights
EPS = 1e-6
class ContextualModule(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int = 512,
scales: List[int] = [1, 2, 3, 6],
) -> None:
super().__init__()
self.scales = scales
self.multiscale_modules = nn.ModuleList([self.__make_scale__(in_channels, size) for size in scales])
self.bottleneck = nn.Conv2d(in_channels * 2, out_channels, kernel_size=1)
self.relu = nn.ReLU(inplace=True)
self.weight_net = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.apply(_init_weights)
def __make_weight__(self, feature: Tensor, scale_feature: Tensor) -> Tensor:
weight_feature = feature - scale_feature
weight_feature = self.weight_net(weight_feature)
return F.sigmoid(weight_feature)
def __make_scale__(self, channels: int, size: int) -> nn.Module:
return nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=(size, size)),
nn.Conv2d(channels, channels, kernel_size=1, bias=False),
)
def forward(self, feature: Tensor) -> Tensor:
h, w = feature.shape[-2:]
multiscale_feats = [F.interpolate(input=scale(feature), size=(h, w), mode="bilinear") for scale in self.multiscale_modules]
weights = [self.__make_weight__(feature, scale_feature) for scale_feature in multiscale_feats]
multiscale_feats = sum([multiscale_feats[i] * weights[i] for i in range(len(weights))]) / (sum(weights) + EPS)
overall_features = torch.cat([multiscale_feats, feature], dim=1)
overall_features = self.bottleneck(overall_features)
overall_features = self.relu(overall_features)
return overall_features
class CANNet(nn.Module):
def __init__(
self,
model_name: str,
block_size: Optional[int] = None,
norm: str = "none",
act: str = "none",
scales: List[int] = [1, 2, 3, 6],
) -> None:
super().__init__()
assert model_name in ["csrnet", "csrnet_bn"], f"Model name should be one of ['csrnet', 'csrnet_bn'], but got {model_name}."
assert block_size is None or block_size in [8, 16, 32], f"block_size should be one of [8, 16, 32], but got {block_size}."
assert isinstance(scales, (tuple, list)), f"scales should be a list or tuple, got {type(scales)}."
assert len(scales) > 0, f"Expected at least one size, got {len(scales)}."
assert all([isinstance(size, int) for size in scales]), f"Expected all size to be int, got {scales}."
self.model_name = model_name
self.scales = scales
csrnet = _csrnet(block_size=block_size, norm=norm, act=act) if model_name == "csrnet" else _csrnet_bn(block_size=block_size, norm=norm, act=act)
self.block_size = csrnet.block_size
self.encoder = csrnet.encoder
self.encoder_channels = csrnet.encoder_channels
self.encoder_reduction = csrnet.encoder_reduction # feature map size compared to input size
self.refiner = nn.Sequential(
csrnet.refiner,
ContextualModule(csrnet.refine_channels, 512, scales)
)
self.refiner_channels = 512
self.refiner_reduction = csrnet.refiner_reduction # feature map size compared to input size
self.decoder = csrnet.decoder
self.decoder_channels = csrnet.decoder_channels
self.decoder_reduction = csrnet.decoder_reduction
def encode(self, x: Tensor) -> Tensor:
return self.encoder(x)
def refine(self, x: Tensor) -> Tensor:
return self.refiner(x)
def decode(self, x: Tensor) -> Tensor:
return self.decoder(x)
def forward(self, x: Tensor) -> Tensor:
x = self.encode(x)
x = self.refine(x)
x = self.decode(x)
return x
def _cannet(block_size: Optional[int] = None, norm: str = "none", act: str = "none", scales: List[int] = [1, 2, 3, 6]) -> CANNet:
return CANNet("csrnet", block_size=block_size, norm=norm, act=act, scales=scales)
def _cannet_bn(block_size: Optional[int] = None, norm: str = "none", act: str = "none", scales: List[int] = [1, 2, 3, 6]) -> CANNet:
return CANNet("csrnet_bn", block_size=block_size, norm=norm, act=act, scales=scales)
|