import math import logging from itertools import chain import torch import torch.nn as nn from torch.utils.checkpoint import checkpoint from timm.models.layers import DropPath, trunc_normal_ import torch.fft from .transformer_ls import AttentionLS _logger = logging.getLogger(__name__) class Mlp(nn.Module): def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class SpectralGatingNetwork(nn.Module): def __init__(self, dim, h=14, w=8): super().__init__() self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2) * 0.02) self.w = w self.h = h def forward(self, x, spatial_size=None): B, N, C = x.shape # torch.Size([1, 262144, 1024]) if spatial_size is None: a = b = int(math.sqrt(N)) # a=b=512 else: a, b = spatial_size x = x.view(B, a, b, C) # torch.Size([1, 512, 512, 1024]) # FROM HERE USED TO BE AUTOCAST to float32 dtype = x.dtype x = x.to(torch.float32) x = torch.fft.rfft2( x, dim=(1, 2), norm="ortho" ) # torch.Size([1, 512, 257, 1024]) weight = torch.view_as_complex( self.complex_weight.to(torch.float32) ) # torch.Size([512, 257, 1024]) x = x * weight x = torch.fft.irfft2( x, s=(a, b), dim=(1, 2), norm="ortho" ) # torch.Size([1, 512, 512, 1024]) x = x.to(dtype) x = x.reshape(B, N, C) # torch.Size([1, 262144, 1024]) # UP TO HERE USED TO BE AUTOCAST to float32 return x class BlockSpectralGating(nn.Module): def __init__( self, dim, mlp_ratio=4.0, drop=0.0, drop_path=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, h=14, w=8, ): super().__init__() self.norm1 = norm_layer(dim) self.filter = SpectralGatingNetwork(dim, h=h, w=w) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, ) def forward(self, x, *args): x = x + self.drop_path(self.mlp(self.norm2(self.filter(self.norm1(x))))) return x class BlockAttention(nn.Module): def __init__( self, dim, num_heads: int = 8, mlp_ratio=4.0, drop=0.0, drop_path=0.0, w=2, dp_rank=2, act_layer=nn.GELU, norm_layer=nn.LayerNorm, rpe=False, adaLN=False, nglo=0, ): """ num_heads: Attention heads. 4 for tiny, 8 for small and 12 for base """ super().__init__() self.norm1 = norm_layer(dim) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, ) self.attn = AttentionLS( dim=dim, num_heads=num_heads, w=w, dp_rank=dp_rank, nglo=nglo, rpe=rpe, ) if adaLN: self.adaLN_modulation = nn.Sequential( nn.Linear(dim, dim, bias=True), act_layer(), nn.Linear(dim, 6 * dim, bias=True), ) else: self.adaLN_modulation = None def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: if self.adaLN_modulation is not None: ( shift_mha, scale_mha, gate_mha, shift_mlp, scale_mlp, gate_mlp, ) = self.adaLN_modulation(c).chunk(6, dim=2) else: shift_mha, scale_mha, gate_mha, shift_mlp, scale_mlp, gate_mlp = 6 * (1.0,) x = x + gate_mha * self.drop_path( self.attn( self.norm1(x) * scale_mha + shift_mha, ) ) x = x + gate_mlp * self.drop_path( self.mlp(self.norm2(x) * scale_mlp + shift_mlp) ) return x class SpectFormer(nn.Module): def __init__( self, grid_size: int = 224 // 16, embed_dim=768, depth=12, n_spectral_blocks=4, num_heads: int = 8, mlp_ratio=4.0, uniform_drop=False, drop_rate=0.0, drop_path_rate=0.0, window_size=2, dp_rank=2, norm_layer=nn.LayerNorm, checkpoint_layers: list[int] | None = None, rpe=False, ensemble: int | None = None, nglo: int = 0, ): """ Args: img_size (int, tuple): input image size patch_size (int, tuple): patch size embed_dim (int): embedding dimension depth (int): depth of transformer n_spectral_blocks (int): number of spectral gating blocks mlp_ratio (int): ratio of mlp hidden dim to embedding dim uniform_drop (bool): true for uniform, false for linearly increasing drop path probability. drop_rate (float): dropout rate drop_path_rate (float): drop path (stochastic depth) rate window_size: window size for long/short attention dp_rank: dp rank for long/short attention norm_layer: (nn.Module): normalization layer for attention blocks checkpoint_layers: indicate which layers to use for checkpointing rpe: Use relative position encoding in Long-Short attention blocks. ensemble: Integer indicating ensemble size or None for deterministic model. nglo: Number of (additional) global tokens. """ super().__init__() self.embed_dim = embed_dim self.n_spectral_blocks = n_spectral_blocks self._checkpoint_layers = checkpoint_layers or [] self.ensemble = ensemble self.nglo = nglo h = grid_size w = h // 2 + 1 if uniform_drop: _logger.info(f"Using uniform droppath with expect rate {drop_path_rate}.") dpr = [drop_path_rate for _ in range(depth)] else: _logger.info( f"Using linear droppath with expect rate {drop_path_rate * 0.5}." ) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] self.blocks_spectral_gating = nn.ModuleList() self.blocks_attention = nn.ModuleList() for i in range(depth): if i < n_spectral_blocks: layer = BlockSpectralGating( dim=embed_dim, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[i], norm_layer=norm_layer, h=h, w=w, ) self.blocks_spectral_gating.append(layer) else: layer = BlockAttention( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[i], norm_layer=norm_layer, w=window_size, dp_rank=dp_rank, rpe=rpe, adaLN=True if ensemble is not None else False, nglo=nglo, ) self.blocks_attention.append(layer) self.apply(self._init_weights) def forward(self, tokens: torch.Tensor) -> torch.Tensor: """ Args: tokens: Tensor of shape B, N, C for deterministic of BxE, N, C for ensemble forecast. Returns: Tensor of same shape as input. """ if self.ensemble: BE, N, C = tokens.shape noise = torch.randn( size=(BE, N, C), dtype=tokens.dtype, device=tokens.device ) else: noise = None for i, blk in enumerate( chain(self.blocks_spectral_gating, self.blocks_attention) ): if i in self._checkpoint_layers: tokens = checkpoint(blk, tokens, noise, use_reentrant=False) else: tokens = blk(tokens, noise) return tokens def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0)