|
import torch |
|
from einops import rearrange |
|
from torch import nn |
|
|
|
import numpy as np |
|
|
|
from .spectformer import SpectFormer, BlockSpectralGating, BlockAttention |
|
from .embedding import ( |
|
LinearEmbedding, |
|
PatchEmbed3D, |
|
PerceiverChannelEmbedding, |
|
LinearDecoder, |
|
PerceiverDecoder, |
|
) |
|
from .flow import HelioFlowModel |
|
|
|
|
|
class HelioSpectFormer(nn.Module): |
|
""" |
|
A note on the ensemble capability: |
|
Ensembles of size E are generated by setting `ensemble=E`. In this case, the forward |
|
pass generates ensemble members after tokenization by increasing the batch dimension |
|
B to B x E. Noise is injected in the `self.backbone` Specformer blocks. After the |
|
backbone, ensemble members ride along implicitly in the batch dimension. (This is |
|
mainly through the `self.unembed` pass.) An explicit ensemble dimension is only |
|
generated at the end. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
img_size: int, |
|
patch_size: int, |
|
in_chans: int, |
|
embed_dim: int, |
|
time_embedding: dict, |
|
depth: int, |
|
n_spectral_blocks: int, |
|
num_heads: int, |
|
mlp_ratio: float, |
|
drop_rate: float, |
|
window_size: int, |
|
dp_rank: int, |
|
learned_flow: bool = False, |
|
use_latitude_in_learned_flow: bool = False, |
|
init_weights: bool = False, |
|
checkpoint_layers: list[int] | None = None, |
|
rpe: bool = False, |
|
ensemble: int | None = None, |
|
finetune: bool = True, |
|
nglo: int = 0, |
|
dtype: torch.dtype | None = None, |
|
) -> None: |
|
""" |
|
Args: |
|
img_size: input image size |
|
patch_size: patch size |
|
in_chans: number of iput channels |
|
embed_dim: embeddin dimension |
|
time_embedding: dictionary to configure temporal embedding: |
|
`type` (str, required): indicates embedding type. `linear`, `perceiver`. |
|
`time_dim` (int): indicates length of time dimension. required for linear embedding. |
|
`n_queries` (int): indicates number of perceiver queries. required for perceiver. |
|
depth: number of transformer blocks |
|
n_spectral_blocks: number of spectral gating blocks |
|
num_heads: Number of transformer heads |
|
mlp_ratio: MLP ratio for transformer blocks |
|
drop_rate: dropout rate |
|
window_size: window size for long/short attention |
|
dp_rank: dp rank for long/short attention |
|
learned_flow: if true, combine learned flow model with spectformer |
|
use_latitude_in_learned_flow: use latitudes in learned flow |
|
init_weights: use optimized weight initialization |
|
checkpoint_layers: indicate which layers to use for checkpointing |
|
rpe: Use relative position encoding in Long-Short attention blocks. |
|
ensemble: Integer indicating ensemble size or None for deterministic model. |
|
finetune: Indicates whether to train from scrach or fine-tune the model. If set to `True`, the final output layers are removed. |
|
nglo: Number of (additional) global tokens. |
|
dtype: A torch data type. Not used and added only for compatibility with the remainder of the codebase. |
|
""" |
|
super().__init__() |
|
|
|
self.learned_flow = learned_flow |
|
self.patch_size = patch_size |
|
self.embed_dim = embed_dim |
|
self.in_chans = in_chans |
|
self.time_embedding = time_embedding |
|
self.ensemble = ensemble |
|
self.finetune = finetune |
|
self.nglo = nglo |
|
|
|
if learned_flow: |
|
self.learned_flow_model = HelioFlowModel( |
|
img_size=(img_size, img_size), |
|
use_latitude_in_learned_flow=use_latitude_in_learned_flow, |
|
) |
|
|
|
match time_embedding["type"]: |
|
case "linear": |
|
self.time_dim = time_embedding["time_dim"] |
|
if learned_flow: |
|
self.time_dim += 1 |
|
self.embedding = LinearEmbedding( |
|
img_size, patch_size, in_chans, self.time_dim, embed_dim, drop_rate |
|
) |
|
|
|
if not self.finetune: |
|
self.unembed = LinearDecoder( |
|
patch_size=patch_size, out_chans=in_chans, embed_dim=embed_dim |
|
) |
|
case "perceiver": |
|
self.embedding = PerceiverChannelEmbedding( |
|
in_chans=in_chans, |
|
img_size=img_size, |
|
patch_size=patch_size, |
|
time_dim=time_embedding["time_dim"], |
|
num_queries=time_embedding["n_queries"], |
|
embed_dim=embed_dim, |
|
drop_rate=drop_rate, |
|
) |
|
if not self.finetune: |
|
self.unembed = PerceiverDecoder( |
|
embed_dim=embed_dim, |
|
patch_size=patch_size, |
|
out_chans=in_chans, |
|
) |
|
case _: |
|
raise NotImplementedError( |
|
f'Embedding {time_embedding["type"]} has not been implemented.' |
|
) |
|
|
|
if isinstance(depth, list): |
|
raise NotImplementedError( |
|
"Multi scale models are no longer supported. Depth should be a single integer." |
|
) |
|
self.backbone = SpectFormer( |
|
grid_size=img_size // patch_size, |
|
embed_dim=embed_dim, |
|
depth=depth, |
|
n_spectral_blocks=n_spectral_blocks, |
|
num_heads=num_heads, |
|
mlp_ratio=mlp_ratio, |
|
drop_rate=drop_rate, |
|
window_size=window_size, |
|
dp_rank=dp_rank, |
|
checkpoint_layers=checkpoint_layers, |
|
rpe=rpe, |
|
ensemble=ensemble, |
|
nglo=nglo, |
|
) |
|
|
|
if init_weights: |
|
self.apply(self._init_weights) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _init_weights(self, module): |
|
|
|
if self.time_embedding["type"] == "linear": |
|
|
|
sampling_step = int( |
|
np.sqrt( |
|
(self.patch_size**2 * self.in_chans * self.time_dim) |
|
/ self.embed_dim |
|
) |
|
) |
|
else: |
|
sampling_step = int( |
|
np.sqrt((self.patch_size**2 * self.in_chans) / self.embed_dim) |
|
) |
|
if isinstance(module, PatchEmbed3D): |
|
torch.nn.init.zeros_(module.proj.weight) |
|
c_out = 0 |
|
w_pool = 1.0 / sampling_step |
|
for k in range(self.in_chans * self.time_dim): |
|
for i in range(0, self.patch_size, sampling_step): |
|
for j in range(0, self.patch_size, sampling_step): |
|
module.proj.weight.data[ |
|
c_out, k, i : i + sampling_step, j : j + sampling_step |
|
] = w_pool |
|
c_out += 1 |
|
if module.proj.bias is not None: |
|
module.proj.bias.data.zero_() |
|
if isinstance(module, BlockSpectralGating): |
|
for m in [ |
|
module.mlp.fc1, |
|
module.mlp.fc2, |
|
]: |
|
|
|
|
|
torch.nn.init.eye_(m.weight) |
|
if m.bias is not None: |
|
m.bias.data.zero_() |
|
if isinstance(module, BlockAttention): |
|
for m in [ |
|
module.mlp.fc1, |
|
module.mlp.fc2, |
|
]: |
|
|
|
torch.nn.init.zeros_(m.weight) |
|
if m.bias is not None: |
|
m.bias.data.zero_() |
|
for m in [ |
|
module.attn.qkv, |
|
module.attn.proj, |
|
module.attn.to_dynamic_projection, |
|
]: |
|
|
|
|
|
torch.nn.init.zeros_(m.weight) |
|
if m.bias is not None: |
|
m.bias.data.zero_() |
|
if isinstance(module, torch.nn.Sequential): |
|
if isinstance(module[1], torch.nn.PixelShuffle): |
|
|
|
torch.nn.init.zeros_(module[0].weight) |
|
if self.time_embedding["type"] == "linear": |
|
c_out = 0 |
|
for k in range(1, self.in_chans + 1): |
|
for i in range( |
|
self.patch_size**2 // (self.patch_size * sampling_step) |
|
): |
|
for j in range(self.patch_size): |
|
module[0].weight.data[ |
|
c_out : c_out + sampling_step, |
|
j + (k * self.time_dim - 1) * self.patch_size, |
|
] = 1.0 |
|
c_out += sampling_step |
|
else: |
|
c_out = 0 |
|
for k in range(2 * self.in_chans): |
|
|
|
for l_feat in range(self.backbone.embed_dim): |
|
module[0].weight.data[c_out, l_feat] = 1.0 |
|
c_out += 1 |
|
if module[0].bias is not None: |
|
module[0].bias.data.zero_() |
|
|
|
def forward(self, batch): |
|
""" |
|
Args: |
|
batch: Dictionary containing keys `ts` and `time_delta_input`. |
|
Their values are tensors with shapes as follows. |
|
ts: B, C, T, H, W |
|
time_delta_input: B, T |
|
Returns: |
|
Tensor fo shape (B, C, H, W) for deterministic or (B, E, C, H, W) for ensemble forecasts. |
|
""" |
|
x = batch["ts"] |
|
dt = batch["time_delta_input"] |
|
B, C, T, H, W = x.shape |
|
|
|
if self.learned_flow: |
|
y_hat_flow = self.learned_flow_model(batch) |
|
if any( |
|
[param.requires_grad for param in self.learned_flow_model.parameters()] |
|
): |
|
return y_hat_flow |
|
else: |
|
x = torch.concat((x, y_hat_flow.unsqueeze(2)), dim=2) |
|
if self.time_embedding["type"] == "perceiver": |
|
dt = torch.cat((dt, batch["lead_time_delta"].reshape(-1, 1)), dim=1) |
|
|
|
|
|
tokens = self.embedding(x, dt) |
|
|
|
|
|
if self.ensemble: |
|
|
|
tokens = torch.repeat_interleave(tokens, repeats=self.ensemble, dim=0) |
|
|
|
|
|
tokens = self.backbone(tokens) |
|
|
|
if self.finetune: |
|
return tokens |
|
|
|
|
|
|
|
forecast_hat = self.unembed(tokens) |
|
|
|
assert forecast_hat.shape == ( |
|
B * self.ensemble if self.ensemble else B, |
|
C, |
|
H, |
|
W, |
|
), f"forecast_hat has shape {forecast_hat.shape} yet expected {(B*self.ensemble if self.ensemble else B, C, H, W)}." |
|
|
|
if self.learned_flow: |
|
assert y_hat_flow.shape == ( |
|
B, |
|
C, |
|
H, |
|
W, |
|
), f"y_hat_flow has shape {y_hat_flow.shape} yet expected {(B, C, H, W)}." |
|
if self.ensemble: |
|
y_hat_flow = torch.repeat_interleave( |
|
y_hat_flow, repeats=self.ensemble, dim=0 |
|
) |
|
assert y_hat_flow.shape == forecast_hat.shape |
|
forecast_hat = forecast_hat + y_hat_flow |
|
|
|
assert forecast_hat.shape == ( |
|
B * self.ensemble if self.ensemble else B, |
|
C, |
|
H, |
|
W, |
|
), f"forecast_hat has shape {forecast_hat.shape} yet expected {(B*self.ensemble if self.ensemble else B, C, H, W)}." |
|
|
|
if self.ensemble: |
|
forecast_hat = rearrange( |
|
forecast_hat, "(B E) C H W -> B E C H W", B=B, E=self.ensemble |
|
) |
|
|
|
return forecast_hat |
|
|