import torch from torch import nn, Tensor from einops import rearrange from typing import Tuple, Union, Dict, Optional, List from functools import partial from .cannet import _cannet, _cannet_bn from .csrnet import _csrnet, _csrnet_bn from .vgg import _vgg_encoder_decoder, _vgg_encoder from .vit import _vit, supported_vit_backbones from .timm_models import _timm_model from .timm_models import regular_models as timm_regular_models, heavy_models as timm_heavy_models, light_models as timm_light_models, lighter_models as timm_lighter_models from .hrnet import _hrnet, available_hrnets from ..utils import conv1x1 regular_models = [ "csrnet", "csrnet_bn", "cannet", "cannet_bn", "vgg11", "vgg11_bn", "vgg13", "vgg13_bn", "vgg16", "vgg16_bn", "vgg19", "vgg19_bn", "vgg11_ae", "vgg11_bn_ae", "vgg13_ae", "vgg13_bn_ae", "vgg16_ae", "vgg16_bn_ae", "vgg19_ae", "vgg19_bn_ae", *timm_regular_models, *available_hrnets, ] heavy_models = timm_heavy_models light_models = timm_light_models lighter_models = timm_lighter_models transformer_models = supported_vit_backbones supported_models = regular_models + heavy_models + light_models + lighter_models + transformer_models class EBC(nn.Module): def __init__( self, model_name: str, block_size: int, bins: List[Tuple[float, float]], bin_centers: List[float], zero_inflated: bool = True, num_vpt: Optional[int] = None, vpt_drop: Optional[float] = None, input_size: Optional[int] = None, norm: str = "none", act: str = "none" ) -> None: super().__init__() assert model_name in supported_models, f"Model name should be one of {supported_models}, but got {model_name}." self.model_name = model_name if input_size is not None: input_size = (input_size, input_size) if isinstance(input_size, int) else input_size assert len(input_size) == 2 and input_size[0] > 0 and input_size[1] > 0, f"Expected input_size to be a tuple of two positive integers, got {input_size}" self.input_size = input_size assert len(bins) == len(bin_centers), f"Expected bins and bin_centers to have the same length, got {len(bins)} and {len(bin_centers)}" assert len(bins) >= 2, f"Expected at least 2 bins, got {len(bins)}" assert all(len(b) == 2 for b in bins), f"Expected bins to be a list of tuples of length 2, got {bins}" bins = [(float(b[0]), float(b[1])) for b in bins] assert all(bin[0] <= p <= bin[1] for bin, p in zip(bins, bin_centers)), f"Expected bin_centers to be within the range of the corresponding bin, got {bins} and {bin_centers}" self.block_size = block_size self.bins = bins self.register_buffer("bin_centers", torch.tensor(bin_centers, dtype=torch.float32, requires_grad=False).view(1, -1, 1, 1)) self.zero_inflated = zero_inflated self.num_vpt = num_vpt self.vpt_drop = vpt_drop self.input_size = input_size self.norm = norm self.act = act self._build_backbone() self._build_head() def _build_backbone(self) -> None: model_name = self.model_name if model_name == "csrnet": self.backbone = _csrnet(self.block_size, self.norm, self.act) elif model_name == "csrnet_bn": self.backbone = _csrnet_bn(self.block_size, self.norm, self.act) elif model_name == "cannet": self.backbone = _cannet(self.block_size, self.norm, self.act) elif model_name == "cannet_bn": self.backbone = _cannet_bn(self.block_size, self.norm, self.act) elif model_name == "vgg11": self.backbone = _vgg_encoder("vgg11", self.block_size, self.norm, self.act) elif model_name == "vgg11_ae": self.backbone = _vgg_encoder_decoder("vgg11", self.block_size, self.norm, self.act) elif model_name == "vgg11_bn": self.backbone = _vgg_encoder("vgg11_bn", self.block_size, self.norm, self.act) elif model_name == "vgg11_bn_ae": self.backbone = _vgg_encoder_decoder("vgg11_bn", self.block_size, self.norm, self.act) elif model_name == "vgg13": self.backbone = _vgg_encoder("vgg13", self.block_size, self.norm, self.act) elif model_name == "vgg13_ae": self.backbone = _vgg_encoder_decoder("vgg13", self.block_size, self.norm, self.act) elif model_name == "vgg13_bn": self.backbone = _vgg_encoder("vgg13_bn", self.block_size, self.norm, self.act) elif model_name == "vgg13_bn_ae": self.backbone = _vgg_encoder_decoder("vgg13_bn", self.block_size, self.norm, self.act) elif model_name == "vgg16": self.backbone = _vgg_encoder("vgg16", self.block_size, self.norm, self.act) elif model_name == "vgg16_ae": self.backbone = _vgg_encoder_decoder("vgg16", self.block_size, self.norm, self.act) elif model_name == "vgg16_bn": self.backbone = _vgg_encoder("vgg16_bn", self.block_size, self.norm, self.act) elif model_name == "vgg16_bn_ae": self.backbone = _vgg_encoder_decoder("vgg16_bn", self.block_size, self.norm, self.act) elif model_name == "vgg19": self.backbone = _vgg_encoder("vgg19", self.block_size, self.norm, self.act) elif model_name == "vgg19_ae": self.backbone = _vgg_encoder_decoder("vgg19", self.block_size, self.norm, self.act) elif model_name == "vgg19_bn": self.backbone = _vgg_encoder("vgg19_bn", self.block_size, self.norm, self.act) elif model_name == "vgg19_bn_ae": self.backbone = _vgg_encoder_decoder("vgg19_bn", self.block_size, self.norm, self.act) elif model_name in supported_vit_backbones: self.backbone = _vit(model_name, block_size=self.block_size, num_vpt=self.num_vpt, vpt_drop=self.vpt_drop, input_size=self.input_size, norm=self.norm, act=self.act) elif model_name in available_hrnets: self.backbone = _hrnet(model_name, block_size=self.block_size, norm=self.norm, act=self.act) else: self.backbone = _timm_model(model_name, self.block_size, self.norm, self.act) def _build_head(self) -> None: channels = self.backbone.decoder_channels if self.zero_inflated: self.bin_head = conv1x1( in_channels=channels, out_channels=len(self.bins) - 1, ) self.pi_head = conv1x1( in_channels=channels, out_channels=2, ) # this models structural 0s. else: self.bin_head = conv1x1( in_channels=channels, out_channels=len(self.bins), ) def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: x = self.backbone(x) if self.zero_inflated: logit_pi_maps = self.pi_head(x) # shape: (B, 2, H, W) logit_maps = self.bin_head(x) # shape: (B, C, H, W) lambda_maps = (logit_maps.softmax(dim=1) * self.bin_centers[:, 1:]).sum(dim=1, keepdim=True) # shape: (B, 1, H, W) # logit_pi_maps.softmax(dim=1)[:, 0] is the probability of zeros den_maps = logit_pi_maps.softmax(dim=1)[:, 1:] * lambda_maps # expectation of the Poisson distribution if self.training: return logit_pi_maps, logit_maps, lambda_maps, den_maps else: return den_maps else: logit_maps = self.bin_head(x) den_maps = (logit_maps.softmax(dim=1) * self.bin_centers).sum(dim=1, keepdim=True) if self.training: return logit_maps, den_maps else: return den_maps def _ebc( model_name: str, block_size: int, bins: List[Tuple[float, float]], bin_centers: List[float], zero_inflated: bool = True, num_vpt: Optional[int] = None, vpt_drop: Optional[float] = None, input_size: Optional[int] = None, norm: str = "none", act: str = "none" ) -> EBC: return EBC( model_name=model_name, block_size=block_size, bins=bins, bin_centers=bin_centers, zero_inflated=zero_inflated, num_vpt=num_vpt, vpt_drop=vpt_drop, input_size=input_size, norm=norm, act=act )