Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class HelioFlowModel(nn.Module): | |
def __init__(self, img_size=(4096, 4096), use_latitude_in_learned_flow=False): | |
super().__init__() | |
self.use_latitude_in_learned_flow = use_latitude_in_learned_flow | |
u = torch.linspace(-1, 1, img_size[0]) | |
v = torch.linspace(-1, 1, img_size[1]) | |
u, v = torch.meshgrid(u, v, indexing="xy") | |
self.register_buffer( | |
"grid", torch.stack((u, v), dim=2).view(1, *img_size, 2) | |
) # B, H, W, 2 | |
# Higher modes can be used for explicit feature engineering for flow features. | |
if self.use_latitude_in_learned_flow: | |
higher_modes = [u, v, torch.ones_like(u)] | |
else: | |
higher_modes = [ | |
u, | |
v, | |
] | |
self.register_buffer( | |
"higher_modes", torch.stack(higher_modes, dim=2).view(1, *img_size, -1) | |
) | |
self.flow_generator = nn.Sequential( | |
nn.Linear(self.higher_modes.shape[3], 128), | |
nn.GELU(), | |
nn.Linear(128, 2), | |
) | |
def forward(self, batch): | |
""" | |
Args: | |
batch: Dictionary containing keys `ts` and | |
`forecast_latitude` (optionally). | |
ts (torch.Tensor): B, C, T, H, W | |
forecast_latitude (torch.Tensor): B, L | |
B - Batch size, C - Channels, T - Input times, H - Image height, | |
W - Image width, L - Lead time. | |
""" | |
x = batch["ts"] | |
B, C, T, H, W = x.shape | |
if T == 1: | |
x = x[:, :, -1, :, :] | |
else: | |
# Taking the average of the last two time stamps | |
x = (x[:, :, -1, :, :] + x[:, :, -2, :, :]) / 2 | |
# Flow fields have the shape B, H_out, W_out, 2 | |
if self.use_latitude_in_learned_flow: | |
broadcast_lat = batch["forecast_latitude"] / 7 | |
broadcast_lat = torch.concatenate( | |
[ | |
torch.ones_like(broadcast_lat), | |
torch.ones_like(broadcast_lat), | |
broadcast_lat, | |
], | |
1, | |
)[:, None, None, :] | |
higher_modes = self.higher_modes * broadcast_lat | |
flow_field = self.grid + self.flow_generator(higher_modes) | |
else: | |
flow_field = self.grid + self.flow_generator(self.higher_modes) | |
flow_field = flow_field.expand(B, H, W, 2) | |
y_hat = F.grid_sample( | |
x, | |
flow_field, | |
mode="bilinear", | |
padding_mode="border", # Possible values: zeros, border, or reflection. | |
align_corners=False, | |
) | |
return y_hat | |