Spaces:
Runtime error
Runtime error
import math | |
import logging | |
from itertools import chain | |
import torch | |
import torch.nn as nn | |
from torch.utils.checkpoint import checkpoint | |
from timm.models.layers import DropPath, trunc_normal_ | |
import torch.fft | |
from .transformer_ls import AttentionLS | |
_logger = logging.getLogger(__name__) | |
class Mlp(nn.Module): | |
def __init__( | |
self, | |
in_features, | |
hidden_features=None, | |
out_features=None, | |
act_layer=nn.GELU, | |
drop=0.0, | |
): | |
super().__init__() | |
out_features = out_features or in_features | |
hidden_features = hidden_features or in_features | |
self.fc1 = nn.Linear(in_features, hidden_features) | |
self.act = act_layer() | |
self.fc2 = nn.Linear(hidden_features, out_features) | |
self.drop = nn.Dropout(drop) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.act(x) | |
x = self.drop(x) | |
x = self.fc2(x) | |
x = self.drop(x) | |
return x | |
class SpectralGatingNetwork(nn.Module): | |
def __init__(self, dim, h=14, w=8): | |
super().__init__() | |
self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2) * 0.02) | |
self.w = w | |
self.h = h | |
def forward(self, x, spatial_size=None): | |
B, N, C = x.shape # torch.Size([1, 262144, 1024]) | |
if spatial_size is None: | |
a = b = int(math.sqrt(N)) # a=b=512 | |
else: | |
a, b = spatial_size | |
x = x.view(B, a, b, C) # torch.Size([1, 512, 512, 1024]) | |
# FROM HERE USED TO BE AUTOCAST to float32 | |
dtype = x.dtype | |
x = x.to(torch.float32) | |
x = torch.fft.rfft2( | |
x, dim=(1, 2), norm="ortho" | |
) # torch.Size([1, 512, 257, 1024]) | |
weight = torch.view_as_complex( | |
self.complex_weight.to(torch.float32) | |
) # torch.Size([512, 257, 1024]) | |
x = x * weight | |
x = torch.fft.irfft2( | |
x, s=(a, b), dim=(1, 2), norm="ortho" | |
) # torch.Size([1, 512, 512, 1024]) | |
x = x.to(dtype) | |
x = x.reshape(B, N, C) # torch.Size([1, 262144, 1024]) | |
# UP TO HERE USED TO BE AUTOCAST to float32 | |
return x | |
class BlockSpectralGating(nn.Module): | |
def __init__( | |
self, | |
dim, | |
mlp_ratio=4.0, | |
drop=0.0, | |
drop_path=0.0, | |
act_layer=nn.GELU, | |
norm_layer=nn.LayerNorm, | |
h=14, | |
w=8, | |
): | |
super().__init__() | |
self.norm1 = norm_layer(dim) | |
self.filter = SpectralGatingNetwork(dim, h=h, w=w) | |
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | |
self.norm2 = norm_layer(dim) | |
mlp_hidden_dim = int(dim * mlp_ratio) | |
self.mlp = Mlp( | |
in_features=dim, | |
hidden_features=mlp_hidden_dim, | |
act_layer=act_layer, | |
drop=drop, | |
) | |
def forward(self, x, *args): | |
x = x + self.drop_path(self.mlp(self.norm2(self.filter(self.norm1(x))))) | |
return x | |
class BlockAttention(nn.Module): | |
def __init__( | |
self, | |
dim, | |
num_heads: int = 8, | |
mlp_ratio=4.0, | |
drop=0.0, | |
drop_path=0.0, | |
w=2, | |
dp_rank=2, | |
act_layer=nn.GELU, | |
norm_layer=nn.LayerNorm, | |
rpe=False, | |
adaLN=False, | |
nglo=0, | |
): | |
""" | |
num_heads: Attention heads. 4 for tiny, 8 for small and 12 for base | |
""" | |
super().__init__() | |
self.norm1 = norm_layer(dim) | |
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | |
self.norm2 = norm_layer(dim) | |
mlp_hidden_dim = int(dim * mlp_ratio) | |
self.mlp = Mlp( | |
in_features=dim, | |
hidden_features=mlp_hidden_dim, | |
act_layer=act_layer, | |
drop=drop, | |
) | |
self.attn = AttentionLS( | |
dim=dim, | |
num_heads=num_heads, | |
w=w, | |
dp_rank=dp_rank, | |
nglo=nglo, | |
rpe=rpe, | |
) | |
if adaLN: | |
self.adaLN_modulation = nn.Sequential( | |
nn.Linear(dim, dim, bias=True), | |
act_layer(), | |
nn.Linear(dim, 6 * dim, bias=True), | |
) | |
else: | |
self.adaLN_modulation = None | |
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: | |
if self.adaLN_modulation is not None: | |
( | |
shift_mha, | |
scale_mha, | |
gate_mha, | |
shift_mlp, | |
scale_mlp, | |
gate_mlp, | |
) = self.adaLN_modulation(c).chunk(6, dim=2) | |
else: | |
shift_mha, scale_mha, gate_mha, shift_mlp, scale_mlp, gate_mlp = 6 * (1.0,) | |
x = x + gate_mha * self.drop_path( | |
self.attn( | |
self.norm1(x) * scale_mha + shift_mha, | |
) | |
) | |
x = x + gate_mlp * self.drop_path( | |
self.mlp(self.norm2(x) * scale_mlp + shift_mlp) | |
) | |
return x | |
class SpectFormer(nn.Module): | |
def __init__( | |
self, | |
grid_size: int = 224 // 16, | |
embed_dim=768, | |
depth=12, | |
n_spectral_blocks=4, | |
num_heads: int = 8, | |
mlp_ratio=4.0, | |
uniform_drop=False, | |
drop_rate=0.0, | |
drop_path_rate=0.0, | |
window_size=2, | |
dp_rank=2, | |
norm_layer=nn.LayerNorm, | |
checkpoint_layers: list[int] | None = None, | |
rpe=False, | |
ensemble: int | None = None, | |
nglo: int = 0, | |
): | |
""" | |
Args: | |
img_size (int, tuple): input image size | |
patch_size (int, tuple): patch size | |
embed_dim (int): embedding dimension | |
depth (int): depth of transformer | |
n_spectral_blocks (int): number of spectral gating blocks | |
mlp_ratio (int): ratio of mlp hidden dim to embedding dim | |
uniform_drop (bool): true for uniform, false for linearly increasing drop path probability. | |
drop_rate (float): dropout rate | |
drop_path_rate (float): drop path (stochastic depth) rate | |
window_size: window size for long/short attention | |
dp_rank: dp rank for long/short attention | |
norm_layer: (nn.Module): normalization layer for attention blocks | |
checkpoint_layers: indicate which layers to use for checkpointing | |
rpe: Use relative position encoding in Long-Short attention blocks. | |
ensemble: Integer indicating ensemble size or None for deterministic model. | |
nglo: Number of (additional) global tokens. | |
""" | |
super().__init__() | |
self.embed_dim = embed_dim | |
self.n_spectral_blocks = n_spectral_blocks | |
self._checkpoint_layers = checkpoint_layers or [] | |
self.ensemble = ensemble | |
self.nglo = nglo | |
h = grid_size | |
w = h // 2 + 1 | |
if uniform_drop: | |
_logger.info(f"Using uniform droppath with expect rate {drop_path_rate}.") | |
dpr = [drop_path_rate for _ in range(depth)] | |
else: | |
_logger.info( | |
f"Using linear droppath with expect rate {drop_path_rate * 0.5}." | |
) | |
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] | |
self.blocks_spectral_gating = nn.ModuleList() | |
self.blocks_attention = nn.ModuleList() | |
for i in range(depth): | |
if i < n_spectral_blocks: | |
layer = BlockSpectralGating( | |
dim=embed_dim, | |
mlp_ratio=mlp_ratio, | |
drop=drop_rate, | |
drop_path=dpr[i], | |
norm_layer=norm_layer, | |
h=h, | |
w=w, | |
) | |
self.blocks_spectral_gating.append(layer) | |
else: | |
layer = BlockAttention( | |
dim=embed_dim, | |
num_heads=num_heads, | |
mlp_ratio=mlp_ratio, | |
drop=drop_rate, | |
drop_path=dpr[i], | |
norm_layer=norm_layer, | |
w=window_size, | |
dp_rank=dp_rank, | |
rpe=rpe, | |
adaLN=True if ensemble is not None else False, | |
nglo=nglo, | |
) | |
self.blocks_attention.append(layer) | |
self.apply(self._init_weights) | |
def forward(self, tokens: torch.Tensor) -> torch.Tensor: | |
""" | |
Args: | |
tokens: Tensor of shape B, N, C for deterministic of BxE, N, C for ensemble forecast. | |
Returns: | |
Tensor of same shape as input. | |
""" | |
if self.ensemble: | |
BE, N, C = tokens.shape | |
noise = torch.randn( | |
size=(BE, N, C), dtype=tokens.dtype, device=tokens.device | |
) | |
else: | |
noise = None | |
for i, blk in enumerate( | |
chain(self.blocks_spectral_gating, self.blocks_attention) | |
): | |
if i in self._checkpoint_layers: | |
tokens = checkpoint(blk, tokens, noise, use_reentrant=False) | |
else: | |
tokens = blk(tokens, noise) | |
return tokens | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
trunc_normal_(m.weight, std=0.02) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |