|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class HelioFlowModel(nn.Module): |
|
def __init__(self, img_size=(4096, 4096), use_latitude_in_learned_flow=False): |
|
super().__init__() |
|
|
|
self.use_latitude_in_learned_flow = use_latitude_in_learned_flow |
|
|
|
u = torch.linspace(-1, 1, img_size[0]) |
|
v = torch.linspace(-1, 1, img_size[1]) |
|
u, v = torch.meshgrid(u, v, indexing="xy") |
|
self.register_buffer( |
|
"grid", torch.stack((u, v), dim=2).view(1, *img_size, 2) |
|
) |
|
|
|
|
|
if self.use_latitude_in_learned_flow: |
|
higher_modes = [u, v, torch.ones_like(u)] |
|
else: |
|
higher_modes = [ |
|
u, |
|
v, |
|
] |
|
self.register_buffer( |
|
"higher_modes", torch.stack(higher_modes, dim=2).view(1, *img_size, -1) |
|
) |
|
|
|
self.flow_generator = nn.Sequential( |
|
nn.Linear(self.higher_modes.shape[3], 128), |
|
nn.GELU(), |
|
nn.Linear(128, 2), |
|
) |
|
|
|
def forward(self, batch): |
|
""" |
|
Args: |
|
batch: Dictionary containing keys `ts` and |
|
`forecast_latitude` (optionally). |
|
ts (torch.Tensor): B, C, T, H, W |
|
forecast_latitude (torch.Tensor): B, L |
|
B - Batch size, C - Channels, T - Input times, H - Image height, |
|
W - Image width, L - Lead time. |
|
""" |
|
|
|
x = batch["ts"] |
|
B, C, T, H, W = x.shape |
|
if T == 1: |
|
x = x[:, :, -1, :, :] |
|
else: |
|
|
|
x = (x[:, :, -1, :, :] + x[:, :, -2, :, :]) / 2 |
|
|
|
|
|
if self.use_latitude_in_learned_flow: |
|
broadcast_lat = batch["forecast_latitude"] / 7 |
|
broadcast_lat = torch.concatenate( |
|
[ |
|
torch.ones_like(broadcast_lat), |
|
torch.ones_like(broadcast_lat), |
|
broadcast_lat, |
|
], |
|
1, |
|
)[:, None, None, :] |
|
higher_modes = self.higher_modes * broadcast_lat |
|
flow_field = self.grid + self.flow_generator(higher_modes) |
|
else: |
|
flow_field = self.grid + self.flow_generator(self.higher_modes) |
|
flow_field = flow_field.expand(B, H, W, 2) |
|
|
|
y_hat = F.grid_sample( |
|
x, |
|
flow_field, |
|
mode="bilinear", |
|
padding_mode="border", |
|
align_corners=False, |
|
) |
|
|
|
return y_hat |
|
|