Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| # Originally from https://github.com/visual-attention-network/segnext | |
| # Licensed under the Apache License, Version 2.0 (the "License") | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmcv.cnn import ConvModule | |
| from mmengine.device import get_device | |
| from mmseg.registry import MODELS | |
| from ..utils import resize | |
| from .decode_head import BaseDecodeHead | |
| class Matrix_Decomposition_2D_Base(nn.Module): | |
| """Base class of 2D Matrix Decomposition. | |
| Args: | |
| MD_S (int): The number of spatial coefficient in | |
| Matrix Decomposition, it may be used for calculation | |
| of the number of latent dimension D in Matrix | |
| Decomposition. Defaults: 1. | |
| MD_R (int): The number of latent dimension R in | |
| Matrix Decomposition. Defaults: 64. | |
| train_steps (int): The number of iteration steps in | |
| Multiplicative Update (MU) rule to solve Non-negative | |
| Matrix Factorization (NMF) in training. Defaults: 6. | |
| eval_steps (int): The number of iteration steps in | |
| Multiplicative Update (MU) rule to solve Non-negative | |
| Matrix Factorization (NMF) in evaluation. Defaults: 7. | |
| inv_t (int): Inverted multiple number to make coefficient | |
| smaller in softmax. Defaults: 100. | |
| rand_init (bool): Whether to initialize randomly. | |
| Defaults: True. | |
| """ | |
| def __init__(self, | |
| MD_S=1, | |
| MD_R=64, | |
| train_steps=6, | |
| eval_steps=7, | |
| inv_t=100, | |
| rand_init=True): | |
| super().__init__() | |
| self.S = MD_S | |
| self.R = MD_R | |
| self.train_steps = train_steps | |
| self.eval_steps = eval_steps | |
| self.inv_t = inv_t | |
| self.rand_init = rand_init | |
| def _build_bases(self, B, S, D, R, device=None): | |
| raise NotImplementedError | |
| def local_step(self, x, bases, coef): | |
| raise NotImplementedError | |
| def local_inference(self, x, bases): | |
| # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) | |
| coef = torch.bmm(x.transpose(1, 2), bases) | |
| coef = F.softmax(self.inv_t * coef, dim=-1) | |
| steps = self.train_steps if self.training else self.eval_steps | |
| for _ in range(steps): | |
| bases, coef = self.local_step(x, bases, coef) | |
| return bases, coef | |
| def compute_coef(self, x, bases, coef): | |
| raise NotImplementedError | |
| def forward(self, x, return_bases=False): | |
| """Forward Function.""" | |
| B, C, H, W = x.shape | |
| # (B, C, H, W) -> (B * S, D, N) | |
| D = C // self.S | |
| N = H * W | |
| x = x.view(B * self.S, D, N) | |
| if not self.rand_init and not hasattr(self, 'bases'): | |
| bases = self._build_bases(1, self.S, D, self.R, device=x.device) | |
| self.register_buffer('bases', bases) | |
| # (S, D, R) -> (B * S, D, R) | |
| if self.rand_init: | |
| bases = self._build_bases(B, self.S, D, self.R, device=x.device) | |
| else: | |
| bases = self.bases.repeat(B, 1, 1) | |
| bases, coef = self.local_inference(x, bases) | |
| # (B * S, N, R) | |
| coef = self.compute_coef(x, bases, coef) | |
| # (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N) | |
| x = torch.bmm(bases, coef.transpose(1, 2)) | |
| # (B * S, D, N) -> (B, C, H, W) | |
| x = x.view(B, C, H, W) | |
| return x | |
| class NMF2D(Matrix_Decomposition_2D_Base): | |
| """Non-negative Matrix Factorization (NMF) module. | |
| It is inherited from ``Matrix_Decomposition_2D_Base`` module. | |
| """ | |
| def __init__(self, args=dict()): | |
| super().__init__(**args) | |
| self.inv_t = 1 | |
| def _build_bases(self, B, S, D, R, device=None): | |
| """Build bases in initialization.""" | |
| if device is None: | |
| device = get_device() | |
| bases = torch.rand((B * S, D, R)).to(device) | |
| bases = F.normalize(bases, dim=1) | |
| return bases | |
| def local_step(self, x, bases, coef): | |
| """Local step in iteration to renew bases and coefficient.""" | |
| # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) | |
| numerator = torch.bmm(x.transpose(1, 2), bases) | |
| # (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R) | |
| denominator = coef.bmm(bases.transpose(1, 2).bmm(bases)) | |
| # Multiplicative Update | |
| coef = coef * numerator / (denominator + 1e-6) | |
| # (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R) | |
| numerator = torch.bmm(x, coef) | |
| # (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R) | |
| denominator = bases.bmm(coef.transpose(1, 2).bmm(coef)) | |
| # Multiplicative Update | |
| bases = bases * numerator / (denominator + 1e-6) | |
| return bases, coef | |
| def compute_coef(self, x, bases, coef): | |
| """Compute coefficient.""" | |
| # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) | |
| numerator = torch.bmm(x.transpose(1, 2), bases) | |
| # (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R) | |
| denominator = coef.bmm(bases.transpose(1, 2).bmm(bases)) | |
| # multiplication update | |
| coef = coef * numerator / (denominator + 1e-6) | |
| return coef | |
| class Hamburger(nn.Module): | |
| """Hamburger Module. It consists of one slice of "ham" (matrix | |
| decomposition) and two slices of "bread" (linear transformation). | |
| Args: | |
| ham_channels (int): Input and output channels of feature. | |
| ham_kwargs (dict): Config of matrix decomposition module. | |
| norm_cfg (dict | None): Config of norm layers. | |
| """ | |
| def __init__(self, | |
| ham_channels=512, | |
| ham_kwargs=dict(), | |
| norm_cfg=None, | |
| **kwargs): | |
| super().__init__() | |
| self.ham_in = ConvModule( | |
| ham_channels, ham_channels, 1, norm_cfg=None, act_cfg=None) | |
| self.ham = NMF2D(ham_kwargs) | |
| self.ham_out = ConvModule( | |
| ham_channels, ham_channels, 1, norm_cfg=norm_cfg, act_cfg=None) | |
| def forward(self, x): | |
| enjoy = self.ham_in(x) | |
| enjoy = F.relu(enjoy, inplace=True) | |
| enjoy = self.ham(enjoy) | |
| enjoy = self.ham_out(enjoy) | |
| ham = F.relu(x + enjoy, inplace=True) | |
| return ham | |
| class LightHamHead(BaseDecodeHead): | |
| """SegNeXt decode head. | |
| This decode head is the implementation of `SegNeXt: Rethinking | |
| Convolutional Attention Design for Semantic | |
| Segmentation <https://arxiv.org/abs/2209.08575>`_. | |
| Inspiration from https://github.com/visual-attention-network/segnext. | |
| Specifically, LightHamHead is inspired by HamNet from | |
| `Is Attention Better Than Matrix Decomposition? | |
| <https://arxiv.org/abs/2109.04553>`. | |
| Args: | |
| ham_channels (int): input channels for Hamburger. | |
| Defaults: 512. | |
| ham_kwargs (int): kwagrs for Ham. Defaults: dict(). | |
| """ | |
| def __init__(self, ham_channels=512, ham_kwargs=dict(), **kwargs): | |
| super().__init__(input_transform='multiple_select', **kwargs) | |
| self.ham_channels = ham_channels | |
| self.squeeze = ConvModule( | |
| sum(self.in_channels), | |
| self.ham_channels, | |
| 1, | |
| conv_cfg=self.conv_cfg, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg) | |
| self.hamburger = Hamburger(ham_channels, ham_kwargs, **kwargs) | |
| self.align = ConvModule( | |
| self.ham_channels, | |
| self.channels, | |
| 1, | |
| conv_cfg=self.conv_cfg, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg) | |
| def forward(self, inputs): | |
| """Forward function.""" | |
| inputs = self._transform_inputs(inputs) | |
| inputs = [ | |
| resize( | |
| level, | |
| size=inputs[0].shape[2:], | |
| mode='bilinear', | |
| align_corners=self.align_corners) for level in inputs | |
| ] | |
| inputs = torch.cat(inputs, dim=1) | |
| # apply a conv block to squeeze feature map | |
| x = self.squeeze(inputs) | |
| # apply hamburger module | |
| x = self.hamburger(x) | |
| # apply a conv block to align feature map | |
| output = self.align(x) | |
| output = self.cls_seg(output) | |
| return output | |