import torch from einops import rearrange from torch import nn import numpy as np from .spectformer import SpectFormer, BlockSpectralGating, BlockAttention from .embedding import ( LinearEmbedding, PatchEmbed3D, PerceiverChannelEmbedding, LinearDecoder, PerceiverDecoder, ) from .flow import HelioFlowModel class HelioSpectFormer(nn.Module): """ A note on the ensemble capability: Ensembles of size E are generated by setting `ensemble=E`. In this case, the forward pass generates ensemble members after tokenization by increasing the batch dimension B to B x E. Noise is injected in the `self.backbone` Specformer blocks. After the backbone, ensemble members ride along implicitly in the batch dimension. (This is mainly through the `self.unembed` pass.) An explicit ensemble dimension is only generated at the end. """ def __init__( self, img_size: int, patch_size: int, in_chans: int, embed_dim: int, time_embedding: dict, depth: int, n_spectral_blocks: int, num_heads: int, mlp_ratio: float, drop_rate: float, window_size: int, dp_rank: int, learned_flow: bool = False, use_latitude_in_learned_flow: bool = False, init_weights: bool = False, checkpoint_layers: list[int] | None = None, rpe: bool = False, ensemble: int | None = None, finetune: bool = True, nglo: int = 0, dtype: torch.dtype | None = None, ) -> None: """ Args: img_size: input image size patch_size: patch size in_chans: number of iput channels embed_dim: embeddin dimension time_embedding: dictionary to configure temporal embedding: `type` (str, required): indicates embedding type. `linear`, `perceiver`. `time_dim` (int): indicates length of time dimension. required for linear embedding. `n_queries` (int): indicates number of perceiver queries. required for perceiver. depth: number of transformer blocks n_spectral_blocks: number of spectral gating blocks num_heads: Number of transformer heads mlp_ratio: MLP ratio for transformer blocks drop_rate: dropout rate window_size: window size for long/short attention dp_rank: dp rank for long/short attention learned_flow: if true, combine learned flow model with spectformer use_latitude_in_learned_flow: use latitudes in learned flow init_weights: use optimized weight initialization checkpoint_layers: indicate which layers to use for checkpointing rpe: Use relative position encoding in Long-Short attention blocks. ensemble: Integer indicating ensemble size or None for deterministic model. finetune: Indicates whether to train from scrach or fine-tune the model. If set to `True`, the final output layers are removed. nglo: Number of (additional) global tokens. dtype: A torch data type. Not used and added only for compatibility with the remainder of the codebase. """ super().__init__() self.learned_flow = learned_flow self.patch_size = patch_size self.embed_dim = embed_dim self.in_chans = in_chans self.time_embedding = time_embedding self.ensemble = ensemble self.finetune = finetune self.nglo = nglo if learned_flow: self.learned_flow_model = HelioFlowModel( img_size=(img_size, img_size), use_latitude_in_learned_flow=use_latitude_in_learned_flow, ) match time_embedding["type"]: case "linear": self.time_dim = time_embedding["time_dim"] if learned_flow: self.time_dim += 1 self.embedding = LinearEmbedding( img_size, patch_size, in_chans, self.time_dim, embed_dim, drop_rate ) if not self.finetune: self.unembed = LinearDecoder( patch_size=patch_size, out_chans=in_chans, embed_dim=embed_dim ) case "perceiver": self.embedding = PerceiverChannelEmbedding( in_chans=in_chans, img_size=img_size, patch_size=patch_size, time_dim=time_embedding["time_dim"], num_queries=time_embedding["n_queries"], embed_dim=embed_dim, drop_rate=drop_rate, ) if not self.finetune: self.unembed = PerceiverDecoder( embed_dim=embed_dim, patch_size=patch_size, out_chans=in_chans, ) case _: raise NotImplementedError( f'Embedding {time_embedding["type"]} has not been implemented.' ) if isinstance(depth, list): raise NotImplementedError( "Multi scale models are no longer supported. Depth should be a single integer." ) self.backbone = SpectFormer( grid_size=img_size // patch_size, embed_dim=embed_dim, depth=depth, n_spectral_blocks=n_spectral_blocks, num_heads=num_heads, mlp_ratio=mlp_ratio, drop_rate=drop_rate, window_size=window_size, dp_rank=dp_rank, checkpoint_layers=checkpoint_layers, rpe=rpe, ensemble=ensemble, nglo=nglo, ) if init_weights: self.apply(self._init_weights) # @staticmethod # def _checkpoint_wrapper( # model: nn.Module, data: tuple[Tensor, Tensor | None] # ) -> Tensor: # return checkpoint(model, data, use_reentrant=False) def _init_weights(self, module): if self.time_embedding["type"] == "linear": # sampling_step * embed_dim = patch_size**2 * in_chans * time_dim sampling_step = int( np.sqrt( (self.patch_size**2 * self.in_chans * self.time_dim) / self.embed_dim ) ) else: sampling_step = int( np.sqrt((self.patch_size**2 * self.in_chans) / self.embed_dim) ) if isinstance(module, PatchEmbed3D): torch.nn.init.zeros_(module.proj.weight) c_out = 0 w_pool = 1.0 / sampling_step for k in range(self.in_chans * self.time_dim): for i in range(0, self.patch_size, sampling_step): for j in range(0, self.patch_size, sampling_step): module.proj.weight.data[ c_out, k, i : i + sampling_step, j : j + sampling_step ] = w_pool c_out += 1 if module.proj.bias is not None: module.proj.bias.data.zero_() if isinstance(module, BlockSpectralGating): for m in [ module.mlp.fc1, module.mlp.fc2, ]: # m.weight.data.normal_(mean=0.0, std=0.01) # torch.nn.init.eye_(m.weight) torch.nn.init.eye_(m.weight) if m.bias is not None: m.bias.data.zero_() if isinstance(module, BlockAttention): for m in [ module.mlp.fc1, module.mlp.fc2, ]: # torch.nn.init.eye_(m.weight) torch.nn.init.zeros_(m.weight) if m.bias is not None: m.bias.data.zero_() for m in [ module.attn.qkv, module.attn.proj, module.attn.to_dynamic_projection, ]: # m.weight.data.normal_(mean=0.0, std=0.01) # torch.nn.init.eye_(m.weight) torch.nn.init.zeros_(m.weight) if m.bias is not None: m.bias.data.zero_() if isinstance(module, torch.nn.Sequential): if isinstance(module[1], torch.nn.PixelShuffle): # torch.nn.init.eye_(module[0].weight.data[:,:,0,0]) torch.nn.init.zeros_(module[0].weight) if self.time_embedding["type"] == "linear": c_out = 0 for k in range(1, self.in_chans + 1): for i in range( self.patch_size**2 // (self.patch_size * sampling_step) ): for j in range(self.patch_size): module[0].weight.data[ c_out : c_out + sampling_step, j + (k * self.time_dim - 1) * self.patch_size, ] = 1.0 c_out += sampling_step else: c_out = 0 for k in range(2 * self.in_chans): # l = 0 for l_feat in range(self.backbone.embed_dim): module[0].weight.data[c_out, l_feat] = 1.0 c_out += 1 if module[0].bias is not None: module[0].bias.data.zero_() def forward(self, batch): """ Args: batch: Dictionary containing keys `ts` and `time_delta_input`. Their values are tensors with shapes as follows. ts: B, C, T, H, W time_delta_input: B, T Returns: Tensor fo shape (B, C, H, W) for deterministic or (B, E, C, H, W) for ensemble forecasts. """ x = batch["ts"] dt = batch["time_delta_input"] B, C, T, H, W = x.shape if self.learned_flow: y_hat_flow = self.learned_flow_model(batch) # B, C, H, W if any( [param.requires_grad for param in self.learned_flow_model.parameters()] ): return y_hat_flow else: x = torch.concat((x, y_hat_flow.unsqueeze(2)), dim=2) # B, C, T+1, H, W if self.time_embedding["type"] == "perceiver": dt = torch.cat((dt, batch["lead_time_delta"].reshape(-1, 1)), dim=1) # embed the data tokens = self.embedding(x, dt) # copy tokens in case of ensemble forecast if self.ensemble: # B L D -> (B E) L D == BE L D tokens = torch.repeat_interleave(tokens, repeats=self.ensemble, dim=0) # pass the time series through the encoder tokens = self.backbone(tokens) if self.finetune: return tokens # Unembed the tokens # BE L D -> BE C H W forecast_hat = self.unembed(tokens) assert forecast_hat.shape == ( B * self.ensemble if self.ensemble else B, C, H, W, ), f"forecast_hat has shape {forecast_hat.shape} yet expected {(B*self.ensemble if self.ensemble else B, C, H, W)}." if self.learned_flow: assert y_hat_flow.shape == ( B, C, H, W, ), f"y_hat_flow has shape {y_hat_flow.shape} yet expected {(B, C, H, W)}." if self.ensemble: y_hat_flow = torch.repeat_interleave( y_hat_flow, repeats=self.ensemble, dim=0 ) assert y_hat_flow.shape == forecast_hat.shape forecast_hat = forecast_hat + y_hat_flow assert forecast_hat.shape == ( B * self.ensemble if self.ensemble else B, C, H, W, ), f"forecast_hat has shape {forecast_hat.shape} yet expected {(B*self.ensemble if self.ensemble else B, C, H, W)}." if self.ensemble: forecast_hat = rearrange( forecast_hat, "(B E) C H W -> B E C H W", B=B, E=self.ensemble ) return forecast_hat