Spaces:
Runtime error
Runtime error
File size: 2,742 Bytes
b73936d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import torch
import torch.nn as nn
import torch.nn.functional as F
class HelioFlowModel(nn.Module):
def __init__(self, img_size=(4096, 4096), use_latitude_in_learned_flow=False):
super().__init__()
self.use_latitude_in_learned_flow = use_latitude_in_learned_flow
u = torch.linspace(-1, 1, img_size[0])
v = torch.linspace(-1, 1, img_size[1])
u, v = torch.meshgrid(u, v, indexing="xy")
self.register_buffer(
"grid", torch.stack((u, v), dim=2).view(1, *img_size, 2)
) # B, H, W, 2
# Higher modes can be used for explicit feature engineering for flow features.
if self.use_latitude_in_learned_flow:
higher_modes = [u, v, torch.ones_like(u)]
else:
higher_modes = [
u,
v,
]
self.register_buffer(
"higher_modes", torch.stack(higher_modes, dim=2).view(1, *img_size, -1)
)
self.flow_generator = nn.Sequential(
nn.Linear(self.higher_modes.shape[3], 128),
nn.GELU(),
nn.Linear(128, 2),
)
def forward(self, batch):
"""
Args:
batch: Dictionary containing keys `ts` and
`forecast_latitude` (optionally).
ts (torch.Tensor): B, C, T, H, W
forecast_latitude (torch.Tensor): B, L
B - Batch size, C - Channels, T - Input times, H - Image height,
W - Image width, L - Lead time.
"""
x = batch["ts"]
B, C, T, H, W = x.shape
if T == 1:
x = x[:, :, -1, :, :]
else:
# Taking the average of the last two time stamps
x = (x[:, :, -1, :, :] + x[:, :, -2, :, :]) / 2
# Flow fields have the shape B, H_out, W_out, 2
if self.use_latitude_in_learned_flow:
broadcast_lat = batch["forecast_latitude"] / 7
broadcast_lat = torch.concatenate(
[
torch.ones_like(broadcast_lat),
torch.ones_like(broadcast_lat),
broadcast_lat,
],
1,
)[:, None, None, :]
higher_modes = self.higher_modes * broadcast_lat
flow_field = self.grid + self.flow_generator(higher_modes)
else:
flow_field = self.grid + self.flow_generator(self.higher_modes)
flow_field = flow_field.expand(B, H, W, 2)
y_hat = F.grid_sample(
x,
flow_field,
mode="bilinear",
padding_mode="border", # Possible values: zeros, border, or reflection.
align_corners=False,
)
return y_hat
|