Spaces:
Runtime error
Runtime error
import torch | |
from einops import rearrange | |
from torch import nn | |
import numpy as np | |
from .spectformer import SpectFormer, BlockSpectralGating, BlockAttention | |
from .embedding import ( | |
LinearEmbedding, | |
PatchEmbed3D, | |
PerceiverChannelEmbedding, | |
LinearDecoder, | |
PerceiverDecoder, | |
) | |
from .flow import HelioFlowModel | |
class HelioSpectFormer(nn.Module): | |
""" | |
A note on the ensemble capability: | |
Ensembles of size E are generated by setting `ensemble=E`. In this case, the forward | |
pass generates ensemble members after tokenization by increasing the batch dimension | |
B to B x E. Noise is injected in the `self.backbone` Specformer blocks. After the | |
backbone, ensemble members ride along implicitly in the batch dimension. (This is | |
mainly through the `self.unembed` pass.) An explicit ensemble dimension is only | |
generated at the end. | |
""" | |
def __init__( | |
self, | |
img_size: int, | |
patch_size: int, | |
in_chans: int, | |
embed_dim: int, | |
time_embedding: dict, | |
depth: int, | |
n_spectral_blocks: int, | |
num_heads: int, | |
mlp_ratio: float, | |
drop_rate: float, | |
window_size: int, | |
dp_rank: int, | |
learned_flow: bool = False, | |
use_latitude_in_learned_flow: bool = False, | |
init_weights: bool = False, | |
checkpoint_layers: list[int] | None = None, | |
rpe: bool = False, | |
ensemble: int | None = None, | |
finetune: bool = True, | |
nglo: int = 0, | |
dtype: torch.dtype | None = None, | |
) -> None: | |
""" | |
Args: | |
img_size: input image size | |
patch_size: patch size | |
in_chans: number of iput channels | |
embed_dim: embeddin dimension | |
time_embedding: dictionary to configure temporal embedding: | |
`type` (str, required): indicates embedding type. `linear`, `perceiver`. | |
`time_dim` (int): indicates length of time dimension. required for linear embedding. | |
`n_queries` (int): indicates number of perceiver queries. required for perceiver. | |
depth: number of transformer blocks | |
n_spectral_blocks: number of spectral gating blocks | |
num_heads: Number of transformer heads | |
mlp_ratio: MLP ratio for transformer blocks | |
drop_rate: dropout rate | |
window_size: window size for long/short attention | |
dp_rank: dp rank for long/short attention | |
learned_flow: if true, combine learned flow model with spectformer | |
use_latitude_in_learned_flow: use latitudes in learned flow | |
init_weights: use optimized weight initialization | |
checkpoint_layers: indicate which layers to use for checkpointing | |
rpe: Use relative position encoding in Long-Short attention blocks. | |
ensemble: Integer indicating ensemble size or None for deterministic model. | |
finetune: Indicates whether to train from scrach or fine-tune the model. If set to `True`, the final output layers are removed. | |
nglo: Number of (additional) global tokens. | |
dtype: A torch data type. Not used and added only for compatibility with the remainder of the codebase. | |
""" | |
super().__init__() | |
self.learned_flow = learned_flow | |
self.patch_size = patch_size | |
self.embed_dim = embed_dim | |
self.in_chans = in_chans | |
self.time_embedding = time_embedding | |
self.ensemble = ensemble | |
self.finetune = finetune | |
self.nglo = nglo | |
if learned_flow: | |
self.learned_flow_model = HelioFlowModel( | |
img_size=(img_size, img_size), | |
use_latitude_in_learned_flow=use_latitude_in_learned_flow, | |
) | |
match time_embedding["type"]: | |
case "linear": | |
self.time_dim = time_embedding["time_dim"] | |
if learned_flow: | |
self.time_dim += 1 | |
self.embedding = LinearEmbedding( | |
img_size, patch_size, in_chans, self.time_dim, embed_dim, drop_rate | |
) | |
if not self.finetune: | |
self.unembed = LinearDecoder( | |
patch_size=patch_size, out_chans=in_chans, embed_dim=embed_dim | |
) | |
case "perceiver": | |
self.embedding = PerceiverChannelEmbedding( | |
in_chans=in_chans, | |
img_size=img_size, | |
patch_size=patch_size, | |
time_dim=time_embedding["time_dim"], | |
num_queries=time_embedding["n_queries"], | |
embed_dim=embed_dim, | |
drop_rate=drop_rate, | |
) | |
if not self.finetune: | |
self.unembed = PerceiverDecoder( | |
embed_dim=embed_dim, | |
patch_size=patch_size, | |
out_chans=in_chans, | |
) | |
case _: | |
raise NotImplementedError( | |
f'Embedding {time_embedding["type"]} has not been implemented.' | |
) | |
if isinstance(depth, list): | |
raise NotImplementedError( | |
"Multi scale models are no longer supported. Depth should be a single integer." | |
) | |
self.backbone = SpectFormer( | |
grid_size=img_size // patch_size, | |
embed_dim=embed_dim, | |
depth=depth, | |
n_spectral_blocks=n_spectral_blocks, | |
num_heads=num_heads, | |
mlp_ratio=mlp_ratio, | |
drop_rate=drop_rate, | |
window_size=window_size, | |
dp_rank=dp_rank, | |
checkpoint_layers=checkpoint_layers, | |
rpe=rpe, | |
ensemble=ensemble, | |
nglo=nglo, | |
) | |
if init_weights: | |
self.apply(self._init_weights) | |
# @staticmethod | |
# def _checkpoint_wrapper( | |
# model: nn.Module, data: tuple[Tensor, Tensor | None] | |
# ) -> Tensor: | |
# return checkpoint(model, data, use_reentrant=False) | |
def _init_weights(self, module): | |
if self.time_embedding["type"] == "linear": | |
# sampling_step * embed_dim = patch_size**2 * in_chans * time_dim | |
sampling_step = int( | |
np.sqrt( | |
(self.patch_size**2 * self.in_chans * self.time_dim) | |
/ self.embed_dim | |
) | |
) | |
else: | |
sampling_step = int( | |
np.sqrt((self.patch_size**2 * self.in_chans) / self.embed_dim) | |
) | |
if isinstance(module, PatchEmbed3D): | |
torch.nn.init.zeros_(module.proj.weight) | |
c_out = 0 | |
w_pool = 1.0 / sampling_step | |
for k in range(self.in_chans * self.time_dim): | |
for i in range(0, self.patch_size, sampling_step): | |
for j in range(0, self.patch_size, sampling_step): | |
module.proj.weight.data[ | |
c_out, k, i : i + sampling_step, j : j + sampling_step | |
] = w_pool | |
c_out += 1 | |
if module.proj.bias is not None: | |
module.proj.bias.data.zero_() | |
if isinstance(module, BlockSpectralGating): | |
for m in [ | |
module.mlp.fc1, | |
module.mlp.fc2, | |
]: | |
# m.weight.data.normal_(mean=0.0, std=0.01) | |
# torch.nn.init.eye_(m.weight) | |
torch.nn.init.eye_(m.weight) | |
if m.bias is not None: | |
m.bias.data.zero_() | |
if isinstance(module, BlockAttention): | |
for m in [ | |
module.mlp.fc1, | |
module.mlp.fc2, | |
]: | |
# torch.nn.init.eye_(m.weight) | |
torch.nn.init.zeros_(m.weight) | |
if m.bias is not None: | |
m.bias.data.zero_() | |
for m in [ | |
module.attn.qkv, | |
module.attn.proj, | |
module.attn.to_dynamic_projection, | |
]: | |
# m.weight.data.normal_(mean=0.0, std=0.01) | |
# torch.nn.init.eye_(m.weight) | |
torch.nn.init.zeros_(m.weight) | |
if m.bias is not None: | |
m.bias.data.zero_() | |
if isinstance(module, torch.nn.Sequential): | |
if isinstance(module[1], torch.nn.PixelShuffle): | |
# torch.nn.init.eye_(module[0].weight.data[:,:,0,0]) | |
torch.nn.init.zeros_(module[0].weight) | |
if self.time_embedding["type"] == "linear": | |
c_out = 0 | |
for k in range(1, self.in_chans + 1): | |
for i in range( | |
self.patch_size**2 // (self.patch_size * sampling_step) | |
): | |
for j in range(self.patch_size): | |
module[0].weight.data[ | |
c_out : c_out + sampling_step, | |
j + (k * self.time_dim - 1) * self.patch_size, | |
] = 1.0 | |
c_out += sampling_step | |
else: | |
c_out = 0 | |
for k in range(2 * self.in_chans): | |
# l = 0 | |
for l_feat in range(self.backbone.embed_dim): | |
module[0].weight.data[c_out, l_feat] = 1.0 | |
c_out += 1 | |
if module[0].bias is not None: | |
module[0].bias.data.zero_() | |
def forward(self, batch): | |
""" | |
Args: | |
batch: Dictionary containing keys `ts` and `time_delta_input`. | |
Their values are tensors with shapes as follows. | |
ts: B, C, T, H, W | |
time_delta_input: B, T | |
Returns: | |
Tensor fo shape (B, C, H, W) for deterministic or (B, E, C, H, W) for ensemble forecasts. | |
""" | |
x = batch["ts"] | |
dt = batch["time_delta_input"] | |
B, C, T, H, W = x.shape | |
if self.learned_flow: | |
y_hat_flow = self.learned_flow_model(batch) # B, C, H, W | |
if any( | |
[param.requires_grad for param in self.learned_flow_model.parameters()] | |
): | |
return y_hat_flow | |
else: | |
x = torch.concat((x, y_hat_flow.unsqueeze(2)), dim=2) # B, C, T+1, H, W | |
if self.time_embedding["type"] == "perceiver": | |
dt = torch.cat((dt, batch["lead_time_delta"].reshape(-1, 1)), dim=1) | |
# embed the data | |
tokens = self.embedding(x, dt) | |
# copy tokens in case of ensemble forecast | |
if self.ensemble: | |
# B L D -> (B E) L D == BE L D | |
tokens = torch.repeat_interleave(tokens, repeats=self.ensemble, dim=0) | |
# pass the time series through the encoder | |
tokens = self.backbone(tokens) | |
if self.finetune: | |
return tokens | |
# Unembed the tokens | |
# BE L D -> BE C H W | |
forecast_hat = self.unembed(tokens) | |
assert forecast_hat.shape == ( | |
B * self.ensemble if self.ensemble else B, | |
C, | |
H, | |
W, | |
), f"forecast_hat has shape {forecast_hat.shape} yet expected {(B*self.ensemble if self.ensemble else B, C, H, W)}." | |
if self.learned_flow: | |
assert y_hat_flow.shape == ( | |
B, | |
C, | |
H, | |
W, | |
), f"y_hat_flow has shape {y_hat_flow.shape} yet expected {(B, C, H, W)}." | |
if self.ensemble: | |
y_hat_flow = torch.repeat_interleave( | |
y_hat_flow, repeats=self.ensemble, dim=0 | |
) | |
assert y_hat_flow.shape == forecast_hat.shape | |
forecast_hat = forecast_hat + y_hat_flow | |
assert forecast_hat.shape == ( | |
B * self.ensemble if self.ensemble else B, | |
C, | |
H, | |
W, | |
), f"forecast_hat has shape {forecast_hat.shape} yet expected {(B*self.ensemble if self.ensemble else B, C, H, W)}." | |
if self.ensemble: | |
forecast_hat = rearrange( | |
forecast_hat, "(B E) C H W -> B E C H W", B=B, E=self.ensemble | |
) | |
return forecast_hat | |