johannesschmude's picture
Initial commit
b73936d
import torch
from einops import rearrange
from torch import nn
import numpy as np
from .spectformer import SpectFormer, BlockSpectralGating, BlockAttention
from .embedding import (
LinearEmbedding,
PatchEmbed3D,
PerceiverChannelEmbedding,
LinearDecoder,
PerceiverDecoder,
)
from .flow import HelioFlowModel
class HelioSpectFormer(nn.Module):
"""
A note on the ensemble capability:
Ensembles of size E are generated by setting `ensemble=E`. In this case, the forward
pass generates ensemble members after tokenization by increasing the batch dimension
B to B x E. Noise is injected in the `self.backbone` Specformer blocks. After the
backbone, ensemble members ride along implicitly in the batch dimension. (This is
mainly through the `self.unembed` pass.) An explicit ensemble dimension is only
generated at the end.
"""
def __init__(
self,
img_size: int,
patch_size: int,
in_chans: int,
embed_dim: int,
time_embedding: dict,
depth: int,
n_spectral_blocks: int,
num_heads: int,
mlp_ratio: float,
drop_rate: float,
window_size: int,
dp_rank: int,
learned_flow: bool = False,
use_latitude_in_learned_flow: bool = False,
init_weights: bool = False,
checkpoint_layers: list[int] | None = None,
rpe: bool = False,
ensemble: int | None = None,
finetune: bool = True,
nglo: int = 0,
dtype: torch.dtype | None = None,
) -> None:
"""
Args:
img_size: input image size
patch_size: patch size
in_chans: number of iput channels
embed_dim: embeddin dimension
time_embedding: dictionary to configure temporal embedding:
`type` (str, required): indicates embedding type. `linear`, `perceiver`.
`time_dim` (int): indicates length of time dimension. required for linear embedding.
`n_queries` (int): indicates number of perceiver queries. required for perceiver.
depth: number of transformer blocks
n_spectral_blocks: number of spectral gating blocks
num_heads: Number of transformer heads
mlp_ratio: MLP ratio for transformer blocks
drop_rate: dropout rate
window_size: window size for long/short attention
dp_rank: dp rank for long/short attention
learned_flow: if true, combine learned flow model with spectformer
use_latitude_in_learned_flow: use latitudes in learned flow
init_weights: use optimized weight initialization
checkpoint_layers: indicate which layers to use for checkpointing
rpe: Use relative position encoding in Long-Short attention blocks.
ensemble: Integer indicating ensemble size or None for deterministic model.
finetune: Indicates whether to train from scrach or fine-tune the model. If set to `True`, the final output layers are removed.
nglo: Number of (additional) global tokens.
dtype: A torch data type. Not used and added only for compatibility with the remainder of the codebase.
"""
super().__init__()
self.learned_flow = learned_flow
self.patch_size = patch_size
self.embed_dim = embed_dim
self.in_chans = in_chans
self.time_embedding = time_embedding
self.ensemble = ensemble
self.finetune = finetune
self.nglo = nglo
if learned_flow:
self.learned_flow_model = HelioFlowModel(
img_size=(img_size, img_size),
use_latitude_in_learned_flow=use_latitude_in_learned_flow,
)
match time_embedding["type"]:
case "linear":
self.time_dim = time_embedding["time_dim"]
if learned_flow:
self.time_dim += 1
self.embedding = LinearEmbedding(
img_size, patch_size, in_chans, self.time_dim, embed_dim, drop_rate
)
if not self.finetune:
self.unembed = LinearDecoder(
patch_size=patch_size, out_chans=in_chans, embed_dim=embed_dim
)
case "perceiver":
self.embedding = PerceiverChannelEmbedding(
in_chans=in_chans,
img_size=img_size,
patch_size=patch_size,
time_dim=time_embedding["time_dim"],
num_queries=time_embedding["n_queries"],
embed_dim=embed_dim,
drop_rate=drop_rate,
)
if not self.finetune:
self.unembed = PerceiverDecoder(
embed_dim=embed_dim,
patch_size=patch_size,
out_chans=in_chans,
)
case _:
raise NotImplementedError(
f'Embedding {time_embedding["type"]} has not been implemented.'
)
if isinstance(depth, list):
raise NotImplementedError(
"Multi scale models are no longer supported. Depth should be a single integer."
)
self.backbone = SpectFormer(
grid_size=img_size // patch_size,
embed_dim=embed_dim,
depth=depth,
n_spectral_blocks=n_spectral_blocks,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
drop_rate=drop_rate,
window_size=window_size,
dp_rank=dp_rank,
checkpoint_layers=checkpoint_layers,
rpe=rpe,
ensemble=ensemble,
nglo=nglo,
)
if init_weights:
self.apply(self._init_weights)
# @staticmethod
# def _checkpoint_wrapper(
# model: nn.Module, data: tuple[Tensor, Tensor | None]
# ) -> Tensor:
# return checkpoint(model, data, use_reentrant=False)
def _init_weights(self, module):
if self.time_embedding["type"] == "linear":
# sampling_step * embed_dim = patch_size**2 * in_chans * time_dim
sampling_step = int(
np.sqrt(
(self.patch_size**2 * self.in_chans * self.time_dim)
/ self.embed_dim
)
)
else:
sampling_step = int(
np.sqrt((self.patch_size**2 * self.in_chans) / self.embed_dim)
)
if isinstance(module, PatchEmbed3D):
torch.nn.init.zeros_(module.proj.weight)
c_out = 0
w_pool = 1.0 / sampling_step
for k in range(self.in_chans * self.time_dim):
for i in range(0, self.patch_size, sampling_step):
for j in range(0, self.patch_size, sampling_step):
module.proj.weight.data[
c_out, k, i : i + sampling_step, j : j + sampling_step
] = w_pool
c_out += 1
if module.proj.bias is not None:
module.proj.bias.data.zero_()
if isinstance(module, BlockSpectralGating):
for m in [
module.mlp.fc1,
module.mlp.fc2,
]:
# m.weight.data.normal_(mean=0.0, std=0.01)
# torch.nn.init.eye_(m.weight)
torch.nn.init.eye_(m.weight)
if m.bias is not None:
m.bias.data.zero_()
if isinstance(module, BlockAttention):
for m in [
module.mlp.fc1,
module.mlp.fc2,
]:
# torch.nn.init.eye_(m.weight)
torch.nn.init.zeros_(m.weight)
if m.bias is not None:
m.bias.data.zero_()
for m in [
module.attn.qkv,
module.attn.proj,
module.attn.to_dynamic_projection,
]:
# m.weight.data.normal_(mean=0.0, std=0.01)
# torch.nn.init.eye_(m.weight)
torch.nn.init.zeros_(m.weight)
if m.bias is not None:
m.bias.data.zero_()
if isinstance(module, torch.nn.Sequential):
if isinstance(module[1], torch.nn.PixelShuffle):
# torch.nn.init.eye_(module[0].weight.data[:,:,0,0])
torch.nn.init.zeros_(module[0].weight)
if self.time_embedding["type"] == "linear":
c_out = 0
for k in range(1, self.in_chans + 1):
for i in range(
self.patch_size**2 // (self.patch_size * sampling_step)
):
for j in range(self.patch_size):
module[0].weight.data[
c_out : c_out + sampling_step,
j + (k * self.time_dim - 1) * self.patch_size,
] = 1.0
c_out += sampling_step
else:
c_out = 0
for k in range(2 * self.in_chans):
# l = 0
for l_feat in range(self.backbone.embed_dim):
module[0].weight.data[c_out, l_feat] = 1.0
c_out += 1
if module[0].bias is not None:
module[0].bias.data.zero_()
def forward(self, batch):
"""
Args:
batch: Dictionary containing keys `ts` and `time_delta_input`.
Their values are tensors with shapes as follows.
ts: B, C, T, H, W
time_delta_input: B, T
Returns:
Tensor fo shape (B, C, H, W) for deterministic or (B, E, C, H, W) for ensemble forecasts.
"""
x = batch["ts"]
dt = batch["time_delta_input"]
B, C, T, H, W = x.shape
if self.learned_flow:
y_hat_flow = self.learned_flow_model(batch) # B, C, H, W
if any(
[param.requires_grad for param in self.learned_flow_model.parameters()]
):
return y_hat_flow
else:
x = torch.concat((x, y_hat_flow.unsqueeze(2)), dim=2) # B, C, T+1, H, W
if self.time_embedding["type"] == "perceiver":
dt = torch.cat((dt, batch["lead_time_delta"].reshape(-1, 1)), dim=1)
# embed the data
tokens = self.embedding(x, dt)
# copy tokens in case of ensemble forecast
if self.ensemble:
# B L D -> (B E) L D == BE L D
tokens = torch.repeat_interleave(tokens, repeats=self.ensemble, dim=0)
# pass the time series through the encoder
tokens = self.backbone(tokens)
if self.finetune:
return tokens
# Unembed the tokens
# BE L D -> BE C H W
forecast_hat = self.unembed(tokens)
assert forecast_hat.shape == (
B * self.ensemble if self.ensemble else B,
C,
H,
W,
), f"forecast_hat has shape {forecast_hat.shape} yet expected {(B*self.ensemble if self.ensemble else B, C, H, W)}."
if self.learned_flow:
assert y_hat_flow.shape == (
B,
C,
H,
W,
), f"y_hat_flow has shape {y_hat_flow.shape} yet expected {(B, C, H, W)}."
if self.ensemble:
y_hat_flow = torch.repeat_interleave(
y_hat_flow, repeats=self.ensemble, dim=0
)
assert y_hat_flow.shape == forecast_hat.shape
forecast_hat = forecast_hat + y_hat_flow
assert forecast_hat.shape == (
B * self.ensemble if self.ensemble else B,
C,
H,
W,
), f"forecast_hat has shape {forecast_hat.shape} yet expected {(B*self.ensemble if self.ensemble else B, C, H, W)}."
if self.ensemble:
forecast_hat = rearrange(
forecast_hat, "(B E) C H W -> B E C H W", B=B, E=self.ensemble
)
return forecast_hat