Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| class Channel_attention(nn.Module): | |
| def __init__(self,ch, ratio = 8): | |
| super().__init__() | |
| self.avg_pool = nn.AdaptiveAvgPool2d(1) | |
| self.max_pool = nn.AdaptiveMaxPool2d(1) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(ch, ch//ratio, bias = False), | |
| nn.ReLU(inplace = True), | |
| nn.Linear( ch//ratio,ch, bias = False) | |
| ) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, x): | |
| x1 = self.avg_pool(x).squeeze(-1).squeeze(-1) | |
| x1 = self.mlp(x1) | |
| # x2 | |
| x2 = self.max_pool(x).squeeze(-1).squeeze(-1) | |
| x2 = self.mlp(x2) | |
| #concat | |
| f = x1+x2 | |
| f_s = self.sigmoid(f).unsqueeze(-1).unsqueeze(-1) | |
| f_final = x * f_s | |
| return f_final |