johannesschmude's picture
Initial commit
b73936d
raw
history blame
9.53 kB
import math
import logging
from itertools import chain
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from timm.models.layers import DropPath, trunc_normal_
import torch.fft
from .transformer_ls import AttentionLS
_logger = logging.getLogger(__name__)
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class SpectralGatingNetwork(nn.Module):
def __init__(self, dim, h=14, w=8):
super().__init__()
self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2) * 0.02)
self.w = w
self.h = h
def forward(self, x, spatial_size=None):
B, N, C = x.shape # torch.Size([1, 262144, 1024])
if spatial_size is None:
a = b = int(math.sqrt(N)) # a=b=512
else:
a, b = spatial_size
x = x.view(B, a, b, C) # torch.Size([1, 512, 512, 1024])
# FROM HERE USED TO BE AUTOCAST to float32
dtype = x.dtype
x = x.to(torch.float32)
x = torch.fft.rfft2(
x, dim=(1, 2), norm="ortho"
) # torch.Size([1, 512, 257, 1024])
weight = torch.view_as_complex(
self.complex_weight.to(torch.float32)
) # torch.Size([512, 257, 1024])
x = x * weight
x = torch.fft.irfft2(
x, s=(a, b), dim=(1, 2), norm="ortho"
) # torch.Size([1, 512, 512, 1024])
x = x.to(dtype)
x = x.reshape(B, N, C) # torch.Size([1, 262144, 1024])
# UP TO HERE USED TO BE AUTOCAST to float32
return x
class BlockSpectralGating(nn.Module):
def __init__(
self,
dim,
mlp_ratio=4.0,
drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
h=14,
w=8,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.filter = SpectralGatingNetwork(dim, h=h, w=w)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
def forward(self, x, *args):
x = x + self.drop_path(self.mlp(self.norm2(self.filter(self.norm1(x)))))
return x
class BlockAttention(nn.Module):
def __init__(
self,
dim,
num_heads: int = 8,
mlp_ratio=4.0,
drop=0.0,
drop_path=0.0,
w=2,
dp_rank=2,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
rpe=False,
adaLN=False,
nglo=0,
):
"""
num_heads: Attention heads. 4 for tiny, 8 for small and 12 for base
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
self.attn = AttentionLS(
dim=dim,
num_heads=num_heads,
w=w,
dp_rank=dp_rank,
nglo=nglo,
rpe=rpe,
)
if adaLN:
self.adaLN_modulation = nn.Sequential(
nn.Linear(dim, dim, bias=True),
act_layer(),
nn.Linear(dim, 6 * dim, bias=True),
)
else:
self.adaLN_modulation = None
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
if self.adaLN_modulation is not None:
(
shift_mha,
scale_mha,
gate_mha,
shift_mlp,
scale_mlp,
gate_mlp,
) = self.adaLN_modulation(c).chunk(6, dim=2)
else:
shift_mha, scale_mha, gate_mha, shift_mlp, scale_mlp, gate_mlp = 6 * (1.0,)
x = x + gate_mha * self.drop_path(
self.attn(
self.norm1(x) * scale_mha + shift_mha,
)
)
x = x + gate_mlp * self.drop_path(
self.mlp(self.norm2(x) * scale_mlp + shift_mlp)
)
return x
class SpectFormer(nn.Module):
def __init__(
self,
grid_size: int = 224 // 16,
embed_dim=768,
depth=12,
n_spectral_blocks=4,
num_heads: int = 8,
mlp_ratio=4.0,
uniform_drop=False,
drop_rate=0.0,
drop_path_rate=0.0,
window_size=2,
dp_rank=2,
norm_layer=nn.LayerNorm,
checkpoint_layers: list[int] | None = None,
rpe=False,
ensemble: int | None = None,
nglo: int = 0,
):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
embed_dim (int): embedding dimension
depth (int): depth of transformer
n_spectral_blocks (int): number of spectral gating blocks
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
uniform_drop (bool): true for uniform, false for linearly increasing drop path probability.
drop_rate (float): dropout rate
drop_path_rate (float): drop path (stochastic depth) rate
window_size: window size for long/short attention
dp_rank: dp rank for long/short attention
norm_layer: (nn.Module): normalization layer for attention blocks
checkpoint_layers: indicate which layers to use for checkpointing
rpe: Use relative position encoding in Long-Short attention blocks.
ensemble: Integer indicating ensemble size or None for deterministic model.
nglo: Number of (additional) global tokens.
"""
super().__init__()
self.embed_dim = embed_dim
self.n_spectral_blocks = n_spectral_blocks
self._checkpoint_layers = checkpoint_layers or []
self.ensemble = ensemble
self.nglo = nglo
h = grid_size
w = h // 2 + 1
if uniform_drop:
_logger.info(f"Using uniform droppath with expect rate {drop_path_rate}.")
dpr = [drop_path_rate for _ in range(depth)]
else:
_logger.info(
f"Using linear droppath with expect rate {drop_path_rate * 0.5}."
)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks_spectral_gating = nn.ModuleList()
self.blocks_attention = nn.ModuleList()
for i in range(depth):
if i < n_spectral_blocks:
layer = BlockSpectralGating(
dim=embed_dim,
mlp_ratio=mlp_ratio,
drop=drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
h=h,
w=w,
)
self.blocks_spectral_gating.append(layer)
else:
layer = BlockAttention(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
drop=drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
w=window_size,
dp_rank=dp_rank,
rpe=rpe,
adaLN=True if ensemble is not None else False,
nglo=nglo,
)
self.blocks_attention.append(layer)
self.apply(self._init_weights)
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
"""
Args:
tokens: Tensor of shape B, N, C for deterministic of BxE, N, C for ensemble forecast.
Returns:
Tensor of same shape as input.
"""
if self.ensemble:
BE, N, C = tokens.shape
noise = torch.randn(
size=(BE, N, C), dtype=tokens.dtype, device=tokens.device
)
else:
noise = None
for i, blk in enumerate(
chain(self.blocks_spectral_gating, self.blocks_attention)
):
if i in self._checkpoint_layers:
tokens = checkpoint(blk, tokens, noise, use_reentrant=False)
else:
tokens = blk(tokens, noise)
return tokens
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)