johannesschmude's picture
Initial commit
b73936d
import math
import logging
from itertools import chain
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from timm.models.layers import DropPath, trunc_normal_
import torch.fft
from .transformer_ls import AttentionLS
_logger = logging.getLogger(__name__)
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class SpectralGatingNetwork(nn.Module):
def __init__(self, dim, h=14, w=8):
super().__init__()
self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2) * 0.02)
self.w = w
self.h = h
def forward(self, x, spatial_size=None):
B, N, C = x.shape # torch.Size([1, 262144, 1024])
if spatial_size is None:
a = b = int(math.sqrt(N)) # a=b=512
else:
a, b = spatial_size
x = x.view(B, a, b, C) # torch.Size([1, 512, 512, 1024])
# FROM HERE USED TO BE AUTOCAST to float32
dtype = x.dtype
x = x.to(torch.float32)
x = torch.fft.rfft2(
x, dim=(1, 2), norm="ortho"
) # torch.Size([1, 512, 257, 1024])
weight = torch.view_as_complex(
self.complex_weight.to(torch.float32)
) # torch.Size([512, 257, 1024])
x = x * weight
x = torch.fft.irfft2(
x, s=(a, b), dim=(1, 2), norm="ortho"
) # torch.Size([1, 512, 512, 1024])
x = x.to(dtype)
x = x.reshape(B, N, C) # torch.Size([1, 262144, 1024])
# UP TO HERE USED TO BE AUTOCAST to float32
return x
class BlockSpectralGating(nn.Module):
def __init__(
self,
dim,
mlp_ratio=4.0,
drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
h=14,
w=8,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.filter = SpectralGatingNetwork(dim, h=h, w=w)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
def forward(self, x, *args):
x = x + self.drop_path(self.mlp(self.norm2(self.filter(self.norm1(x)))))
return x
class BlockAttention(nn.Module):
def __init__(
self,
dim,
num_heads: int = 8,
mlp_ratio=4.0,
drop=0.0,
drop_path=0.0,
w=2,
dp_rank=2,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
rpe=False,
adaLN=False,
nglo=0,
):
"""
num_heads: Attention heads. 4 for tiny, 8 for small and 12 for base
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
self.attn = AttentionLS(
dim=dim,
num_heads=num_heads,
w=w,
dp_rank=dp_rank,
nglo=nglo,
rpe=rpe,
)
if adaLN:
self.adaLN_modulation = nn.Sequential(
nn.Linear(dim, dim, bias=True),
act_layer(),
nn.Linear(dim, 6 * dim, bias=True),
)
else:
self.adaLN_modulation = None
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
if self.adaLN_modulation is not None:
(
shift_mha,
scale_mha,
gate_mha,
shift_mlp,
scale_mlp,
gate_mlp,
) = self.adaLN_modulation(c).chunk(6, dim=2)
else:
shift_mha, scale_mha, gate_mha, shift_mlp, scale_mlp, gate_mlp = 6 * (1.0,)
x = x + gate_mha * self.drop_path(
self.attn(
self.norm1(x) * scale_mha + shift_mha,
)
)
x = x + gate_mlp * self.drop_path(
self.mlp(self.norm2(x) * scale_mlp + shift_mlp)
)
return x
class SpectFormer(nn.Module):
def __init__(
self,
grid_size: int = 224 // 16,
embed_dim=768,
depth=12,
n_spectral_blocks=4,
num_heads: int = 8,
mlp_ratio=4.0,
uniform_drop=False,
drop_rate=0.0,
drop_path_rate=0.0,
window_size=2,
dp_rank=2,
norm_layer=nn.LayerNorm,
checkpoint_layers: list[int] | None = None,
rpe=False,
ensemble: int | None = None,
nglo: int = 0,
):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
embed_dim (int): embedding dimension
depth (int): depth of transformer
n_spectral_blocks (int): number of spectral gating blocks
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
uniform_drop (bool): true for uniform, false for linearly increasing drop path probability.
drop_rate (float): dropout rate
drop_path_rate (float): drop path (stochastic depth) rate
window_size: window size for long/short attention
dp_rank: dp rank for long/short attention
norm_layer: (nn.Module): normalization layer for attention blocks
checkpoint_layers: indicate which layers to use for checkpointing
rpe: Use relative position encoding in Long-Short attention blocks.
ensemble: Integer indicating ensemble size or None for deterministic model.
nglo: Number of (additional) global tokens.
"""
super().__init__()
self.embed_dim = embed_dim
self.n_spectral_blocks = n_spectral_blocks
self._checkpoint_layers = checkpoint_layers or []
self.ensemble = ensemble
self.nglo = nglo
h = grid_size
w = h // 2 + 1
if uniform_drop:
_logger.info(f"Using uniform droppath with expect rate {drop_path_rate}.")
dpr = [drop_path_rate for _ in range(depth)]
else:
_logger.info(
f"Using linear droppath with expect rate {drop_path_rate * 0.5}."
)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks_spectral_gating = nn.ModuleList()
self.blocks_attention = nn.ModuleList()
for i in range(depth):
if i < n_spectral_blocks:
layer = BlockSpectralGating(
dim=embed_dim,
mlp_ratio=mlp_ratio,
drop=drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
h=h,
w=w,
)
self.blocks_spectral_gating.append(layer)
else:
layer = BlockAttention(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
drop=drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
w=window_size,
dp_rank=dp_rank,
rpe=rpe,
adaLN=True if ensemble is not None else False,
nglo=nglo,
)
self.blocks_attention.append(layer)
self.apply(self._init_weights)
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
"""
Args:
tokens: Tensor of shape B, N, C for deterministic of BxE, N, C for ensemble forecast.
Returns:
Tensor of same shape as input.
"""
if self.ensemble:
BE, N, C = tokens.shape
noise = torch.randn(
size=(BE, N, C), dtype=tokens.dtype, device=tokens.device
)
else:
noise = None
for i, blk in enumerate(
chain(self.blocks_spectral_gating, self.blocks_attention)
):
if i in self._checkpoint_layers:
tokens = checkpoint(blk, tokens, noise, use_reentrant=False)
else:
tokens = blk(tokens, noise)
return tokens
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)