johannesschmude commited on
Commit
b73936d
·
1 Parent(s): 3808ef8

Initial commit

Browse files
app.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import socket
2
+ import yaml
3
+ import logging
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch.utils.data import DataLoader
12
+
13
+ import matplotlib.pyplot as plt
14
+ import sunpy.visualization.colormaps as sunpy_cm
15
+
16
+ import gradio as gr
17
+ from huggingface_hub import snapshot_download
18
+
19
+ from surya.datasets.helio import HelioNetCDFDataset, inverse_transform_single_channel
20
+ from surya.models.helio_spectformer import HelioSpectFormer
21
+ from surya.utils.data import build_scalers, custom_collate_fn
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ SDO_CHANNELS = [
26
+ "aia94",
27
+ "aia131",
28
+ "aia171",
29
+ "aia193",
30
+ "aia211",
31
+ "aia304",
32
+ "aia335",
33
+ "aia1600",
34
+ "hmi_m",
35
+ "hmi_bx",
36
+ "hmi_by",
37
+ "hmi_bz",
38
+ "hmi_v",
39
+ ]
40
+
41
+ @dataclass
42
+ class SDOImage:
43
+ channel: str
44
+ data: np.ndarray
45
+ timestamp: str
46
+ type: str
47
+
48
+ def download_data():
49
+ snapshot_download(
50
+ repo_id="nasa-ibm-ai4science/Surya-1.0",
51
+ local_dir="data/Surya-1.0",
52
+ allow_patterns=["config.yaml", "scalers.yaml", "surya.366m.v1.pt"],
53
+ token=None,
54
+ )
55
+ snapshot_download(
56
+ repo_id="nasa-ibm-ai4science/Surya-1.0_validation_data",
57
+ repo_type="dataset",
58
+ local_dir="data/Surya-1.0_validation_data",
59
+ allow_patterns="20140107_1[5-9]??.nc",
60
+ token=None,
61
+ )
62
+
63
+ def get_dataset(config, scalers) -> HelioNetCDFDataset:
64
+ dataset = HelioNetCDFDataset(
65
+ index_path="tests/test_surya_index.csv",
66
+ time_delta_input_minutes=config["data"]["time_delta_input_minutes"],
67
+ time_delta_target_minutes=config["data"]["time_delta_target_minutes"],
68
+ n_input_timestamps=len(config["data"]["time_delta_input_minutes"]),
69
+ rollout_steps=0,
70
+ channels=config["data"]["sdo_channels"],
71
+ drop_hmi_probability=config["data"]["drop_hmi_probability"],
72
+ num_mask_aia_channels=config["data"]["num_mask_aia_channels"],
73
+ use_latitude_in_learned_flow=config["data"]["use_latitude_in_learned_flow"],
74
+ scalers=scalers,
75
+ phase="valid",
76
+ pooling=config["data"]["pooling"],
77
+ random_vert_flip=config["data"]["random_vert_flip"],
78
+ )
79
+ logger.info(f"Initialized the dataset. {len(dataset)} samples.")
80
+
81
+ return dataset
82
+
83
+ def get_scalers() -> dict:
84
+ scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r"))
85
+ scalers = build_scalers(info=scalers_info)
86
+ logger.info("Built the scalers.")
87
+ return scalers
88
+
89
+ def get_model_from_config(config) -> HelioSpectFormer:
90
+ model = HelioSpectFormer(
91
+ img_size=config["model"]["img_size"],
92
+ patch_size=config["model"]["patch_size"],
93
+ in_chans=len(config["data"]["sdo_channels"]),
94
+ embed_dim=config["model"]["embed_dim"],
95
+ time_embedding={
96
+ "type": "linear",
97
+ "time_dim": len(config["data"]["time_delta_input_minutes"]),
98
+ },
99
+ depth=config["model"]["depth"],
100
+ n_spectral_blocks=config["model"]["n_spectral_blocks"],
101
+ num_heads=config["model"]["num_heads"],
102
+ mlp_ratio=config["model"]["mlp_ratio"],
103
+ drop_rate=config["model"]["drop_rate"],
104
+ dtype=torch.bfloat16,
105
+ window_size=config["model"]["window_size"],
106
+ dp_rank=config["model"]["dp_rank"],
107
+ learned_flow=config["model"]["learned_flow"],
108
+ use_latitude_in_learned_flow=config["model"]["learned_flow"],
109
+ init_weights=False,
110
+ checkpoint_layers=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
111
+ rpe=config["model"]["rpe"],
112
+ ensemble=config["model"]["ensemble"],
113
+ finetune=config["model"]["finetune"],
114
+ )
115
+ logger.info("Initialized the model.")
116
+
117
+ return model
118
+
119
+ def get_config() -> dict:
120
+ with open("data/Surya-1.0/config.yaml") as fp:
121
+ config = yaml.safe_load(fp)
122
+
123
+ return config
124
+
125
+ def setup():
126
+ logger.info("Loading data ...")
127
+ download_data()
128
+ config = get_config()
129
+ scalers = get_scalers()
130
+
131
+ logger.info("Initializing dataset ...")
132
+ dataset = get_dataset(config, scalers)
133
+
134
+ logger.info("Initializing model ...")
135
+ model = get_model_from_config(config)
136
+ if torch.cuda.is_available():
137
+ device = torch.cuda.current_device()
138
+ logger.info(f"GPU detected. Running the test on device {device}.")
139
+ else:
140
+ device = "cpu"
141
+ logger.warning(f"No GPU detected. Running the test on CPU.")
142
+ model.to(device)
143
+ n_parameters = sum(p.numel() for p in model.parameters()) / 1e6
144
+ logger.info(f"Surya FM: {n_parameters:.2f} M total parameters.")
145
+ path_weights = "data/Surya-1.0/surya.366m.v1.pt"
146
+ weights = torch.load(
147
+ path_weights, map_location=torch.device(device), weights_only=True
148
+ )
149
+ model.load_state_dict(weights, strict=True)
150
+ logger.info("Loaded weights.")
151
+
152
+ return dataset, model, device
153
+
154
+ def batch_step(
155
+ model: HelioSpectFormer,
156
+ sample_data: dict,
157
+ sample_metadata: dict,
158
+ device: int | str,
159
+ hours_ahead: int = 1,
160
+ ) -> np.ndarray:
161
+ """
162
+ Perform a single batch step for the given model, batch data, metadata, and device.
163
+
164
+ Args:
165
+ model: The PyTorch model to use for prediction.
166
+ sample_data: A dictionary containing input and target data for the batch.
167
+ sample_metadata: A dictionary containing metadata for the batch, including timestamps.
168
+ device: The device to use for computation ('cpu', 'cuda' or device number).
169
+ hours_ahead: The number of steps to forecast ahead. Defaults to 1.
170
+
171
+ Returns:
172
+ np.ndarray: Output data.
173
+ """
174
+
175
+ data_returned = []
176
+ forecast_hat = None # Initialize forecast_hat
177
+
178
+ for step in range(1, hours_ahead + 1):
179
+ if step == 1:
180
+ curr_batch = {
181
+ key: torch.from_numpy(sample_data[key]).unsqueeze(0).to(device)
182
+ for key in ["ts", "time_delta_input"]
183
+ }
184
+ else:
185
+ # Use the previous forecast_hat from the previous iteration
186
+ if forecast_hat is not None:
187
+ curr_batch["ts"] = torch.cat(
188
+ (curr_batch["ts"][:, :, 1:, ...], forecast_hat[:, :, None, ...]),
189
+ dim=2,
190
+ )
191
+
192
+ forecast_hat = model(curr_batch)
193
+
194
+ data_returned = forecast_hat.to(dtype=torch.float32).cpu().squeeze(0).numpy()
195
+
196
+ return data_returned
197
+
198
+
199
+ def run_inference(init_time_idx, plt_channel_idx, hours_ahead):
200
+ plt_channel_str = SDO_CHANNELS[plt_channel_idx]
201
+
202
+ input_timestamp_1 = dataset.valid_indices[init_time_idx]
203
+ input_timestamp_0 = input_timestamp_1 - pd.Timedelta(1, "h")
204
+ output_timestamp = input_timestamp_1 + pd.Timedelta(int(hours_ahead), "h")
205
+
206
+ input_timestamp_0 = input_timestamp_0.strftime("%Y-%m-%d %H:%M")
207
+ input_timestamp_1 = input_timestamp_1.strftime("%Y-%m-%d %H:%M")
208
+ output_timestamp = output_timestamp.strftime("%Y-%m-%d %H:%M")
209
+
210
+ sample_data, sample_metadata = dataset[init_time_idx]
211
+ with torch.no_grad():
212
+ model_output = batch_step(
213
+ model,
214
+ sample_data,
215
+ sample_metadata,
216
+ device,
217
+ hours_ahead
218
+ )
219
+
220
+ means, stds, epsilons, sl_scale_factors = dataset.transformation_inputs()
221
+
222
+ vmin = float("-inf")
223
+ vmax = float("inf")
224
+ input_image = []
225
+ for i in range(2):
226
+ input_image.append(
227
+ inverse_transform_single_channel(
228
+ sample_data["ts"][plt_channel_idx, i],
229
+ mean=means[plt_channel_idx],
230
+ std=stds[plt_channel_idx],
231
+ epsilon=epsilons[plt_channel_idx],
232
+ sl_scale_factor=sl_scale_factors[plt_channel_idx],
233
+ )
234
+ )
235
+ vmin = max(vmin, sample_data["ts"][plt_channel_idx, i].min())
236
+ #vmax = min(vmax, np.quantile(sample_data["ts"][plt_channel_idx, i], 0.99))
237
+ vmax = min(vmax, sample_data["ts"][plt_channel_idx, i].max())
238
+
239
+ if plt_channel_str.startswith("aia"):
240
+ cm_name = "sdo" + plt_channel_str
241
+ else:
242
+ cm_name = "hmimag"
243
+
244
+ input_image = [
245
+ sunpy_cm.cmlist[cm_name](
246
+ (img[::-1]-vmin) / (vmax-vmin), bytes=True
247
+ )
248
+ for img in input_image
249
+ ]
250
+
251
+ output_image = inverse_transform_single_channel(
252
+ model_output[plt_channel_idx],
253
+ mean=means[plt_channel_idx],
254
+ std=stds[plt_channel_idx],
255
+ epsilon=epsilons[plt_channel_idx],
256
+ sl_scale_factor=sl_scale_factors[plt_channel_idx],
257
+ )
258
+ output_image = sunpy_cm.cmlist[cm_name](
259
+ (output_image[::-1]-vmin) / (vmax-vmin), bytes=True
260
+ )
261
+
262
+ return input_timestamp_0, input_image[0], input_timestamp_1, input_image[1], output_timestamp, output_image
263
+
264
+ logging.basicConfig(level=logging.INFO)
265
+ dataset, model, device = setup()
266
+ hostname = socket.getfqdn()
267
+ logging.info(f"Launching app on {hostname}")
268
+
269
+ with gr.Blocks() as demo:
270
+ gr.Markdown(value="# Surya 1.0 - Visual forecasting demo")
271
+ #with gr.Row():
272
+ #with gr.Column():
273
+ with gr.Row():
274
+ with gr.Column():
275
+ init_time = gr.Dropdown(
276
+ [v.strftime("%Y-%m-%d %H:%M") for v in dataset.valid_indices],
277
+ label="Initialization time",
278
+ multiselect=False,
279
+ type="index",
280
+ )
281
+ with gr.Column():
282
+ plt_channel = gr.Dropdown(
283
+ [c.upper() for c in SDO_CHANNELS],
284
+ label="SDO Band",
285
+ value="AIA94",
286
+ multiselect=False,
287
+ type="index"
288
+ )
289
+ with gr.Row():
290
+ hours_ahead = gr.Slider(minimum=1.0, maximum=6.0, step=1.0, label="Forcast step [hours ahead]")
291
+ with gr.Row():
292
+ btn = gr.Button("Run")
293
+
294
+ with gr.Row():
295
+ with gr.Column():
296
+ input_timestamp_0 = gr.Textbox(label="Input 0")
297
+ input_image_0 = gr.Image()
298
+ with gr.Column():
299
+ input_timestamp_1 = gr.Textbox(label="Input 1")
300
+ input_image_1 = gr.Image()
301
+ with gr.Column():
302
+ output_timestamp = gr.Textbox(label="Prediction")
303
+ output_image = gr.Image()
304
+
305
+ btn.click(
306
+ fn=run_inference,
307
+ inputs=[init_time, plt_channel, hours_ahead],
308
+ outputs=[input_timestamp_0, input_image_0, input_timestamp_1, input_image_1, output_timestamp, output_image]
309
+ )
310
+
311
+ with gr.Row():
312
+ gr.Examples(
313
+ examples=[
314
+ ["2014-01-07 17:24", "AIA94", 2],
315
+ ["2014-01-07 16:12", "AIA94", 6],
316
+ ["2014-01-07 16:00", "AIA131", 1],
317
+ ["2014-01-07 16:00", "HMI_M", 2],
318
+ ],
319
+ fn=run_inference,
320
+ inputs=[init_time, plt_channel, hours_ahead],
321
+ outputs=[input_timestamp_0, input_image_0, input_timestamp_1, input_image_1, output_timestamp, output_image],
322
+ cache_examples=False,
323
+ )
324
+
325
+ demo.launch(server_name=hostname, server_port=None)
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ einops==0.8.1
2
+ gradio==5.43.1
3
+ hdf5plugin==5.1.0
4
+ huggingface_hub==0.34.4
5
+ matplotlib==3.10.5
6
+ numba==0.61.2
7
+ numpy==2.3.2
8
+ packaging==25.0
9
+ pandas==2.3.1
10
+ PyYAML==6.0.2
11
+ PyYAML==6.0.2
12
+ skimage==0.0
13
+ sunpy==6.1.1
14
+ timm==1.0.19
15
+ torch==2.6.0
16
+ wandb==0.21.1
17
+ xarray==2025.3.1
surya/__init__.py ADDED
File without changes
surya/datasets/__init__.py ADDED
File without changes
surya/datasets/helio.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import random
4
+ from datetime import datetime
5
+ import torch
6
+ import numpy as np
7
+ import skimage.measure
8
+ import xarray as xr
9
+ import pandas as pd
10
+ from logging import Logger
11
+ from torch.utils.data import Dataset
12
+ from surya.utils.distributed import get_rank
13
+ from surya.utils.log import create_logger
14
+ from functools import cache
15
+
16
+ from numba import njit, prange
17
+
18
+ import hdf5plugin
19
+
20
+
21
+ @njit(parallel=True)
22
+ def fast_transform(data, means, stds, sl_scale_factors, epsilons):
23
+ """
24
+ Implements signum log transform using numba for speed
25
+ Notes:
26
+ - This must reside outside the class definition from which it is called.
27
+ - We used this function during pretraining for faster data loading. On select
28
+ GPU clusters it leads to the system hanging however when data loading happens
29
+ outside the GPU thread. See below for a non-numba-enhanced version.
30
+
31
+ Args:
32
+ data: Numpy array of shape C, H, W
33
+ means: Numpy array of shape C. Mean per channel.
34
+ stds: Numpy array of shape C. Standard deviation per channel.
35
+ sl_scale_factors: Numpy array of shape C. Signum-log scale factors.
36
+ epsilons: Numpy array of shape C. Constant to avoid zero division.
37
+
38
+ Returns:
39
+ Numpy array of shape C, H, W.
40
+ """
41
+ C, H, W = data.shape
42
+ out = np.empty((C, H, W), dtype=np.float32)
43
+ for c in prange(C):
44
+ mean = means[c]
45
+ std = stds[c]
46
+ eps = epsilons[c]
47
+ sl_scale_factor = sl_scale_factors[c]
48
+ for i in range(H):
49
+ for j in range(W):
50
+ val = data[c, i, j]
51
+ val = val * sl_scale_factor
52
+ if val >= 0:
53
+ val = np.log1p(val)
54
+ else:
55
+ val = -np.log1p(-val)
56
+ out[c, i, j] = (val - mean) / (std + eps)
57
+ return out
58
+
59
+ def transform(
60
+ data: np.ndarray,
61
+ means: np.ndarray,
62
+ stds: np.ndarray,
63
+ sl_scale_factors: np.ndarray,
64
+ epsilons: np.ndarray
65
+ ) -> np.ndarray:
66
+ """
67
+ Implements signum log transform. Drop-in replacement for
68
+ `fast_transform` method above.
69
+
70
+ Args:
71
+ data: Numpy array of shape C, H, W
72
+ means: Numpy array of shape C. Mean per channel.
73
+ stds: Numpy array of shape C. Standard deviation per channel.
74
+ sl_scale_factors: Numpy array of shape C. Signum-log scale factors.
75
+ epsilons: Numpy array of shape C. Constant to avoid zero division.
76
+
77
+ Returns:
78
+ Numpy array of shape C, H, W.
79
+ """
80
+ means = means.reshape(*means.shape, 1, 1)
81
+ stds = stds.reshape(*stds.shape, 1, 1)
82
+ sl_scale_factors = sl_scale_factors.reshape(*sl_scale_factors.shape, 1, 1)
83
+ epsilons = epsilons.reshape(*epsilons.shape, 1, 1)
84
+
85
+ data = data * sl_scale_factors
86
+ data = np.sign(data) * np.log1p(np.abs(data))
87
+ data = (data - means) / (stds + epsilons)
88
+
89
+ return data
90
+
91
+ @njit(parallel=True)
92
+ def inverse_fast_transform(data, means, stds, sl_scale_factors, epsilons):
93
+ """
94
+ Implements inverse signum log transform using numba for speed
95
+
96
+ Args:
97
+ data: Numpy array of shape C, H, W
98
+ means: Numpy array of shape C. Mean per channel.
99
+ stds: Numpy array of shape C. Standard deviation per channel.
100
+ sl_scale_factors: Numpy array of shape C. Signum-log scale factors.
101
+ epsilons: Numpy array of shape C. Constant to avoid zero division.
102
+
103
+ Returns:
104
+ Numpy array of shape C, H, W.
105
+ """
106
+ C, H, W = data.shape
107
+ out = np.empty((C, H, W), dtype=np.float32)
108
+
109
+ for c in prange(C):
110
+ mean = means[c]
111
+ std = stds[c]
112
+ eps = epsilons[c]
113
+ sl_scale_factor = sl_scale_factors[c]
114
+
115
+ for i in range(H):
116
+ for j in range(W):
117
+ val = data[c, i, j]
118
+ val = val * (std + eps) + mean
119
+
120
+ if val >= 0:
121
+ val = np.expm1(val)
122
+ else:
123
+ val = -np.expm1(-val)
124
+
125
+ val = val / sl_scale_factor
126
+
127
+ out[c, i, j] = val
128
+
129
+ return out
130
+
131
+
132
+ def inverse_transform_single_channel(data, mean, std, sl_scale_factor, epsilon):
133
+ """
134
+ Implements inverse signum log transform.
135
+
136
+ Args:
137
+ data: Numpy array of shape C, H, W
138
+ means: Numpy array of shape C. Mean per channel.
139
+ stds: Numpy array of shape C. Standard deviation per channel.
140
+ sl_scale_factors: Numpy array of shape C. Signum-log scale factors.
141
+ epsilons: Numpy array of shape C. Constant to avoid zero division.
142
+
143
+ Returns:
144
+ Numpy array of shape C, H, W.
145
+ """
146
+ data = data * (std + epsilon) + mean
147
+
148
+ data = np.sign(data) * np.expm1(np.abs(data))
149
+
150
+ data = data / sl_scale_factor
151
+
152
+ return data
153
+
154
+
155
+ class RandomChannelMaskerTransform:
156
+ def __init__(
157
+ self, num_channels, num_mask_aia_channels, phase, drop_hmi_probability
158
+ ):
159
+ """
160
+ Initialize the RandomChannelMaskerTransform class as a transform.
161
+
162
+ Args:
163
+ - num_channels: Total number of channels in the input (3rd dimension of
164
+ the tensor).
165
+ - num_mask_aia_channels: Number of channels to randomly mask.
166
+ """
167
+ self.num_channels = num_channels
168
+ self.num_mask_aia_channels = num_mask_aia_channels
169
+ self.drop_hmi_probability = drop_hmi_probability
170
+
171
+ def __call__(self, input_tensor):
172
+ C, T, H, W = input_tensor.shape # Unpacking the correct 5 dimensions
173
+
174
+ # Randomly select channels to mask
175
+ channels_to_mask = random.sample(range(C), self.num_mask_aia_channels)
176
+
177
+ # Create an in-place mask of shape [1, 1, num_channels, 1, 1]
178
+ mask = torch.ones((C, 1, 1, 1))
179
+ mask[channels_to_mask, ...] = 0 # Set selected channels to zero
180
+
181
+ # Apply the mask in-place for memory efficiency
182
+ masked_tensor = input_tensor * mask # Modify input_tensor directly
183
+
184
+ if self.drop_hmi_probability > random.random():
185
+ masked_tensor[-1, ...] = 0
186
+
187
+ return masked_tensor
188
+
189
+
190
+ class HelioNetCDFDataset(Dataset):
191
+ """
192
+ PyTorch dataset to load a curated dataset from the NASA Solar Dynamics
193
+ Observatory (SDO) mission stored as NetCDF files, with handling for variable timesteps.
194
+
195
+ Internally maintains two databases. The first is `self.index`. This takes the
196
+ form
197
+ path present
198
+ timestep
199
+ 2011-01-01 00:00:00 /lustre/fs0/scratch/shared/data/2011/01/Arka_2... 1
200
+ 2011-01-01 00:12:00 /lustre/fs0/scratch/shared/data/2011/01/Arka_2... 1
201
+ ... ... ...
202
+ 2012-11-30 23:48:00 /lustre/fs0/scratch/shared/data/2012/11/Arka_2... 1
203
+
204
+ The second is `self.valid_indices`. This is simply a list of timesteps -- entries
205
+ in the index of `self.index` -- which define valid samples. A sample is valid
206
+ when all timestamps that can be reached by entris in
207
+ time_delta_input_minutes and time_delta_target_minutes can be reached from it
208
+ are present.
209
+ """
210
+
211
+ def __init__(
212
+ self,
213
+ index_path: str,
214
+ time_delta_input_minutes: list[int],
215
+ time_delta_target_minutes: int,
216
+ n_input_timestamps: int,
217
+ rollout_steps: int,
218
+ scalers=None,
219
+ num_mask_aia_channels: int = 0,
220
+ drop_hmi_probability: float = 0.0,
221
+ use_latitude_in_learned_flow=False,
222
+ channels: list[str] | None = None,
223
+ phase="train",
224
+ pooling: int | None = None,
225
+ random_vert_flip: bool = False,
226
+ ):
227
+ self.scalers = scalers
228
+ self.phase = phase
229
+ self.channels = channels
230
+ self.num_mask_aia_channels = num_mask_aia_channels
231
+ self.drop_hmi_probability = drop_hmi_probability
232
+ self.n_input_timestamps = n_input_timestamps
233
+ self.rollout_steps = rollout_steps
234
+ self.use_latitude_in_learned_flow = use_latitude_in_learned_flow
235
+ self.pooling = pooling if pooling is not None else 1
236
+ self.random_vert_flip = random_vert_flip
237
+
238
+ if self.channels is None:
239
+ # AIA + HMI channels
240
+ self.channels = [
241
+ "0094",
242
+ "0131",
243
+ "0171",
244
+ "0193",
245
+ "0211",
246
+ "0304",
247
+ "0335",
248
+ "hmi",
249
+ ]
250
+ self.in_channels = len(self.channels)
251
+
252
+ self.masker = RandomChannelMaskerTransform(
253
+ num_channels=self.in_channels,
254
+ num_mask_aia_channels=self.num_mask_aia_channels,
255
+ phase=self.phase,
256
+ drop_hmi_probability=self.drop_hmi_probability,
257
+ )
258
+
259
+ # Convert time delta to numpy timedelta64
260
+ self.time_delta_input_minutes = sorted(
261
+ np.timedelta64(t, "m") for t in time_delta_input_minutes
262
+ )
263
+ self.time_delta_target_minutes = [
264
+ np.timedelta64(iroll * time_delta_target_minutes, "m")
265
+ for iroll in range(1, rollout_steps + 2)
266
+ ]
267
+
268
+ # Create the index
269
+ self.index = pd.read_csv(index_path)
270
+ self.index = self.index[self.index["present"] == 1]
271
+ self.index["timestep"] = pd.to_datetime(self.index["timestep"]).values.astype(
272
+ "datetime64[ns]"
273
+ )
274
+ self.index.set_index("timestep", inplace=True)
275
+ self.index.sort_index(inplace=True)
276
+
277
+ # Filter out rows where the sequence is not fully present
278
+ self.valid_indices = self.filter_valid_indices()
279
+ self.adjusted_length = len(self.valid_indices)
280
+
281
+ self.rank = get_rank()
282
+ self.logger: Logger | None = None
283
+
284
+ def create_logger(self):
285
+ """
286
+ Creates a logger attached to self.logger.
287
+ The logger is identified by SLURM job ID
288
+ as well as the data processes rank and process ID.
289
+ """
290
+ os.makedirs("logs/data", exist_ok=True)
291
+ timestamp = datetime.now().strftime("%Y%m%dT%H%M%SZ")
292
+ pid = os.getpid()
293
+ self.logger = create_logger(
294
+ output_dir="logs/data",
295
+ dist_rank=self.rank,
296
+ name=f"{timestamp}_{self.rank:>03}_data_{self.phase}_{pid}",
297
+ )
298
+
299
+ def filter_valid_indices(self):
300
+ """
301
+ Extracts timestamps from the index of self.index that define valid
302
+ samples.
303
+
304
+ Args:
305
+ Returns:
306
+ List of timestamps.
307
+ """
308
+
309
+ valid_indices = []
310
+ time_deltas = np.unique(
311
+ self.time_delta_input_minutes + self.time_delta_target_minutes
312
+ )
313
+
314
+ for reference_timestep in self.index.index:
315
+ required_timesteps = reference_timestep + time_deltas
316
+
317
+ if all(t in self.index.index for t in required_timesteps):
318
+ valid_indices.append(reference_timestep)
319
+
320
+ return valid_indices
321
+
322
+ def __len__(self):
323
+ return self.adjusted_length
324
+
325
+ def __getitem__(self, idx: int) -> dict:
326
+ """
327
+ Args:
328
+ idx: Index of sample to load. (Pytorch standard.)
329
+ Returns:
330
+ Dictionary with following keys. The values are tensors with shape as follows:
331
+ ts (torch.Tensor): C, T, H, W
332
+ time_delta_input (torch.Tensor): T
333
+ input_latitude (torch.Tensor): T
334
+ forecast (torch.Tensor): C, L, H, W
335
+ lead_time_delta (torch.Tensor): L
336
+ forecast_latitude (torch.Tensor): L
337
+ C - Channels, T - Input times, H - Image height, W - Image width, L - Lead time.
338
+ """
339
+ if self.logger is None:
340
+ self.create_logger()
341
+ self.logger.info(f"HelioNetCDFDataset of length {self.__len__()}.")
342
+
343
+ exception_counter = 0
344
+ max_exception = 100
345
+
346
+ self.logger.info(f"Starting to retrieve index {idx}.")
347
+
348
+ while True:
349
+ try:
350
+ sample = self._get_index_data(idx)
351
+ except Exception as e:
352
+ exception_counter += 1
353
+ if exception_counter >= max_exception:
354
+ raise e
355
+
356
+ reference_timestep = self.valid_indices[idx]
357
+ self.logger.warning(
358
+ f"Failed retrieving index {idx}. Timestamp {reference_timestep}. Attempt {exception_counter}."
359
+ )
360
+
361
+ idx = (idx + 1) % self.__len__()
362
+ else:
363
+ self.logger.info(f"Returning index {idx}.")
364
+ return sample
365
+
366
+ def _get_index_data(self, idx: int) -> dict:
367
+ """
368
+ Args:
369
+ idx: Index of sample to load. (Pytorch standard.)
370
+ Returns:
371
+ Dictionary with following keys. The values are tensors with shape as follows:
372
+ ts (torch.Tensor): C, T, H, W
373
+ time_delta_input (torch.Tensor): T
374
+ input_latitude (torch.Tensor): T
375
+ forecast (torch.Tensor): C, L, H, W
376
+ lead_time_delta (torch.Tensor): L
377
+ forecast_latitude (torch.Tensor): L
378
+ C - Channels, T - Input times, H - Image height, W - Image width, L - Lead time.
379
+ """
380
+ # start_time = time.time()
381
+
382
+ time_deltas = np.array(
383
+ sorted(
384
+ random.sample(
385
+ self.time_delta_input_minutes[:-1], self.n_input_timestamps - 1
386
+ )
387
+ )
388
+ + [self.time_delta_input_minutes[-1]]
389
+ + self.time_delta_target_minutes
390
+ )
391
+ reference_timestep = self.valid_indices[idx]
392
+ required_timesteps = reference_timestep + time_deltas
393
+
394
+ sequence_data = [
395
+ self.transform_data(
396
+ self.load_nc_data(
397
+ self.index.loc[timestep, "path"], timestep, self.channels
398
+ )
399
+ )
400
+ for timestep in required_timesteps
401
+ ]
402
+
403
+ # Split sequence_data into inputs and target
404
+ inputs = sequence_data[: -self.rollout_steps - 1]
405
+ targets = sequence_data[-self.rollout_steps - 1 :]
406
+
407
+ stacked_inputs = np.stack(inputs, axis=1)
408
+ stacked_targets = np.stack(targets, axis=1)
409
+
410
+ timestamps_input = required_timesteps[: -self.rollout_steps - 1]
411
+ timestamps_targets = required_timesteps[-self.rollout_steps - 1 :]
412
+
413
+ if self.num_mask_aia_channels > 0 or self.drop_hmi_probability:
414
+ # assert 0 < self.num_mask_aia_channels < self.in_channels, \
415
+ # f'num_mask_aia_channels = {self.num_mask_aia_channels} should lie between 0 and {self.in_channels}'
416
+
417
+ stacked_inputs = self.masker(stacked_inputs)
418
+
419
+ time_delta_input_float = (
420
+ time_deltas[-self.rollout_steps - 2]
421
+ - time_deltas[: -self.rollout_steps - 1]
422
+ ) / np.timedelta64(1, "h")
423
+ time_delta_input_float = time_delta_input_float.astype(np.float32)
424
+
425
+ lead_time_delta_float = (
426
+ time_deltas[-self.rollout_steps - 2]
427
+ - time_deltas[-self.rollout_steps - 1 :]
428
+ ) / np.timedelta64(1, "h")
429
+ lead_time_delta_float = lead_time_delta_float.astype(np.float32)
430
+
431
+ # print('LocalRank', int(os.environ["LOCAL_RANK"]),
432
+ # 'GlobalRank', int(os.environ["RANK"]),
433
+ # 'worker', torch.utils.data.get_worker_info().id,
434
+ # f': Processed Input: {idx} ',time.time()- start_time)
435
+
436
+ metadata = {
437
+ "timestamps_input": timestamps_input,
438
+ "timestamps_targets": timestamps_targets,
439
+ }
440
+
441
+ if self.random_vert_flip:
442
+ if torch.bernoulli(torch.ones(()) / 2) == 1:
443
+ stacked_inputs = torch.flip(stacked_inputs, dims=-2)
444
+ stacked_targets = torch.flip(stacked_inputs, dims=-2)
445
+
446
+ if self.use_latitude_in_learned_flow:
447
+ from sunpy.coordinates.ephemeris import get_earth
448
+
449
+ sequence_latitude = [
450
+ get_earth(timestep).lat.value for timestep in required_timesteps
451
+ ]
452
+ input_latitudes = sequence_latitude[: -self.rollout_steps - 1]
453
+ target_latitude = sequence_latitude[-self.rollout_steps - 1 :]
454
+
455
+ return {
456
+ "ts": stacked_inputs,
457
+ "time_delta_input": time_delta_input_float,
458
+ "input_latitudes": input_latitudes,
459
+ "forecast": stacked_targets,
460
+ "lead_time_delta": lead_time_delta_float,
461
+ "forecast_latitude": target_latitude,
462
+ }, metadata
463
+
464
+ return {
465
+ "ts": stacked_inputs,
466
+ "time_delta_input": time_delta_input_float,
467
+ "forecast": stacked_targets,
468
+ "lead_time_delta": lead_time_delta_float,
469
+ }, metadata
470
+
471
+ def load_nc_data(
472
+ self, filepath: str, timestep: pd.Timestamp, channels: list[str]
473
+ ) -> np.ndarray:
474
+ """
475
+ Args:
476
+ filepath: String or Pathlike. Points to NetCDF file to open.
477
+ timestep: Identifies timestamp to retrieve.
478
+ Returns:
479
+ Numpy array of shape (C, H, W).
480
+ """
481
+ self.logger.info(f"Reading file {filepath}.")
482
+
483
+ with xr.open_dataset(
484
+ filepath, engine="h5netcdf", chunks=None, cache=False,
485
+ ) as ds:
486
+ data = ds[channels].to_array().load().to_numpy()
487
+
488
+ return data
489
+
490
+ @cache
491
+ def transformation_inputs(self) -> (np.ndarray, np.ndarray, np.ndarray, np.ndarray):
492
+ means = np.array([self.scalers[ch].mean for ch in self.channels])
493
+ stds = np.array([self.scalers[ch].std for ch in self.channels])
494
+ epsilons = np.array([self.scalers[ch].epsilon for ch in self.channels])
495
+ sl_scale_factors = np.array(
496
+ [self.scalers[ch].sl_scale_factor for ch in self.channels]
497
+ )
498
+
499
+ return means, stds, epsilons, sl_scale_factors
500
+
501
+ def transform_data(self, data: np.ndarray) -> np.ndarray:
502
+ """
503
+ Applies scalers.
504
+
505
+ Args:
506
+ data: Numpy array of shape (C, H, W)
507
+ Returns:
508
+ Tensor of shape (C, H, W). Data type float32.
509
+ Uses:
510
+ numba to speed up transform
511
+ tvk-srm-heliofm environment cloned from srm-heliofm with numba added
512
+ tvk_dgx_slurm.sh shell script modified to use new environment and new jobname
513
+ train_spectformer_dgx.yaml new jobname
514
+ """
515
+ assert data.ndim == 3
516
+
517
+ if self.pooling > 1:
518
+ data = skimage.measure.block_reduce(
519
+ data, block_size=(1, self.pooling, self.pooling), func=np.mean
520
+ )
521
+
522
+ means, stds, epsilons, sl_scale_factors = self.transformation_inputs()
523
+ result_np = transform(data, means, stds, sl_scale_factors, epsilons)
524
+ return result_np
surya/datasets/transformations.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from logging import info
3
+ from typing import Tuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ import xarray as xr
8
+
9
+
10
+ class Transformation(object):
11
+ @abc.abstractmethod
12
+ def fit(self, data: xr.DataArray):
13
+ raise NotImplementedError()
14
+
15
+ @abc.abstractmethod
16
+ def transform(self, data: xr.DataArray):
17
+ raise NotImplementedError()
18
+
19
+ @abc.abstractmethod
20
+ def inverse_transform(self, data: xr.DataArray):
21
+ raise NotImplementedError()
22
+
23
+ @abc.abstractmethod
24
+ def fit_transform(self, data: xr.DataArray):
25
+ return self.fit(data).transform(data)
26
+
27
+ @abc.abstractmethod
28
+ def to_dict(self) -> dict:
29
+ raise NotImplementedError()
30
+
31
+ @staticmethod
32
+ @abc.abstractmethod
33
+ def from_dict(info: dict):
34
+ raise NotImplementedError()
35
+
36
+ @abc.abstractmethod
37
+ def reset(self):
38
+ raise NotImplementedError()
39
+
40
+
41
+ class MinMaxScaler(Transformation):
42
+ """_summary_
43
+ Minmax scaling on the entire data
44
+ """
45
+
46
+ def __init__(self, new_min=1, new_max=2):
47
+ self._is_fitted = False
48
+ self.new_min = new_min
49
+ self.new_max = new_max
50
+ self._min = None
51
+ self._max = None
52
+
53
+ @property
54
+ def min(self) -> float:
55
+ return self._min
56
+
57
+ @property
58
+ def max(self) -> float:
59
+ return self._max
60
+
61
+ @property
62
+ def is_fitted(self) -> bool:
63
+ return self._is_fitted
64
+
65
+ def fit(self, data: xr.DataArray):
66
+ if not self.is_fitted:
67
+ self._max = data.max().values
68
+ self._min = data.min().values
69
+ self._is_fitted = True
70
+ else:
71
+ info("Already fitted, skipping function.")
72
+ return self
73
+
74
+ def _transform(self, data: xr.DataArray):
75
+ return (
76
+ ((data - self.min) / (self.max - self.min)) * (self.new_max - self.new_min)
77
+ ) + self.new_min
78
+
79
+ def transform(self, data: xr.DataArray) -> xr.DataArray:
80
+ assert self.min is not None and self.max is not None, "You must run fit first."
81
+
82
+ data = xr.apply_ufunc(self._transform, data, dask="forbidden")
83
+
84
+ return data
85
+
86
+ def fit_transform(self, data):
87
+ self.fit(data)
88
+ return self.transform(data)
89
+
90
+ def inverse_transform(self, data):
91
+ return data * (self.max - self.min) + self.min
92
+
93
+ def to_dict(self) -> dict:
94
+ out_dict = {
95
+ "base": self.__module__,
96
+ "class": self.__class__.__name__,
97
+ "new_min": str(self.new_min),
98
+ "new_max": str(self.new_max),
99
+ "min": str(self.min),
100
+ "max": str(self.max),
101
+ "is_fitted": self.is_fitted,
102
+ }
103
+ return out_dict
104
+
105
+ @staticmethod
106
+ def from_dict(info: dict):
107
+ # with open(yaml_path, 'r') as file:
108
+ # data = yaml.load(file, Loader=yaml.SafeLoader)
109
+ out = MinMaxScaler(
110
+ new_min=np.float32(info["new_min"]), new_max=np.float32(info["new_max"])
111
+ )
112
+ out._min = np.float32(info["min"])
113
+ out._max = np.float32(info["max"])
114
+ out._is_fitted = info["is_fitted"]
115
+ return out
116
+
117
+ def reset(self):
118
+ self.__init__(self.new_min, self.new_max)
119
+
120
+ def __str__(self):
121
+ return (
122
+ f"min: {self.min}, "
123
+ f"max: {self.max}, "
124
+ f"new_max: {self.new_max}, "
125
+ f"new_min: {self.new_min}"
126
+ )
127
+
128
+
129
+ class StandardScaler(Transformation):
130
+ """_summary_
131
+ Standard scaling on the entire data
132
+ """
133
+
134
+ def __init__(self, epsilon=1e-8):
135
+ self.epsilon = epsilon
136
+ self._is_fitted = False
137
+ self._mean = None
138
+ self._std = None
139
+ self._min = None
140
+ self._max = None
141
+ self._sl_scale_factor = None
142
+
143
+ @property
144
+ def mean(self) -> float:
145
+ return self._mean
146
+
147
+ @property
148
+ def std(self) -> float:
149
+ return self._std
150
+
151
+ @property
152
+ def min(self) -> float:
153
+ return self._min
154
+
155
+ @property
156
+ def max(self) -> float:
157
+ return self._max
158
+
159
+ @property
160
+ def sl_scale_factor(self) -> float:
161
+ return self._sl_scale_factor
162
+
163
+ @property
164
+ def is_fitted(self) -> bool:
165
+ return self._is_fitted
166
+
167
+ def fit(self, data):
168
+ if not self.is_fitted:
169
+ self._mean = data.mean().values
170
+ self._std = data.std().values
171
+ self._min = data.min().values
172
+ self._max = data.max().values
173
+ self._is_fitted = True
174
+ else:
175
+ info("Already fitted, skipping function.")
176
+
177
+ return self
178
+
179
+ def _transform(self, data: xr.DataArray):
180
+ return (data - self.mean) / (self.std + self.epsilon)
181
+
182
+ def _signum_log_transform(self, data: xr.DataArray):
183
+ data = data * self.sl_scale_factor
184
+ return np.sign(data) * np.log1p(np.abs(data))
185
+
186
+ def signum_log_transform(self, data: xr.DataArray):
187
+ assert self.mean is not None and self.std is not None, "You must run fit first."
188
+
189
+ data = xr.apply_ufunc(self._signum_log_transform, data, dask="forbidden")
190
+ data = xr.apply_ufunc(self._transform, data, dask="forbidden")
191
+ return data
192
+
193
+ def transform(self, data: xr.DataArray):
194
+ assert self.mean is not None and self.std is not None, "You must run fit first."
195
+
196
+ data = xr.apply_ufunc(self._transform, data, dask="forbidden")
197
+ return data
198
+
199
+ def fit_transform(self, data: xr.DataArray):
200
+ self.fit(data)
201
+ return self.transform(data)
202
+
203
+ def inverse_transform(self, data):
204
+ if isinstance(data, torch.Tensor):
205
+ return data * (
206
+ torch.Tensor([self.std]).to(data.device)
207
+ + torch.Tensor([self.epsilon]).to(data.device)
208
+ ) + torch.Tensor([self.mean]).to(data.device)
209
+ else:
210
+ return data * (self.std + self.epsilon) + self.mean
211
+
212
+ def inverse_signum_log_transform(self, data):
213
+ if isinstance(data, torch.Tensor):
214
+ return (
215
+ torch.sign(data)
216
+ * torch.expm1(torch.abs(data))
217
+ / torch.Tensor([self.sl_scale_factor]).to(data.device)
218
+ )
219
+ else:
220
+ return np.sign(data) * np.expm1(np.abs(data)) / self.sl_scale_factor
221
+
222
+ def to_dict(self) -> dict:
223
+ return {
224
+ "base": self.__module__,
225
+ "class": self.__class__.__name__,
226
+ "epsilon": str(self.epsilon),
227
+ "mean": str(self.mean),
228
+ "std": str(self.std),
229
+ "is_fitted": self.is_fitted,
230
+ "min": str(self.min),
231
+ "max": str(self.max),
232
+ "sl_scale_factor": str(self.sl_scale_factor),
233
+ }
234
+
235
+ @staticmethod
236
+ def from_dict(info: dict):
237
+ out = StandardScaler(epsilon=np.float32(info["epsilon"]))
238
+ out._mean = np.float32(info["mean"])
239
+ out._std = np.float32(info["std"])
240
+ out._is_fitted = info["is_fitted"]
241
+ out._min = np.float32(info["min"])
242
+ out._max = np.float32(info["max"])
243
+ out._sl_scale_factor = np.float32(info["sl_scale_factor"])
244
+ return out
245
+
246
+ def reset(self):
247
+ self.__init__(self.epsilon)
248
+
249
+ def __str__(self):
250
+ return f"mean: {self.mean}, " f"std: {self.std}, " f"epsilon: {self.epsilon}"
251
+
252
+
253
+ class MaskUnits2D:
254
+ """
255
+ Transformation that takes a tuple of numpy tensors and returns a sequence of mask units. These are generally in the form `channel, dim_0, dim_1, dim_2, ...`. The returned data is largely of shape `mask unit sequence, channel, lat, lon`. Masked patches are not returned.
256
+ The return values contain sets of indices. The indices indicate which mask units where dropped (masked) or not. The 1D indexing here simply relies on flattening the 2D space of mask units. The class methods `reconstruct` and `reconstruct_batch` show how to re-assemble the entire sequence.
257
+ """
258
+
259
+ def __init__(
260
+ self,
261
+ n_lat_mu: int,
262
+ n_lon_mu: int,
263
+ padding,
264
+ seed=None,
265
+ mask_ratio_vals: float = 0.5,
266
+ mask_ratio_tars: float = 0.0,
267
+ n_lats: int = 361,
268
+ n_lons: int = 576,
269
+ ):
270
+ self.n_lat_mu = n_lat_mu
271
+ self.n_lon_mu = n_lon_mu
272
+ self.mask_ratio_vals = mask_ratio_vals
273
+ self.mask_ratio_tars = mask_ratio_tars
274
+ self.padding = padding
275
+ self.n_lats = n_lats + padding[0][0] + padding[0][1]
276
+ self.n_lons = n_lons + padding[1][0] + padding[1][1]
277
+
278
+ if self.n_lats % n_lat_mu != 0:
279
+ raise ValueError(
280
+ f"Padded latitudes {self.n_lats} are not an integer multiple of the mask unit size {n_lat_mu}."
281
+ )
282
+ if self.n_lons % n_lon_mu != 0:
283
+ raise ValueError(
284
+ f"Padded longitudes {self.n_lons} are not an integer multiple of the mask unit size {n_lon_mu}."
285
+ )
286
+
287
+ self.mask_shape = (self.n_lats // self.n_lat_mu, self.n_lons // self.n_lon_mu)
288
+
289
+ self.rng = np.random.default_rng(seed=seed)
290
+
291
+ def n_units_masked(self, mask_type="vals"):
292
+ if mask_type == "vals":
293
+ return int(self.mask_ratio_vals * np.prod(self.mask_shape))
294
+ elif mask_type == "tars":
295
+ return int(self.mask_ratio_tars * np.prod(self.mask_shape))
296
+ else:
297
+ raise ValueError(
298
+ f"`{mask_type}` not an allowed value for `mask_type`. Use `vals` or `tars`."
299
+ )
300
+
301
+ @staticmethod
302
+ def reconstruct(
303
+ idx_masked: torch.Tensor,
304
+ idx_unmasked: torch.Tensor,
305
+ data_masked: torch.Tensor,
306
+ data_unmasked: torch.Tensor,
307
+ ) -> torch.Tensor:
308
+ """
309
+ Reconstructs a tensor along the mask unit dimension. Non-batched version.
310
+
311
+ Args:
312
+ idx_masked: Tensor of shape `mask unit sequence`.
313
+ idx_unmasked: Tensor of shape `mask unit sequence`.
314
+ data_masked: Tensor of shape `mask unit sequence, ...`. Should have same size along mask unit sequence dimension as idx_masked. Dimensions beyond the first two, marked here as ... will typically be `local_sequence, channel` or `channel, lat, lon`. These dimensions should agree with data_unmasked.
315
+ data_unmasked: Tensor of shape `mask unit sequence, ...`. Should have same size along mask unit sequence dimension as idx_unmasked. Dimensions beyond the first two, marked here as ... will typically be `local_sequence, channel` or `channel, lat, lon`. These dimensions should agree with data_masked.
316
+ Returns:
317
+ Tensor of same shape as inputs data_masked and data_unmasked. I.e. `mask unit sequence, ...`.
318
+ """
319
+ idx_total = torch.argsort(torch.cat([idx_masked, idx_unmasked], dim=0), dim=0)
320
+ idx_total = idx_total.reshape(
321
+ *idx_total.shape,
322
+ *[1 for _ in range(len(idx_total.shape), len(data_unmasked.shape))],
323
+ )
324
+ idx_total = idx_total.expand(*idx_total.shape[:1], *data_unmasked.shape[1:])
325
+ data = torch.cat([data_masked, data_unmasked], dim=0)
326
+ data = torch.gather(data, dim=0, index=idx_total)
327
+ return data
328
+
329
+ @staticmethod
330
+ def reconstruct_batch(
331
+ idx_masked: torch.Tensor,
332
+ idx_unmasked: torch.Tensor,
333
+ data_masked: torch.Tensor,
334
+ data_unmasked: torch.Tensor,
335
+ ) -> torch.Tensor:
336
+ """
337
+ Reconstructs a tensor along the mask unit dimension. Batched version.
338
+
339
+ Args:
340
+ idx_masked: Tensor of shape `batch, mask unit sequence`.
341
+ idx_unmasked: Tensor of shape `batch, mask unit sequence`.
342
+ data_masked: Tensor of shape `batch, mask unit sequence, ...`. Should have same size along mask unit sequence dimension as idx_masked. Dimensions beyond the first two, marked here as ... will typically be `local_sequence, channel` or `channel, lat, lon`. These dimensions should agree with data_unmasked.
343
+ data_unmasked: Tensor of shape `batch, mask unit sequence, ...`. Should have same size along mask unit sequence dimension as idx_unmasked. Dimensions beyond the first two, marked here as ... will typically be `local_sequence, channel` or `channel, lat, lon`. These dimensions should agree with data_masked.
344
+ Returns:
345
+ Tensor of same shape as inputs data_masked and data_unmasked. I.e. `batch, mask unit sequence, ...`.
346
+ """
347
+ idx_total = torch.argsort(torch.cat([idx_masked, idx_unmasked], dim=1), dim=1)
348
+ idx_total = idx_total.reshape(
349
+ *idx_total.shape,
350
+ *[1 for _ in range(len(idx_total.shape), len(data_unmasked.shape))],
351
+ )
352
+ idx_total = idx_total.expand(*idx_total.shape[:2], *data_unmasked.shape[2:])
353
+ data = torch.cat([data_masked, data_unmasked], dim=1)
354
+ data = torch.gather(data, dim=1, index=idx_total)
355
+ return data
356
+
357
+ def __call__(self, data: Tuple[np.array]) -> Tuple[torch.Tensor]:
358
+ """
359
+ Args:
360
+ data: Tuple of numpy tensors. These are interpreted as `(sur_static, ulv_static, sur_vals, ulv_vals, sur_tars, ulv_tars)`.
361
+ Returns:
362
+ Tuple of torch tensors. If the target is unmasked (`mask_ratio_tars` is zero), the tuple contains
363
+ `(static, indices_masked_vals, indices_unmaked_vals, vals, tars)`. When targets are masked as well, we are dealing with
364
+ `(static, indices_masked_vals, indices_unmaked_vals, vals, indices_masked_tars, indices_unmasked_tars, tars)`.
365
+ Their shapes are as follows:
366
+ static: mask unit sequence, channel, lat, lon
367
+ indices_masked_vals: mask unit sequence
368
+ indices_unmaked_vals: mask unit sequence
369
+ vals: mask unit sequence, channel, lat, lon
370
+ tars: mask unit sequence, channel, lat, lon
371
+ """
372
+ sur_static, ulv_static, sur_vals, ulv_vals, sur_tars, ulv_tars = data
373
+
374
+ sur_vals, ulv_vals = np.squeeze(sur_vals, axis=1), np.squeeze(ulv_vals, axis=1)
375
+ sur_tars, ulv_tars = np.squeeze(sur_tars, axis=1), np.squeeze(ulv_tars, axis=1)
376
+
377
+ vals = np.concatenate(
378
+ [
379
+ sur_vals,
380
+ ulv_vals.reshape(
381
+ ulv_vals.shape[0] * ulv_vals.shape[1], *ulv_vals.shape[-2:]
382
+ ),
383
+ ],
384
+ axis=0,
385
+ )
386
+ tars = np.concatenate(
387
+ [
388
+ sur_tars,
389
+ ulv_tars.reshape(
390
+ ulv_tars.shape[0] * ulv_tars.shape[1], *ulv_tars.shape[-2:]
391
+ ),
392
+ ],
393
+ axis=0,
394
+ )
395
+
396
+ padding = ((0, 0), *self.padding)
397
+ static = np.pad(sur_static, padding)
398
+ vals = np.pad(vals, padding)
399
+ tars = np.pad(tars, padding)
400
+
401
+ static = static.reshape(
402
+ static.shape[0],
403
+ static.shape[-2] // self.n_lat_mu,
404
+ self.n_lat_mu,
405
+ static.shape[-1] // self.n_lon_mu,
406
+ self.n_lon_mu,
407
+ ).transpose(1, 3, 0, 2, 4)
408
+ vals = vals.reshape(
409
+ vals.shape[0],
410
+ vals.shape[-2] // self.n_lat_mu,
411
+ self.n_lat_mu,
412
+ vals.shape[-1] // self.n_lon_mu,
413
+ self.n_lon_mu,
414
+ ).transpose(1, 3, 0, 2, 4)
415
+ tars = tars.reshape(
416
+ tars.shape[0],
417
+ tars.shape[-2] // self.n_lat_mu,
418
+ self.n_lat_mu,
419
+ tars.shape[-1] // self.n_lon_mu,
420
+ self.n_lon_mu,
421
+ ).transpose(1, 3, 0, 2, 4)
422
+
423
+ maskable_indices = np.arange(np.prod(self.mask_shape))
424
+ maskable_indices = self.rng.permutation(maskable_indices)
425
+ indices_masked_vals = maskable_indices[: self.n_units_masked()]
426
+ indices_unmasked_vals = maskable_indices[self.n_units_masked() :]
427
+
428
+ vals = vals.reshape(-1, *vals.shape[2:])[indices_unmasked_vals, :, :, :]
429
+
430
+ if self.mask_ratio_tars > 0.0:
431
+ maskable_indices = np.arange(np.prod(self.mask_shape))
432
+ maskable_indices = self.rng.permutation(maskable_indices)
433
+ indices_masked_tars = maskable_indices[: self.n_units_masked("tars")]
434
+ indices_unmasked_tars = maskable_indices[self.n_units_masked("tars") :]
435
+
436
+ tars = tars.reshape(-1, *tars.shape[2:])[indices_unmasked_tars, :, :, :]
437
+
438
+ return_value = (
439
+ torch.from_numpy(static).flatten(0, 1),
440
+ torch.from_numpy(indices_masked_vals),
441
+ torch.from_numpy(indices_unmasked_vals),
442
+ torch.from_numpy(vals),
443
+ torch.from_numpy(indices_masked_tars),
444
+ torch.from_numpy(indices_unmasked_tars),
445
+ torch.from_numpy(tars),
446
+ )
447
+ return return_value
448
+ else:
449
+ return_value = (
450
+ torch.from_numpy(static).flatten(0, 1),
451
+ torch.from_numpy(indices_masked_vals),
452
+ torch.from_numpy(indices_unmasked_vals),
453
+ torch.from_numpy(vals),
454
+ torch.from_numpy(tars).flatten(0, 1),
455
+ )
456
+ return return_value
surya/models/__init__.py ADDED
File without changes
surya/models/embedding.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Perceiver code is based on Aurora: https://github.com/microsoft/aurora/blob/main/aurora/model/perceiver.py
3
+
4
+ Some conventions for notation:
5
+ B - Batch
6
+ T - Time
7
+ H - Height (pixel space)
8
+ W - Width (pixel space)
9
+ HT - Height (token space)
10
+ WT - Width (token space)
11
+ ST - Sequence (token space)
12
+ C - Input channels
13
+ D - Model (embedding) dimension
14
+ """
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from einops import rearrange
20
+ from timm.models.layers import trunc_normal_
21
+
22
+
23
+ class PatchEmbed3D(nn.Module):
24
+ """Timeseries Image to Patch Embedding"""
25
+
26
+ def __init__(
27
+ self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, time_dim=2
28
+ ):
29
+ super().__init__()
30
+ self.img_size = img_size
31
+ self.patch_size = patch_size
32
+ self.embed_dim = embed_dim
33
+ self.time_dim = time_dim
34
+
35
+ self.proj = nn.Conv2d(
36
+ in_chans * time_dim,
37
+ embed_dim,
38
+ kernel_size=(patch_size, patch_size),
39
+ stride=(patch_size, patch_size),
40
+ )
41
+
42
+ def forward(self, x):
43
+ """
44
+ Args:
45
+ x: Tensor of shape (B, C, T, H, W)
46
+ Returns:
47
+ Tensor of shape (B, ST, D)
48
+ """
49
+ B, C, T, H, W = x.shape
50
+ x = self.proj(x.flatten(1, 2)) # (B, C, T, H, W) -> (B, D, HT, WT)
51
+ x = rearrange(x, "B D HT WT -> B (HT WT) D") # (B, N, D)
52
+ return x
53
+
54
+
55
+ class LinearEmbedding(nn.Module):
56
+ def __init__(
57
+ self,
58
+ img_size=224,
59
+ patch_size=16,
60
+ in_chans=3,
61
+ time_dim=2,
62
+ embed_dim=768,
63
+ drop_rate=0.0,
64
+ ):
65
+ super().__init__()
66
+
67
+ self.num_patches = (img_size // patch_size) ** 2
68
+
69
+ self.patch_embed = PatchEmbed3D(
70
+ img_size=img_size,
71
+ patch_size=patch_size,
72
+ in_chans=in_chans,
73
+ embed_dim=embed_dim,
74
+ time_dim=time_dim,
75
+ )
76
+
77
+ self._generate_position_encoding(img_size, patch_size, embed_dim)
78
+
79
+ self.pos_drop = nn.Dropout(p=drop_rate)
80
+
81
+ def _generate_position_encoding(self, img_size, patch_size, embed_dim):
82
+ """
83
+ Generates a positional encoding signal for the model. The generated
84
+ positional encoding signal is stored as a buffer (`self.fourier_signal`).
85
+
86
+ Args:
87
+ img_size (int): The size of the input image.
88
+ patch_size (int): The size of each patch in the image.
89
+ embed_dim (int): The embedding dimension of the model.
90
+
91
+ Returns:
92
+ None.
93
+ """
94
+ # Generate signal of shape (C, H, W)
95
+ x = torch.linspace(0.0, 1.0, img_size // patch_size)
96
+ y = torch.linspace(0.0, 1.0, img_size // patch_size)
97
+ x, y = torch.meshgrid(x, y, indexing="xy")
98
+ fourier_signal = []
99
+
100
+ frequencies = torch.linspace(1, (img_size // patch_size) / 2.0, embed_dim // 4)
101
+
102
+ for f in frequencies:
103
+ fourier_signal.extend(
104
+ [
105
+ torch.cos(2.0 * torch.pi * f * x),
106
+ torch.sin(2.0 * torch.pi * f * x),
107
+ torch.cos(2.0 * torch.pi * f * y),
108
+ torch.sin(2.0 * torch.pi * f * y),
109
+ ]
110
+ )
111
+ fourier_signal = torch.stack(fourier_signal, dim=2)
112
+ fourier_signal = rearrange(fourier_signal, "h w c -> 1 (h w) c")
113
+ self.register_buffer("pos_embed", fourier_signal)
114
+
115
+ def forward(self, x, dt):
116
+ """
117
+ Args:
118
+ x: Tensor of shape (B, C, T, H, W).
119
+ dt: Tensor of shape (B, T). However it is not used.
120
+ Returns:
121
+ Tensor of shape (B, ST, D)
122
+ """
123
+ x = self.patch_embed(x)
124
+ x = x + self.pos_embed
125
+ x = self.pos_drop(x)
126
+
127
+ return x
128
+
129
+
130
+ class LinearDecoder(nn.Module):
131
+ def __init__(
132
+ self,
133
+ patch_size: int,
134
+ out_chans: int,
135
+ embed_dim: int,
136
+ ):
137
+ """
138
+ Args:
139
+ patch_size: patch size
140
+ in_chans: number of iput channels
141
+ embed_dim: embedding dimension
142
+ """
143
+ super().__init__()
144
+
145
+ self.unembed = nn.Sequential(
146
+ nn.Conv2d(
147
+ in_channels=embed_dim,
148
+ out_channels=(patch_size**2) * out_chans,
149
+ kernel_size=1,
150
+ ),
151
+ nn.PixelShuffle(patch_size),
152
+ )
153
+
154
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
155
+ """
156
+ Args:
157
+ x: Tensor of shape (B, L, D). For ensembles, we have implicitly B = (B E).
158
+ Returns:
159
+ Tensor of shape (B C H W).
160
+ Here
161
+ - C equals num_queries
162
+ - H == W == sqrt(L) x patch_size
163
+ """
164
+ # Reshape the tokens to 2d token space: (B, C, H_token, W_token)
165
+ _, L, _ = x.shape
166
+ H_token = W_token = int(L**0.5)
167
+ x = rearrange(x, "B (H W) D -> B D H W", H=H_token, W=W_token)
168
+
169
+ # Unembed the tokens. Convolution + pixel shuffle.
170
+ x = self.unembed(x)
171
+
172
+ return x
173
+
174
+
175
+ class MLP(nn.Module):
176
+ """A simple one-hidden-layer MLP."""
177
+
178
+ def __init__(self, dim: int, hidden_features: int, dropout: float = 0.0) -> None:
179
+ """Initialise.
180
+
181
+ Args:
182
+ dim (int): Input dimensionality.
183
+ hidden_features (int): Width of the hidden layer.
184
+ dropout (float, optional): Drop-out rate. Defaults to no drop-out.
185
+ """
186
+ super().__init__()
187
+ self.net = nn.Sequential(
188
+ nn.Linear(dim, hidden_features),
189
+ nn.GELU(),
190
+ nn.Linear(hidden_features, dim),
191
+ nn.Dropout(dropout),
192
+ )
193
+
194
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
195
+ """Run the MLP."""
196
+ return self.net(x)
197
+
198
+
199
+ class PerceiverAttention(nn.Module):
200
+ """Cross attention module from the Perceiver architecture."""
201
+
202
+ def __init__(
203
+ self,
204
+ latent_dim: int,
205
+ context_dim: int,
206
+ head_dim: int = 64,
207
+ num_heads: int = 8,
208
+ ) -> None:
209
+ """Initialise.
210
+
211
+ Args:
212
+ latent_dim (int): Dimensionality of the latent features given as input.
213
+ context_dim (int): Dimensionality of the context features also given as input.
214
+ head_dim (int): Attention head dimensionality.
215
+ num_heads (int): Number of heads.
216
+ """
217
+ super().__init__()
218
+ self.num_heads = num_heads
219
+ self.head_dim = head_dim
220
+ self.inner_dim = head_dim * num_heads
221
+
222
+ self.to_q = nn.Linear(latent_dim, self.inner_dim, bias=False)
223
+ self.to_kv = nn.Linear(context_dim, self.inner_dim * 2, bias=False)
224
+ self.to_out = nn.Linear(self.inner_dim, latent_dim, bias=False)
225
+
226
+ def forward(self, latents: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
227
+ """Run the cross-attention module.
228
+
229
+ Args:
230
+ latents (:class:`torch.Tensor`): Latent features of shape `(B, L1, Latent_D)`
231
+ where typically `L1 < L2` and `Latent_D <= Context_D`. `Latent_D` is equal to
232
+ `self.latent_dim`.
233
+ x (:class:`torch.Tensor`): Context features of shape `(B, L2, Context_D)`.
234
+
235
+ Returns:
236
+ :class:`torch.Tensor`: Latent values of shape `(B, L1, Latent_D)`.
237
+ """
238
+ h = self.num_heads
239
+
240
+ q = self.to_q(latents) # (B, L1, D2) to (B, L1, D)
241
+ k, v = self.to_kv(x).chunk(2, dim=-1) # (B, L2, D1) to twice (B, L2, D)
242
+ q, k, v = map(lambda t: rearrange(t, "b l (h d) -> b h l d", h=h), (q, k, v))
243
+
244
+ out = F.scaled_dot_product_attention(q, k, v)
245
+ out = rearrange(out, "B H L1 D -> B L1 (H D)") # (B, L1, D)
246
+ return self.to_out(out) # (B, L1, Latent_D)
247
+
248
+
249
+ class PerceiverResampler(nn.Module):
250
+ """Perceiver Resampler module from the Flamingo paper."""
251
+
252
+ def __init__(
253
+ self,
254
+ latent_dim: int,
255
+ context_dim: int,
256
+ depth: int = 1,
257
+ head_dim: int = 64,
258
+ num_heads: int = 16,
259
+ mlp_ratio: float = 4.0,
260
+ drop: float = 0.0,
261
+ residual_latent: bool = True,
262
+ ln_eps: float = 1e-5,
263
+ ) -> None:
264
+ """Initialise.
265
+
266
+ Args:
267
+ latent_dim (int): Dimensionality of the latent features given as input.
268
+ context_dim (int): Dimensionality of the context features also given as input.
269
+ depth (int, optional): Number of attention layers.
270
+ head_dim (int, optional): Attention head dimensionality. Defaults to `64`.
271
+ num_heads (int, optional): Number of heads. Defaults to `16`
272
+ mlp_ratio (float, optional): Rimensionality of the hidden layer divided by that of the
273
+ input for all MLPs. Defaults to `4.0`.
274
+ drop (float, optional): Drop-out rate. Defaults to no drop-out.
275
+ residual_latent (bool, optional): Use residual attention w.r.t. the latent features.
276
+ Defaults to `True`.
277
+ ln_eps (float, optional): Epsilon in the layer normalisation layers. Defaults to
278
+ `1e-5`.
279
+ """
280
+ super().__init__()
281
+
282
+ self.residual_latent = residual_latent
283
+ self.layers = nn.ModuleList([])
284
+ mlp_hidden_dim = int(latent_dim * mlp_ratio)
285
+ for _ in range(depth):
286
+ self.layers.append(
287
+ nn.ModuleList(
288
+ [
289
+ PerceiverAttention(
290
+ latent_dim=latent_dim,
291
+ context_dim=context_dim,
292
+ head_dim=head_dim,
293
+ num_heads=num_heads,
294
+ ),
295
+ MLP(
296
+ dim=latent_dim, hidden_features=mlp_hidden_dim, dropout=drop
297
+ ),
298
+ nn.LayerNorm(latent_dim, eps=ln_eps),
299
+ nn.LayerNorm(latent_dim, eps=ln_eps),
300
+ ]
301
+ )
302
+ )
303
+
304
+ def forward(self, latents: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
305
+ """Run the module.
306
+
307
+ Args:
308
+ latents (:class:`torch.Tensor`): Latent features of shape `(B, L1, D1)`.
309
+ x (:class:`torch.Tensor`): Context features of shape `(B, L2, D1)`.
310
+
311
+ Returns:
312
+ torch.Tensor: Latent features of shape `(B, L1, D1)`.
313
+ """
314
+ for attn, ff, ln1, ln2 in self.layers:
315
+ # We use post-res-norm like in Swin v2 and most Transformer architectures these days.
316
+ # This empirically works better than the pre-norm used in the original Perceiver.
317
+ attn_out = ln1(attn(latents, x))
318
+ # HuggingFace suggests using non-residual attention in Perceiver might work better when
319
+ # the semantics of the query and the output are different:
320
+ #
321
+ # https://github.com/huggingface/transformers/blob/v4.35.2/src/transformers/models/perceiver/modeling_perceiver.py#L398
322
+ #
323
+ latents = attn_out + latents if self.residual_latent else attn_out
324
+ latents = ln2(ff(latents)) + latents
325
+ return latents
326
+
327
+
328
+ class PerceiverChannelEmbedding(nn.Module):
329
+ def __init__(
330
+ self,
331
+ in_chans: int,
332
+ img_size: int,
333
+ patch_size: int,
334
+ time_dim: int,
335
+ num_queries: int,
336
+ embed_dim: int,
337
+ drop_rate: float,
338
+ ):
339
+ super().__init__()
340
+
341
+ if embed_dim % 2 != 0:
342
+ raise ValueError(
343
+ f"Temporal embeddings require `embed_dim` to be even. Currently we have {embed_dim}."
344
+ )
345
+
346
+ self.num_patches = (img_size // patch_size) ** 2
347
+ self.num_queries = num_queries
348
+ self.embed_dim = embed_dim
349
+
350
+ self.proj = nn.Conv2d(
351
+ in_channels=in_chans * time_dim,
352
+ out_channels=in_chans * embed_dim,
353
+ kernel_size=patch_size,
354
+ stride=patch_size,
355
+ groups=in_chans,
356
+ )
357
+
358
+ self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, self.num_patches))
359
+ trunc_normal_(self.pos_embed, std=0.02)
360
+
361
+ self.latent_queries = nn.Parameter(torch.zeros(1, num_queries, embed_dim))
362
+ trunc_normal_(self.latent_queries, std=0.02)
363
+
364
+ self.perceiver = PerceiverResampler(
365
+ latent_dim=embed_dim,
366
+ context_dim=embed_dim,
367
+ depth=1,
368
+ head_dim=embed_dim // 16,
369
+ num_heads=16,
370
+ mlp_ratio=4.0,
371
+ drop=0.0,
372
+ residual_latent=False,
373
+ ln_eps=1e-5,
374
+ )
375
+
376
+ self.latent_aggregation = nn.Linear(num_queries * embed_dim, embed_dim)
377
+
378
+ self.pos_drop = nn.Dropout(p=drop_rate)
379
+
380
+ def forward(self, x, dt):
381
+ """
382
+ Args:
383
+ x: Tensor of shape (B, C, T, H, W)
384
+ dt: Tensor of shape (B, T) identifying time deltas.
385
+ Returns:
386
+ Tensor of shape (B, ST, D)
387
+ """
388
+ B, C, T, H, W = x.shape
389
+ x = rearrange(x, "B C T H W -> B (C T) H W")
390
+ x = self.proj(x) # B (C T) H W -> B (C D) HT WT
391
+ x = x.flatten(2, 3) # B (C D) ST
392
+ ST = x.shape[2]
393
+ assert ST == self.num_patches
394
+ x = rearrange(x, "B (C D) ST -> (B C) D ST", B=B, ST=ST, C=C, D=self.embed_dim)
395
+ x = x + self.pos_embed
396
+ x = rearrange(x, "(B C) D ST -> (B ST) C D", B=B, ST=ST, C=C, D=self.embed_dim)
397
+
398
+ # ((B ST) NQ D), ((B ST) C D) -> ((B ST) NQ D)
399
+ x = self.perceiver(self.latent_queries.expand(B * ST, -1, -1), x)
400
+ x = rearrange(
401
+ x,
402
+ "(B ST) NQ D -> B ST (NQ D)",
403
+ B=B,
404
+ ST=self.num_patches,
405
+ NQ=self.num_queries,
406
+ D=self.embed_dim,
407
+ )
408
+ x = self.latent_aggregation(x) # B ST (NQ D) -> B ST D'
409
+
410
+ assert x.shape[1] == self.num_patches
411
+ assert x.shape[2] == self.embed_dim
412
+
413
+ x = self.pos_drop(x)
414
+
415
+ return x
416
+
417
+
418
+ class PerceiverDecoder(nn.Module):
419
+ def __init__(
420
+ self,
421
+ embed_dim: int,
422
+ patch_size: int,
423
+ out_chans: int,
424
+ ):
425
+ """
426
+ Args:
427
+ embed_dim: embedding dimension
428
+ patch_size: patch size
429
+ out_chans: number of output channels. This determines the number of latent queries.
430
+ drop_rate: dropout rate
431
+ """
432
+ super().__init__()
433
+
434
+ self.embed_dim = embed_dim
435
+ self.patch_size = patch_size
436
+ self.out_chans = out_chans
437
+
438
+ self.latent_queries = nn.Parameter(torch.zeros(1, out_chans, embed_dim))
439
+ trunc_normal_(self.latent_queries, std=0.02)
440
+
441
+ self.perceiver = PerceiverResampler(
442
+ latent_dim=embed_dim,
443
+ context_dim=embed_dim,
444
+ depth=1,
445
+ head_dim=embed_dim // 16,
446
+ num_heads=16,
447
+ mlp_ratio=4.0,
448
+ drop=0.0,
449
+ residual_latent=False,
450
+ ln_eps=1e-5,
451
+ )
452
+ self.proj = nn.Conv2d(
453
+ in_channels=out_chans * embed_dim,
454
+ out_channels=out_chans * patch_size**2,
455
+ kernel_size=1,
456
+ padding=0,
457
+ groups=out_chans,
458
+ )
459
+ self.pixel_shuffle = nn.PixelShuffle(patch_size)
460
+
461
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
462
+ """
463
+ Args:
464
+ x: Tensor of shape (B, L, D) For ensembles, we have implicitly B = (B E).
465
+ Returns:
466
+ Tensor of shape (B C H W).
467
+ Here
468
+ - C equals out_chans
469
+ - H == W == sqrt(L) x patch_size
470
+ """
471
+ B, L, D = x.shape
472
+ H_token = W_token = int(L**0.5)
473
+
474
+ x = rearrange(x, "B L D -> (B L) 1 D")
475
+ # (B L) 1 D -> (B L) C D
476
+ x = self.perceiver(self.latent_queries.expand(B * L, -1, -1), x)
477
+ x = rearrange(x, "(B H W) C D -> B (C D) H W", H=H_token, W=W_token)
478
+ # B (C D) H_token W_token -> B (C patch_size patch_size) H_token W_token
479
+ x = self.proj(x)
480
+ # B (C patch_size patch_size) H_token W_token -> B C H W
481
+ x = self.pixel_shuffle(x)
482
+
483
+ return x
surya/models/flow.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class HelioFlowModel(nn.Module):
7
+ def __init__(self, img_size=(4096, 4096), use_latitude_in_learned_flow=False):
8
+ super().__init__()
9
+
10
+ self.use_latitude_in_learned_flow = use_latitude_in_learned_flow
11
+
12
+ u = torch.linspace(-1, 1, img_size[0])
13
+ v = torch.linspace(-1, 1, img_size[1])
14
+ u, v = torch.meshgrid(u, v, indexing="xy")
15
+ self.register_buffer(
16
+ "grid", torch.stack((u, v), dim=2).view(1, *img_size, 2)
17
+ ) # B, H, W, 2
18
+
19
+ # Higher modes can be used for explicit feature engineering for flow features.
20
+ if self.use_latitude_in_learned_flow:
21
+ higher_modes = [u, v, torch.ones_like(u)]
22
+ else:
23
+ higher_modes = [
24
+ u,
25
+ v,
26
+ ]
27
+ self.register_buffer(
28
+ "higher_modes", torch.stack(higher_modes, dim=2).view(1, *img_size, -1)
29
+ )
30
+
31
+ self.flow_generator = nn.Sequential(
32
+ nn.Linear(self.higher_modes.shape[3], 128),
33
+ nn.GELU(),
34
+ nn.Linear(128, 2),
35
+ )
36
+
37
+ def forward(self, batch):
38
+ """
39
+ Args:
40
+ batch: Dictionary containing keys `ts` and
41
+ `forecast_latitude` (optionally).
42
+ ts (torch.Tensor): B, C, T, H, W
43
+ forecast_latitude (torch.Tensor): B, L
44
+ B - Batch size, C - Channels, T - Input times, H - Image height,
45
+ W - Image width, L - Lead time.
46
+ """
47
+
48
+ x = batch["ts"]
49
+ B, C, T, H, W = x.shape
50
+ if T == 1:
51
+ x = x[:, :, -1, :, :]
52
+ else:
53
+ # Taking the average of the last two time stamps
54
+ x = (x[:, :, -1, :, :] + x[:, :, -2, :, :]) / 2
55
+
56
+ # Flow fields have the shape B, H_out, W_out, 2
57
+ if self.use_latitude_in_learned_flow:
58
+ broadcast_lat = batch["forecast_latitude"] / 7
59
+ broadcast_lat = torch.concatenate(
60
+ [
61
+ torch.ones_like(broadcast_lat),
62
+ torch.ones_like(broadcast_lat),
63
+ broadcast_lat,
64
+ ],
65
+ 1,
66
+ )[:, None, None, :]
67
+ higher_modes = self.higher_modes * broadcast_lat
68
+ flow_field = self.grid + self.flow_generator(higher_modes)
69
+ else:
70
+ flow_field = self.grid + self.flow_generator(self.higher_modes)
71
+ flow_field = flow_field.expand(B, H, W, 2)
72
+
73
+ y_hat = F.grid_sample(
74
+ x,
75
+ flow_field,
76
+ mode="bilinear",
77
+ padding_mode="border", # Possible values: zeros, border, or reflection.
78
+ align_corners=False,
79
+ )
80
+
81
+ return y_hat
surya/models/helio_spectformer.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from torch import nn
4
+
5
+ import numpy as np
6
+
7
+ from .spectformer import SpectFormer, BlockSpectralGating, BlockAttention
8
+ from .embedding import (
9
+ LinearEmbedding,
10
+ PatchEmbed3D,
11
+ PerceiverChannelEmbedding,
12
+ LinearDecoder,
13
+ PerceiverDecoder,
14
+ )
15
+ from .flow import HelioFlowModel
16
+
17
+
18
+ class HelioSpectFormer(nn.Module):
19
+ """
20
+ A note on the ensemble capability:
21
+ Ensembles of size E are generated by setting `ensemble=E`. In this case, the forward
22
+ pass generates ensemble members after tokenization by increasing the batch dimension
23
+ B to B x E. Noise is injected in the `self.backbone` Specformer blocks. After the
24
+ backbone, ensemble members ride along implicitly in the batch dimension. (This is
25
+ mainly through the `self.unembed` pass.) An explicit ensemble dimension is only
26
+ generated at the end.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ img_size: int,
32
+ patch_size: int,
33
+ in_chans: int,
34
+ embed_dim: int,
35
+ time_embedding: dict,
36
+ depth: int,
37
+ n_spectral_blocks: int,
38
+ num_heads: int,
39
+ mlp_ratio: float,
40
+ drop_rate: float,
41
+ window_size: int,
42
+ dp_rank: int,
43
+ learned_flow: bool = False,
44
+ use_latitude_in_learned_flow: bool = False,
45
+ init_weights: bool = False,
46
+ checkpoint_layers: list[int] | None = None,
47
+ rpe: bool = False,
48
+ ensemble: int | None = None,
49
+ finetune: bool = True,
50
+ nglo: int = 0,
51
+ dtype: torch.dtype | None = None,
52
+ ) -> None:
53
+ """
54
+ Args:
55
+ img_size: input image size
56
+ patch_size: patch size
57
+ in_chans: number of iput channels
58
+ embed_dim: embeddin dimension
59
+ time_embedding: dictionary to configure temporal embedding:
60
+ `type` (str, required): indicates embedding type. `linear`, `perceiver`.
61
+ `time_dim` (int): indicates length of time dimension. required for linear embedding.
62
+ `n_queries` (int): indicates number of perceiver queries. required for perceiver.
63
+ depth: number of transformer blocks
64
+ n_spectral_blocks: number of spectral gating blocks
65
+ num_heads: Number of transformer heads
66
+ mlp_ratio: MLP ratio for transformer blocks
67
+ drop_rate: dropout rate
68
+ window_size: window size for long/short attention
69
+ dp_rank: dp rank for long/short attention
70
+ learned_flow: if true, combine learned flow model with spectformer
71
+ use_latitude_in_learned_flow: use latitudes in learned flow
72
+ init_weights: use optimized weight initialization
73
+ checkpoint_layers: indicate which layers to use for checkpointing
74
+ rpe: Use relative position encoding in Long-Short attention blocks.
75
+ ensemble: Integer indicating ensemble size or None for deterministic model.
76
+ finetune: Indicates whether to train from scrach or fine-tune the model. If set to `True`, the final output layers are removed.
77
+ nglo: Number of (additional) global tokens.
78
+ dtype: A torch data type. Not used and added only for compatibility with the remainder of the codebase.
79
+ """
80
+ super().__init__()
81
+
82
+ self.learned_flow = learned_flow
83
+ self.patch_size = patch_size
84
+ self.embed_dim = embed_dim
85
+ self.in_chans = in_chans
86
+ self.time_embedding = time_embedding
87
+ self.ensemble = ensemble
88
+ self.finetune = finetune
89
+ self.nglo = nglo
90
+
91
+ if learned_flow:
92
+ self.learned_flow_model = HelioFlowModel(
93
+ img_size=(img_size, img_size),
94
+ use_latitude_in_learned_flow=use_latitude_in_learned_flow,
95
+ )
96
+
97
+ match time_embedding["type"]:
98
+ case "linear":
99
+ self.time_dim = time_embedding["time_dim"]
100
+ if learned_flow:
101
+ self.time_dim += 1
102
+ self.embedding = LinearEmbedding(
103
+ img_size, patch_size, in_chans, self.time_dim, embed_dim, drop_rate
104
+ )
105
+
106
+ if not self.finetune:
107
+ self.unembed = LinearDecoder(
108
+ patch_size=patch_size, out_chans=in_chans, embed_dim=embed_dim
109
+ )
110
+ case "perceiver":
111
+ self.embedding = PerceiverChannelEmbedding(
112
+ in_chans=in_chans,
113
+ img_size=img_size,
114
+ patch_size=patch_size,
115
+ time_dim=time_embedding["time_dim"],
116
+ num_queries=time_embedding["n_queries"],
117
+ embed_dim=embed_dim,
118
+ drop_rate=drop_rate,
119
+ )
120
+ if not self.finetune:
121
+ self.unembed = PerceiverDecoder(
122
+ embed_dim=embed_dim,
123
+ patch_size=patch_size,
124
+ out_chans=in_chans,
125
+ )
126
+ case _:
127
+ raise NotImplementedError(
128
+ f'Embedding {time_embedding["type"]} has not been implemented.'
129
+ )
130
+
131
+ if isinstance(depth, list):
132
+ raise NotImplementedError(
133
+ "Multi scale models are no longer supported. Depth should be a single integer."
134
+ )
135
+ self.backbone = SpectFormer(
136
+ grid_size=img_size // patch_size,
137
+ embed_dim=embed_dim,
138
+ depth=depth,
139
+ n_spectral_blocks=n_spectral_blocks,
140
+ num_heads=num_heads,
141
+ mlp_ratio=mlp_ratio,
142
+ drop_rate=drop_rate,
143
+ window_size=window_size,
144
+ dp_rank=dp_rank,
145
+ checkpoint_layers=checkpoint_layers,
146
+ rpe=rpe,
147
+ ensemble=ensemble,
148
+ nglo=nglo,
149
+ )
150
+
151
+ if init_weights:
152
+ self.apply(self._init_weights)
153
+
154
+ # @staticmethod
155
+ # def _checkpoint_wrapper(
156
+ # model: nn.Module, data: tuple[Tensor, Tensor | None]
157
+ # ) -> Tensor:
158
+ # return checkpoint(model, data, use_reentrant=False)
159
+
160
+ def _init_weights(self, module):
161
+
162
+ if self.time_embedding["type"] == "linear":
163
+ # sampling_step * embed_dim = patch_size**2 * in_chans * time_dim
164
+ sampling_step = int(
165
+ np.sqrt(
166
+ (self.patch_size**2 * self.in_chans * self.time_dim)
167
+ / self.embed_dim
168
+ )
169
+ )
170
+ else:
171
+ sampling_step = int(
172
+ np.sqrt((self.patch_size**2 * self.in_chans) / self.embed_dim)
173
+ )
174
+ if isinstance(module, PatchEmbed3D):
175
+ torch.nn.init.zeros_(module.proj.weight)
176
+ c_out = 0
177
+ w_pool = 1.0 / sampling_step
178
+ for k in range(self.in_chans * self.time_dim):
179
+ for i in range(0, self.patch_size, sampling_step):
180
+ for j in range(0, self.patch_size, sampling_step):
181
+ module.proj.weight.data[
182
+ c_out, k, i : i + sampling_step, j : j + sampling_step
183
+ ] = w_pool
184
+ c_out += 1
185
+ if module.proj.bias is not None:
186
+ module.proj.bias.data.zero_()
187
+ if isinstance(module, BlockSpectralGating):
188
+ for m in [
189
+ module.mlp.fc1,
190
+ module.mlp.fc2,
191
+ ]:
192
+ # m.weight.data.normal_(mean=0.0, std=0.01)
193
+ # torch.nn.init.eye_(m.weight)
194
+ torch.nn.init.eye_(m.weight)
195
+ if m.bias is not None:
196
+ m.bias.data.zero_()
197
+ if isinstance(module, BlockAttention):
198
+ for m in [
199
+ module.mlp.fc1,
200
+ module.mlp.fc2,
201
+ ]:
202
+ # torch.nn.init.eye_(m.weight)
203
+ torch.nn.init.zeros_(m.weight)
204
+ if m.bias is not None:
205
+ m.bias.data.zero_()
206
+ for m in [
207
+ module.attn.qkv,
208
+ module.attn.proj,
209
+ module.attn.to_dynamic_projection,
210
+ ]:
211
+ # m.weight.data.normal_(mean=0.0, std=0.01)
212
+ # torch.nn.init.eye_(m.weight)
213
+ torch.nn.init.zeros_(m.weight)
214
+ if m.bias is not None:
215
+ m.bias.data.zero_()
216
+ if isinstance(module, torch.nn.Sequential):
217
+ if isinstance(module[1], torch.nn.PixelShuffle):
218
+ # torch.nn.init.eye_(module[0].weight.data[:,:,0,0])
219
+ torch.nn.init.zeros_(module[0].weight)
220
+ if self.time_embedding["type"] == "linear":
221
+ c_out = 0
222
+ for k in range(1, self.in_chans + 1):
223
+ for i in range(
224
+ self.patch_size**2 // (self.patch_size * sampling_step)
225
+ ):
226
+ for j in range(self.patch_size):
227
+ module[0].weight.data[
228
+ c_out : c_out + sampling_step,
229
+ j + (k * self.time_dim - 1) * self.patch_size,
230
+ ] = 1.0
231
+ c_out += sampling_step
232
+ else:
233
+ c_out = 0
234
+ for k in range(2 * self.in_chans):
235
+ # l = 0
236
+ for l_feat in range(self.backbone.embed_dim):
237
+ module[0].weight.data[c_out, l_feat] = 1.0
238
+ c_out += 1
239
+ if module[0].bias is not None:
240
+ module[0].bias.data.zero_()
241
+
242
+ def forward(self, batch):
243
+ """
244
+ Args:
245
+ batch: Dictionary containing keys `ts` and `time_delta_input`.
246
+ Their values are tensors with shapes as follows.
247
+ ts: B, C, T, H, W
248
+ time_delta_input: B, T
249
+ Returns:
250
+ Tensor fo shape (B, C, H, W) for deterministic or (B, E, C, H, W) for ensemble forecasts.
251
+ """
252
+ x = batch["ts"]
253
+ dt = batch["time_delta_input"]
254
+ B, C, T, H, W = x.shape
255
+
256
+ if self.learned_flow:
257
+ y_hat_flow = self.learned_flow_model(batch) # B, C, H, W
258
+ if any(
259
+ [param.requires_grad for param in self.learned_flow_model.parameters()]
260
+ ):
261
+ return y_hat_flow
262
+ else:
263
+ x = torch.concat((x, y_hat_flow.unsqueeze(2)), dim=2) # B, C, T+1, H, W
264
+ if self.time_embedding["type"] == "perceiver":
265
+ dt = torch.cat((dt, batch["lead_time_delta"].reshape(-1, 1)), dim=1)
266
+
267
+ # embed the data
268
+ tokens = self.embedding(x, dt)
269
+
270
+ # copy tokens in case of ensemble forecast
271
+ if self.ensemble:
272
+ # B L D -> (B E) L D == BE L D
273
+ tokens = torch.repeat_interleave(tokens, repeats=self.ensemble, dim=0)
274
+
275
+ # pass the time series through the encoder
276
+ tokens = self.backbone(tokens)
277
+
278
+ if self.finetune:
279
+ return tokens
280
+
281
+ # Unembed the tokens
282
+ # BE L D -> BE C H W
283
+ forecast_hat = self.unembed(tokens)
284
+
285
+ assert forecast_hat.shape == (
286
+ B * self.ensemble if self.ensemble else B,
287
+ C,
288
+ H,
289
+ W,
290
+ ), f"forecast_hat has shape {forecast_hat.shape} yet expected {(B*self.ensemble if self.ensemble else B, C, H, W)}."
291
+
292
+ if self.learned_flow:
293
+ assert y_hat_flow.shape == (
294
+ B,
295
+ C,
296
+ H,
297
+ W,
298
+ ), f"y_hat_flow has shape {y_hat_flow.shape} yet expected {(B, C, H, W)}."
299
+ if self.ensemble:
300
+ y_hat_flow = torch.repeat_interleave(
301
+ y_hat_flow, repeats=self.ensemble, dim=0
302
+ )
303
+ assert y_hat_flow.shape == forecast_hat.shape
304
+ forecast_hat = forecast_hat + y_hat_flow
305
+
306
+ assert forecast_hat.shape == (
307
+ B * self.ensemble if self.ensemble else B,
308
+ C,
309
+ H,
310
+ W,
311
+ ), f"forecast_hat has shape {forecast_hat.shape} yet expected {(B*self.ensemble if self.ensemble else B, C, H, W)}."
312
+
313
+ if self.ensemble:
314
+ forecast_hat = rearrange(
315
+ forecast_hat, "(B E) C H W -> B E C H W", B=B, E=self.ensemble
316
+ )
317
+
318
+ return forecast_hat
surya/models/spectformer.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import logging
3
+ from itertools import chain
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.utils.checkpoint import checkpoint
8
+
9
+ from timm.models.layers import DropPath, trunc_normal_
10
+ import torch.fft
11
+
12
+ from .transformer_ls import AttentionLS
13
+
14
+ _logger = logging.getLogger(__name__)
15
+
16
+
17
+ class Mlp(nn.Module):
18
+ def __init__(
19
+ self,
20
+ in_features,
21
+ hidden_features=None,
22
+ out_features=None,
23
+ act_layer=nn.GELU,
24
+ drop=0.0,
25
+ ):
26
+ super().__init__()
27
+ out_features = out_features or in_features
28
+ hidden_features = hidden_features or in_features
29
+ self.fc1 = nn.Linear(in_features, hidden_features)
30
+ self.act = act_layer()
31
+ self.fc2 = nn.Linear(hidden_features, out_features)
32
+ self.drop = nn.Dropout(drop)
33
+
34
+ def forward(self, x):
35
+ x = self.fc1(x)
36
+ x = self.act(x)
37
+ x = self.drop(x)
38
+ x = self.fc2(x)
39
+ x = self.drop(x)
40
+ return x
41
+
42
+
43
+ class SpectralGatingNetwork(nn.Module):
44
+ def __init__(self, dim, h=14, w=8):
45
+ super().__init__()
46
+ self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2) * 0.02)
47
+ self.w = w
48
+ self.h = h
49
+
50
+ def forward(self, x, spatial_size=None):
51
+ B, N, C = x.shape # torch.Size([1, 262144, 1024])
52
+ if spatial_size is None:
53
+ a = b = int(math.sqrt(N)) # a=b=512
54
+ else:
55
+ a, b = spatial_size
56
+
57
+ x = x.view(B, a, b, C) # torch.Size([1, 512, 512, 1024])
58
+
59
+ # FROM HERE USED TO BE AUTOCAST to float32
60
+ dtype = x.dtype
61
+ x = x.to(torch.float32)
62
+ x = torch.fft.rfft2(
63
+ x, dim=(1, 2), norm="ortho"
64
+ ) # torch.Size([1, 512, 257, 1024])
65
+ weight = torch.view_as_complex(
66
+ self.complex_weight.to(torch.float32)
67
+ ) # torch.Size([512, 257, 1024])
68
+ x = x * weight
69
+ x = torch.fft.irfft2(
70
+ x, s=(a, b), dim=(1, 2), norm="ortho"
71
+ ) # torch.Size([1, 512, 512, 1024])
72
+ x = x.to(dtype)
73
+
74
+ x = x.reshape(B, N, C) # torch.Size([1, 262144, 1024])
75
+ # UP TO HERE USED TO BE AUTOCAST to float32
76
+
77
+ return x
78
+
79
+
80
+ class BlockSpectralGating(nn.Module):
81
+ def __init__(
82
+ self,
83
+ dim,
84
+ mlp_ratio=4.0,
85
+ drop=0.0,
86
+ drop_path=0.0,
87
+ act_layer=nn.GELU,
88
+ norm_layer=nn.LayerNorm,
89
+ h=14,
90
+ w=8,
91
+ ):
92
+ super().__init__()
93
+ self.norm1 = norm_layer(dim)
94
+ self.filter = SpectralGatingNetwork(dim, h=h, w=w)
95
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
96
+ self.norm2 = norm_layer(dim)
97
+ mlp_hidden_dim = int(dim * mlp_ratio)
98
+ self.mlp = Mlp(
99
+ in_features=dim,
100
+ hidden_features=mlp_hidden_dim,
101
+ act_layer=act_layer,
102
+ drop=drop,
103
+ )
104
+
105
+ def forward(self, x, *args):
106
+ x = x + self.drop_path(self.mlp(self.norm2(self.filter(self.norm1(x)))))
107
+ return x
108
+
109
+
110
+ class BlockAttention(nn.Module):
111
+ def __init__(
112
+ self,
113
+ dim,
114
+ num_heads: int = 8,
115
+ mlp_ratio=4.0,
116
+ drop=0.0,
117
+ drop_path=0.0,
118
+ w=2,
119
+ dp_rank=2,
120
+ act_layer=nn.GELU,
121
+ norm_layer=nn.LayerNorm,
122
+ rpe=False,
123
+ adaLN=False,
124
+ nglo=0,
125
+ ):
126
+ """
127
+ num_heads: Attention heads. 4 for tiny, 8 for small and 12 for base
128
+ """
129
+ super().__init__()
130
+ self.norm1 = norm_layer(dim)
131
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
132
+ self.norm2 = norm_layer(dim)
133
+ mlp_hidden_dim = int(dim * mlp_ratio)
134
+ self.mlp = Mlp(
135
+ in_features=dim,
136
+ hidden_features=mlp_hidden_dim,
137
+ act_layer=act_layer,
138
+ drop=drop,
139
+ )
140
+ self.attn = AttentionLS(
141
+ dim=dim,
142
+ num_heads=num_heads,
143
+ w=w,
144
+ dp_rank=dp_rank,
145
+ nglo=nglo,
146
+ rpe=rpe,
147
+ )
148
+
149
+ if adaLN:
150
+ self.adaLN_modulation = nn.Sequential(
151
+ nn.Linear(dim, dim, bias=True),
152
+ act_layer(),
153
+ nn.Linear(dim, 6 * dim, bias=True),
154
+ )
155
+ else:
156
+ self.adaLN_modulation = None
157
+
158
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
159
+ if self.adaLN_modulation is not None:
160
+ (
161
+ shift_mha,
162
+ scale_mha,
163
+ gate_mha,
164
+ shift_mlp,
165
+ scale_mlp,
166
+ gate_mlp,
167
+ ) = self.adaLN_modulation(c).chunk(6, dim=2)
168
+ else:
169
+ shift_mha, scale_mha, gate_mha, shift_mlp, scale_mlp, gate_mlp = 6 * (1.0,)
170
+
171
+ x = x + gate_mha * self.drop_path(
172
+ self.attn(
173
+ self.norm1(x) * scale_mha + shift_mha,
174
+ )
175
+ )
176
+ x = x + gate_mlp * self.drop_path(
177
+ self.mlp(self.norm2(x) * scale_mlp + shift_mlp)
178
+ )
179
+
180
+ return x
181
+
182
+
183
+ class SpectFormer(nn.Module):
184
+ def __init__(
185
+ self,
186
+ grid_size: int = 224 // 16,
187
+ embed_dim=768,
188
+ depth=12,
189
+ n_spectral_blocks=4,
190
+ num_heads: int = 8,
191
+ mlp_ratio=4.0,
192
+ uniform_drop=False,
193
+ drop_rate=0.0,
194
+ drop_path_rate=0.0,
195
+ window_size=2,
196
+ dp_rank=2,
197
+ norm_layer=nn.LayerNorm,
198
+ checkpoint_layers: list[int] | None = None,
199
+ rpe=False,
200
+ ensemble: int | None = None,
201
+ nglo: int = 0,
202
+ ):
203
+ """
204
+ Args:
205
+ img_size (int, tuple): input image size
206
+ patch_size (int, tuple): patch size
207
+ embed_dim (int): embedding dimension
208
+ depth (int): depth of transformer
209
+ n_spectral_blocks (int): number of spectral gating blocks
210
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
211
+ uniform_drop (bool): true for uniform, false for linearly increasing drop path probability.
212
+ drop_rate (float): dropout rate
213
+ drop_path_rate (float): drop path (stochastic depth) rate
214
+ window_size: window size for long/short attention
215
+ dp_rank: dp rank for long/short attention
216
+ norm_layer: (nn.Module): normalization layer for attention blocks
217
+ checkpoint_layers: indicate which layers to use for checkpointing
218
+ rpe: Use relative position encoding in Long-Short attention blocks.
219
+ ensemble: Integer indicating ensemble size or None for deterministic model.
220
+ nglo: Number of (additional) global tokens.
221
+ """
222
+ super().__init__()
223
+ self.embed_dim = embed_dim
224
+ self.n_spectral_blocks = n_spectral_blocks
225
+ self._checkpoint_layers = checkpoint_layers or []
226
+ self.ensemble = ensemble
227
+ self.nglo = nglo
228
+
229
+ h = grid_size
230
+ w = h // 2 + 1
231
+
232
+ if uniform_drop:
233
+ _logger.info(f"Using uniform droppath with expect rate {drop_path_rate}.")
234
+ dpr = [drop_path_rate for _ in range(depth)]
235
+ else:
236
+ _logger.info(
237
+ f"Using linear droppath with expect rate {drop_path_rate * 0.5}."
238
+ )
239
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
240
+
241
+ self.blocks_spectral_gating = nn.ModuleList()
242
+ self.blocks_attention = nn.ModuleList()
243
+ for i in range(depth):
244
+ if i < n_spectral_blocks:
245
+ layer = BlockSpectralGating(
246
+ dim=embed_dim,
247
+ mlp_ratio=mlp_ratio,
248
+ drop=drop_rate,
249
+ drop_path=dpr[i],
250
+ norm_layer=norm_layer,
251
+ h=h,
252
+ w=w,
253
+ )
254
+ self.blocks_spectral_gating.append(layer)
255
+ else:
256
+ layer = BlockAttention(
257
+ dim=embed_dim,
258
+ num_heads=num_heads,
259
+ mlp_ratio=mlp_ratio,
260
+ drop=drop_rate,
261
+ drop_path=dpr[i],
262
+ norm_layer=norm_layer,
263
+ w=window_size,
264
+ dp_rank=dp_rank,
265
+ rpe=rpe,
266
+ adaLN=True if ensemble is not None else False,
267
+ nglo=nglo,
268
+ )
269
+ self.blocks_attention.append(layer)
270
+
271
+ self.apply(self._init_weights)
272
+
273
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
274
+ """
275
+ Args:
276
+ tokens: Tensor of shape B, N, C for deterministic of BxE, N, C for ensemble forecast.
277
+ Returns:
278
+ Tensor of same shape as input.
279
+ """
280
+ if self.ensemble:
281
+ BE, N, C = tokens.shape
282
+ noise = torch.randn(
283
+ size=(BE, N, C), dtype=tokens.dtype, device=tokens.device
284
+ )
285
+ else:
286
+ noise = None
287
+
288
+ for i, blk in enumerate(
289
+ chain(self.blocks_spectral_gating, self.blocks_attention)
290
+ ):
291
+ if i in self._checkpoint_layers:
292
+ tokens = checkpoint(blk, tokens, noise, use_reentrant=False)
293
+ else:
294
+ tokens = blk(tokens, noise)
295
+
296
+ return tokens
297
+
298
+ def _init_weights(self, m):
299
+ if isinstance(m, nn.Linear):
300
+ trunc_normal_(m.weight, std=0.02)
301
+ if isinstance(m, nn.Linear) and m.bias is not None:
302
+ nn.init.constant_(m.bias, 0)
303
+ elif isinstance(m, nn.LayerNorm):
304
+ nn.init.constant_(m.bias, 0)
305
+ nn.init.constant_(m.weight, 1.0)
surya/models/transformer_ls.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 NVIDIA CORPORATION. Licensed under the MIT license.
2
+ # Written by Chen Zhu during an internship at NVIDIA, [email protected]
3
+ import math
4
+
5
+ from torch import nn
6
+ import torch
7
+ from timm.models.layers import trunc_normal_
8
+ import torch.nn.functional as F
9
+
10
+
11
+ class AttentionLS(nn.Module):
12
+ """Implementation for long-short term attention.
13
+ Flexible options for using window attention, global token and dynamic projection.
14
+
15
+ Args:
16
+ dim: input and output feature dimension.
17
+ num_heads: number of attention heads.
18
+ qkv_bias: whether to use bias for the projection of query, key and values.
19
+ qk_scale: scale factor on query and key for numerical stability.
20
+ By default, set to square root of head dimensions.
21
+ attn_drop: dropout probability for attention matrix.
22
+ proj_drop: dropout probability for the final output.
23
+ rpe: whether to use relative position encoding.
24
+ nglo: number of global tokens (e.g., CLS).
25
+
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ dim,
31
+ num_heads=8,
32
+ qkv_bias=False,
33
+ qk_scale=None,
34
+ attn_drop=0.0,
35
+ proj_drop=0.0,
36
+ rpe=False,
37
+ nglo=1,
38
+ dp_rank=2,
39
+ w=2,
40
+ ):
41
+ super().__init__()
42
+ self.num_heads = num_heads
43
+ head_dim = dim // num_heads
44
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
45
+ self.scale = qk_scale or head_dim**-0.5
46
+
47
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
48
+ self.attn_drop = nn.Dropout(attn_drop)
49
+ self.proj = nn.Linear(dim, dim)
50
+ self.proj_drop = nn.Dropout(proj_drop)
51
+ self.nglo = nglo
52
+
53
+ # Equals to segment size (w) in the paper.
54
+ self.window_size = w
55
+ # Equals to r in the paper.
56
+ self.dp_rank = dp_rank
57
+
58
+ if self.dp_rank > 0:
59
+ self.to_dynamic_projection = nn.Linear(dim, dp_rank * num_heads)
60
+ # The LN of DualLN corresponding to dynamic projection
61
+ self.dual_ln_dp = nn.LayerNorm(dim)
62
+ # The LN of DualLN corresponding to all the tokens
63
+ self.dual_ln_full = nn.LayerNorm(dim)
64
+
65
+ # Adapted from ViL: https://github.com/microsoft/vision-longformer/blob/main/src/models/layers/longformer2d.py#L55-L100
66
+ # We only add RPE to window attention.
67
+ # Unnecessary to add bias for global tokens, since DualLN already adds biases.
68
+ self.rpe = rpe
69
+ if rpe:
70
+ # handle the boarder conditions...
71
+ w_pad = int(w * 0.5)
72
+ self.local_relative_position_bias_table = nn.Parameter(
73
+ torch.zeros(2 * (w + w_pad - 1) * (2 * w_pad + w + 1) + 1, num_heads)
74
+ )
75
+ trunc_normal_(self.local_relative_position_bias_table, std=0.02)
76
+
77
+ # get pair-wise relative position index
78
+ coords_h = torch.arange(-w_pad, w_pad + w)
79
+ coords_w = torch.arange(-w_pad, w_pad + w)
80
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, 2w, 2w
81
+ coords = (
82
+ coords.view(2, (w + w_pad * 2) ** 2).transpose(0, 1).unsqueeze(0)
83
+ ) # 1, 4w**2, 2
84
+ q_coords_hw = torch.arange(0, w)
85
+ q_coords = torch.stack(
86
+ torch.meshgrid([q_coords_hw, q_coords_hw])
87
+ ) # 2, w, w
88
+ q_coords = q_coords.view(2, w**2).transpose(0, 1).unsqueeze(1) # w**2, 1, 2
89
+ relative_coords = q_coords - coords
90
+ relative_coords += w_pad + w - 1 # shift to start from 0
91
+ relative_coords[:, :, 0] *= 2 * w_pad + w
92
+ relative_position_index = relative_coords.sum(-1) # w^2, 4w^2
93
+ self.register_buffer("relative_position_index", relative_position_index)
94
+
95
+ def forward(self, x, nx=None, ny=None):
96
+ B, N, C = x.shape
97
+ N_feat = N - self.nglo
98
+ self.img_size = int(math.sqrt(N)) if nx is None else nx
99
+ qkv = self.qkv(x)
100
+ # query, key, value
101
+ q, k, v = qkv.chunk(3, dim=2)
102
+ q = q.mul(self.scale)
103
+
104
+ # Layer norm on the projected keys and values
105
+ k = self.dual_ln_full(k)
106
+ v = self.dual_ln_full(v)
107
+
108
+ # output size: bsz x n_heads x seqlen x d
109
+ if self.nglo > 0:
110
+ q_cls, q = q[:, : self.nglo], q[:, self.nglo :]
111
+ k_cls, k = k[:, : self.nglo], k[:, self.nglo :]
112
+ v_cls, v = v[:, : self.nglo], v[:, self.nglo :]
113
+
114
+ q_cls = q_cls.reshape(
115
+ B, self.nglo, self.num_heads, C // self.num_heads
116
+ ).transpose(1, 2)
117
+ k_cls = k_cls.reshape(
118
+ B, self.nglo, self.num_heads, C // self.num_heads
119
+ ).transpose(1, 2)
120
+ v_cls = v_cls.reshape(
121
+ B, self.nglo, self.num_heads, C // self.num_heads
122
+ ).transpose(1, 2)
123
+
124
+ q = q.reshape(B, N_feat, self.num_heads, C // self.num_heads).transpose(1, 2)
125
+ k = k.reshape(B, N_feat, self.num_heads, C // self.num_heads).transpose(1, 2)
126
+ v = v.reshape(B, N_feat, self.num_heads, C // self.num_heads).transpose(1, 2)
127
+
128
+ # Long-range Attention (Dynamic Projection)
129
+ if self.dp_rank > 0:
130
+ # b x h x r x (l w)
131
+ # Compute the projection matrix (P_i in the paper)
132
+ c_scores = (
133
+ self.to_dynamic_projection(x[:, self.nglo :])
134
+ .transpose(1, 2)
135
+ .contiguous()
136
+ .view(B, self.num_heads, self.dp_rank, -1)
137
+ )
138
+ # c_scores = c_scores.softmax(dim=-1, dtype=torch.float32).to(x)
139
+ c_scores = c_scores.softmax(dim=-1).to(
140
+ x
141
+ ) # Changed when experimenting with mixed precision (Johannes S.)
142
+ # b x h x r x d
143
+ k_lms = c_scores.matmul(k)
144
+ k_lms = k_lms.transpose(1, 2).contiguous().view(B, self.dp_rank, -1)
145
+ k_lms = (
146
+ self.dual_ln_dp(k_lms)
147
+ .view(B, self.dp_rank, self.num_heads, -1)
148
+ .contiguous()
149
+ .permute(0, 2, 3, 1)
150
+ )
151
+ # b x h x (lw) x r
152
+ dots_all = q.matmul(k_lms)
153
+
154
+ if self.window_size > 0:
155
+ # Switch the order of dimensions if using window attention.
156
+ dots_all = self.group_dots(dots_all)
157
+ else:
158
+ dots_all = None
159
+
160
+ # Short-term Attention (Window Attention)
161
+ # In our window attention, each token attends to at most (4w^2) tokens.
162
+ if self.window_size > 0:
163
+ dots_win = self.compute_window_scores(q, k)
164
+ w2 = int(self.window_size * self.window_size)
165
+
166
+ if self.rpe:
167
+ w_pad = int(0.5 * self.window_size)
168
+ local_relative_position_bias = self.local_relative_position_bias_table[
169
+ self.relative_position_index.view(-1)
170
+ ].view(
171
+ 1, w2, (w_pad * 2 + self.window_size) ** 2, -1
172
+ ) # w^2, kv_nums,H
173
+ local_relative_position_bias = (
174
+ local_relative_position_bias.permute(0, 3, 1, 2)
175
+ .expand(B, -1, -1, -1)
176
+ .unsqueeze(2)
177
+ .unsqueeze(2)
178
+ )
179
+
180
+ dots_win += local_relative_position_bias
181
+ if dots_all is None:
182
+ dots_all = dots_win
183
+ else:
184
+ dots_all = torch.cat([dots_all, dots_win], dim=-1)
185
+
186
+ # Global token.
187
+ if self.nglo > 0:
188
+ # and compute the scores of queries on CLS
189
+ dots_q_cls = q.matmul(k_cls.transpose(-1, -2))
190
+
191
+ if self.window_size > 0:
192
+ dots_q_cls = self.group_dots(dots_q_cls)
193
+ dots_all = torch.cat([dots_all, dots_q_cls], dim=-1)
194
+
195
+ # attn = dots_all.softmax(dim=-1, dtype=torch.float32).to(x)
196
+ attn = dots_all.softmax(dim=-1).to(
197
+ x
198
+ ) # Changed when experimenting with mixed precision (Johannes S.)
199
+ attn = self.attn_drop(attn)
200
+ out = 0
201
+ if self.window_size > 0:
202
+ offset = max(0, self.dp_rank)
203
+ kv_group_size = self.window_size
204
+ total_win_size = max(1, self.window_size // 2) * 2 + kv_group_size
205
+ attn_win = attn[:, :, :, :, :, offset : offset + total_win_size**2]
206
+ out += self.compute_window_pv(attn_win, v)
207
+ attn = self.ungroup_dots(attn)
208
+
209
+ # attn will be b x h x lw x n_k from now on
210
+ if self.dp_rank > 0:
211
+ attn_lm = attn[:, :, :, : self.dp_rank]
212
+ v_lms = (
213
+ # c_scores.matmul(v.float())
214
+ c_scores.matmul(
215
+ v
216
+ ) # Changed when experimenting with mixed precision (Johannes S.)
217
+ .to(v)
218
+ .transpose(1, 2)
219
+ .contiguous()
220
+ .view(B, self.dp_rank, -1)
221
+ )
222
+ v_lms = (
223
+ self.dual_ln_dp(v_lms)
224
+ .view(B, self.dp_rank, self.num_heads, -1)
225
+ .contiguous()
226
+ .transpose(1, 2)
227
+ )
228
+
229
+ out += attn_lm.matmul(v_lms)
230
+
231
+ if self.nglo > 0:
232
+ attn_cls = attn[:, :, :, -self.nglo :]
233
+ out += attn_cls.matmul(
234
+ v_cls
235
+ ) # Changed. Was `.mul` instead of `.matmul`. (JWS)
236
+
237
+ # b x h x 1 x lw
238
+ cls_inner = q_cls.matmul(k_cls.transpose(-1, -2))
239
+ cls_dots = q_cls.matmul(
240
+ k.transpose(-1, -2)
241
+ ) # Changed. Was `out` instead of `k`. (JWS)
242
+ cls_dots = torch.cat([cls_inner, cls_dots], dim=-1)
243
+
244
+ # cls_dots = cls_dots.softmax(dim=-1, dtype=torch.float32).to(x)
245
+ cls_dots = cls_dots.softmax(dim=-1).to(
246
+ x
247
+ ) # Changed when experimenting with mixed precision (Johannes S.)
248
+ cls_next = cls_dots[:, :, :, self.nglo :].matmul(
249
+ v
250
+ ) # the post_cls variant # Changed. Was `out` instead of `v`. (JWS)
251
+ cls_next += cls_dots[:, :, :, : self.nglo].matmul(v_cls)
252
+
253
+ out = torch.cat([cls_next, out], dim=2)
254
+ out = out.transpose(1, 2).contiguous().view(B, N, -1)
255
+
256
+ # x = (attn @ v).transpose(1, 2).reshape(B, N, C)
257
+ out = self.proj(out)
258
+ out = self.proj_drop(out)
259
+ return out
260
+
261
+ def compute_window_scores(self, q, k):
262
+ """Compute the inner products for the window attention.
263
+ Frist, divide the query into non-overlapping windows.
264
+ Then, use torch.as_trided (implemented in self.get_overlapping_tiles) to create a view of the keys
265
+ that corresponds to the windows with at most 2x memory overhead.
266
+ Finally, compute the inner product.
267
+ """
268
+ # q: b h (l w) d
269
+ b, h, _, d = q.shape
270
+ side_size = max(self.window_size // 2, 1)
271
+ # q_group_size: segment size
272
+ kv_width = 2 * side_size + self.window_size # assuming q_stride=1
273
+ q_n_group = self.img_size // self.window_size
274
+ q_tiles = q.reshape(
275
+ b, h, q_n_group, self.window_size, q_n_group, self.window_size, d
276
+ ).permute(0, 1, 2, 4, 3, 5, 6)
277
+ # q_tiles: b x h x n_group x n_group x w^2 x d
278
+ q_tiles = q_tiles.contiguous().view(b, h, q_n_group, q_n_group, -1, d)
279
+
280
+ # k_tiles: b x h x n_group x n_group x 9w^2 x d
281
+ k_tiles = (
282
+ self.get_overlapping_tiles(k)
283
+ .contiguous()
284
+ .view(b, h, q_n_group, q_n_group, -1, d)
285
+ )
286
+ # dot_tiles: b x h x n_group x n_group x w^2 x 9w^2
287
+ dot_tiles = q_tiles.matmul(k_tiles.transpose(-1, -2))
288
+
289
+ # fill "-inf" into the zero-padding parts
290
+ dot_tiles = dot_tiles.view(b, h, q_n_group, q_n_group, -1, kv_width, kv_width)
291
+
292
+ dot_tiles[:, :, 0, :, :, :side_size].fill_(float("-inf"))
293
+ dot_tiles[:, :, -1, :, :, -side_size:].fill_(float("-inf"))
294
+ dot_tiles[:, :, :, 0, :, :, :side_size].fill_(float("-inf"))
295
+ dot_tiles[:, :, :, -1, :, :, -side_size:].fill_(float("-inf"))
296
+
297
+ dot_tiles = dot_tiles.view(b, h, q_n_group, q_n_group, -1, kv_width**2)
298
+ return dot_tiles
299
+
300
+ def get_overlapping_tiles(self, x):
301
+ """Get overlapping tiles in the 2D spatial domain, ensuring each query computes correlation with all neighbors"""
302
+ # x: b h (l w) d
303
+ b, h, _, d = x.shape
304
+ side_size = max(self.window_size // 2, 1)
305
+ total_size = 2 * side_size + self.window_size
306
+ kv_group_size = self.window_size
307
+ kv_width = self.img_size
308
+
309
+ x = x.view(b, h, kv_width, kv_width, d)
310
+ x = F.pad(x, [0, 0, side_size, side_size, side_size, side_size], value=0)
311
+
312
+ out_shape = [
313
+ b,
314
+ h,
315
+ kv_width // kv_group_size,
316
+ kv_width // kv_group_size,
317
+ total_size,
318
+ total_size,
319
+ d,
320
+ ]
321
+ in_stride = x.stride()
322
+ out_stride = [
323
+ in_stride[0],
324
+ in_stride[1],
325
+ in_stride[2] * kv_group_size,
326
+ in_stride[3] * kv_group_size,
327
+ in_stride[2],
328
+ in_stride[3],
329
+ in_stride[4],
330
+ ]
331
+
332
+ # note we ignored the boundary here
333
+ return x.as_strided(size=out_shape, stride=out_stride)
334
+
335
+ def compute_window_pv(self, attn, v):
336
+ """Compute the inner product of attention matrix and the values for the window attention."""
337
+ b, h, n_group, _, w2, n_k = attn.shape
338
+ d = v.shape[-1]
339
+ v_tiles = (
340
+ self.get_overlapping_tiles(v)
341
+ .contiguous()
342
+ .view(b, h, n_group, n_group, -1, d)
343
+ )
344
+
345
+ # b x h x n_group x n_group x w^2 x d
346
+ pv = attn.matmul(v_tiles)
347
+ # return: b x h x (lw) x d
348
+ ret = self.ungroup_dots(pv)
349
+
350
+ return ret
351
+
352
+ def group_dots(self, dots):
353
+ b, h = dots.shape[:2]
354
+ n_group = self.img_size // self.window_size
355
+ dots = dots.reshape(
356
+ b, h, n_group, self.window_size, n_group, self.window_size, -1
357
+ ).permute(0, 1, 2, 4, 3, 5, 6)
358
+ dots = dots.contiguous().view(
359
+ b, h, n_group, n_group, self.window_size * self.window_size, -1
360
+ )
361
+ return dots
362
+
363
+ def ungroup_dots(self, dots):
364
+ b, h, n_group, _, _, n_keys = dots.shape
365
+ dots = dots.reshape(
366
+ b, h, n_group, n_group, self.window_size, self.window_size, -1
367
+ ).permute(0, 1, 2, 4, 3, 5, 6)
368
+ dots = dots.contiguous().view(b, h, -1, n_keys)
369
+ return dots
surya/utils/__init__.py ADDED
File without changes
surya/utils/config.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from argparse import Namespace
3
+
4
+ import yaml
5
+
6
+
7
+ class DataConfig:
8
+ def __init__(
9
+ self,
10
+ train_data_path: str,
11
+ valid_data_path: str,
12
+ batch_size: int,
13
+ num_data_workers: int,
14
+ prefetch_factor: int,
15
+ time_delta_input_minutes: list[int],
16
+ n_input_timestamps: int | None = None,
17
+ pooling: int | None = None,
18
+ random_vert_flip: bool = False,
19
+ **kwargs,
20
+ ):
21
+ self.__dict__.update(kwargs)
22
+
23
+ self.train_data_path = train_data_path
24
+ self.valid_data_path = valid_data_path
25
+ self.batch_size = batch_size
26
+ self.num_data_workers = num_data_workers
27
+ self.prefetch_factor = prefetch_factor
28
+ self.time_delta_input_minutes = sorted(time_delta_input_minutes)
29
+ self.n_input_timestamps = n_input_timestamps
30
+ self.pooling = pooling
31
+ self.random_vert_flip = random_vert_flip
32
+
33
+ if self.n_input_timestamps is None:
34
+ self.n_input_timestamps = len(self.time_delta_input_minutes)
35
+
36
+ assert (
37
+ self.n_input_timestamps > 0
38
+ ), "Number of input timestamps must be greater than 0."
39
+ assert self.n_input_timestamps <= len(self.time_delta_input_minutes), (
40
+ f"Cannot sample {self.n_input_timestamps} from list of "
41
+ f"{self.time_delta_input_minutes} input timestamps."
42
+ )
43
+
44
+ def to_dict(self):
45
+ return self.__dict__
46
+
47
+ @staticmethod
48
+ def from_argparse(args: Namespace):
49
+ return DataConfig(**args.__dict__)
50
+
51
+ def __str__(self):
52
+ return (
53
+ f"Training index: {self.train_data_path}, "
54
+ f"Validation index: {self.valid_data_path}, "
55
+ )
56
+
57
+ def __repr__(self):
58
+ return (
59
+ f"Training index: {self.train_data_path}, "
60
+ f"Validation index: {self.valid_data_path}, "
61
+ )
62
+
63
+
64
+ class ModelConfig:
65
+ def __init__(
66
+ self,
67
+ # enc_num_layers: int,
68
+ # enc_num_heads: int,
69
+ # enc_embed_size: int,
70
+ # dec_num_layers: int,
71
+ # dec_num_heads: int,
72
+ # dec_embed_size: int,
73
+ # mask_ratio: float,
74
+ **kwargs,
75
+ ):
76
+ self.__dict__.update(kwargs)
77
+
78
+ # self.enc_num_layers = enc_num_layers
79
+ # self.enc_num_heads = enc_num_heads
80
+ # self.enc_embed_size = enc_embed_size
81
+ # self.dec_num_layers = dec_num_layers
82
+ # self.dec_num_heads = dec_num_heads
83
+ # self.dec_embed_size = dec_embed_size
84
+ # self.mlp_ratio = 0.0
85
+ # self.mask_ratio = mask_ratio
86
+
87
+ self.__dict__.update(kwargs)
88
+
89
+ def to_dict(self):
90
+ return self.__dict__
91
+
92
+ @staticmethod
93
+ def from_argparse(args: Namespace):
94
+ return ModelConfig(**args.__dict__)
95
+
96
+ @property
97
+ def encoder_d_ff(self):
98
+ return int(self.enc_embed_size * self.mlp_ratio)
99
+
100
+ @property
101
+ def decoder_d_ff(self):
102
+ return int(self.dec_embed_size * self.mlp_ratio)
103
+
104
+ def __str__(self):
105
+ return (
106
+ f"Input channels: {self.model.in_channels}, "
107
+ f"Encoder (L, H, E): {[self.enc_num_layers, self.enc_num_heads, self.enc_embed_size]}, "
108
+ f"Decoder (L, H, E): {[self.dec_num_layers, self.dec_num_heads, self.dec_embed_size]}"
109
+ )
110
+
111
+ def __repr__(self):
112
+ return (
113
+ f"Input channels: {self.model.in_channels}, "
114
+ f"Encoder (L, H, E): {[self.enc_num_layers, self.enc_num_heads, self.enc_embed_size]}, "
115
+ f"Decoder (L, H, E): {[self.dec_num_layers, self.dec_num_heads, self.dec_embed_size]}"
116
+ )
117
+
118
+
119
+ class OptimizerConfig:
120
+ def __init__(
121
+ self,
122
+ warm_up_steps: int,
123
+ max_epochs: int,
124
+ learning_rate: float,
125
+ min_lr: float,
126
+ ):
127
+ self.warm_up_steps = warm_up_steps
128
+ self.max_epochs = max_epochs
129
+ self.learning_rate = learning_rate
130
+ self.min_lr = min_lr
131
+
132
+ def to_dict(self):
133
+ return self.__dict__
134
+
135
+ @staticmethod
136
+ def from_argparse(args: Namespace):
137
+ return ModelConfig(**args.__dict__)
138
+
139
+ def __str__(self):
140
+ return (
141
+ f"Epochs: {self.max_epochs}, "
142
+ f"LR: {[self.learning_rate, self.min_lr]}, "
143
+ f"Warm up: {self.warm_up_steps},"
144
+ )
145
+
146
+ def __repr__(self):
147
+ return (
148
+ f"Epochs: {self.max_epochs}, "
149
+ f"LR: {[self.learning_rate, self.min_lr]}, "
150
+ f"Warm up: {self.warm_up_steps},"
151
+ )
152
+
153
+
154
+ class ExperimentConfig:
155
+ def __init__(
156
+ self,
157
+ job_id: str,
158
+ data_config: DataConfig,
159
+ model_config: ModelConfig,
160
+ optimizer_config: OptimizerConfig,
161
+ path_experiment: str,
162
+ parallelism: str,
163
+ from_checkpoint: str | None = None,
164
+ **kwargs,
165
+ ):
166
+ # additional experiment parameters used in downstream tasks
167
+ self.__dict__.update(kwargs)
168
+
169
+ self.job_id = job_id
170
+ self.data = data_config
171
+ self.model = model_config
172
+ self.optimizer = optimizer_config
173
+ self.path_experiment = path_experiment
174
+ self.from_checkpoint = from_checkpoint
175
+ self.parallelism = parallelism
176
+
177
+ assert self.model.in_channels == len(self.data.channels), (
178
+ f"Number of model input channels ({self.model.in_channels}) must be "
179
+ f"equal to number of input variables ({len(self.data.channels)})."
180
+ )
181
+ if self.model.time_embedding["type"] == "linear":
182
+ assert (
183
+ self.model.time_embedding["time_dim"] == self.data.n_input_timestamps
184
+ ), "Time dimension of linear embedding must be equal to number of input timestamps."
185
+ if self.rollout_steps > 0:
186
+ assert self.data.n_input_timestamps == len(
187
+ self.data.time_delta_input_minutes
188
+ ), "Rollout does not support randomly sampled input timestamps."
189
+
190
+ metrics_channels = []
191
+ for field1, value1 in self.metrics["train_metrics_config"].items():
192
+ for field2, value2 in self.metrics["train_metrics_config"][field1].items():
193
+ if field2 == "metrics":
194
+ for metric_definition in value2:
195
+ split_metric_definition = metric_definition.split(":")
196
+ channels = (
197
+ split_metric_definition[2]
198
+ if len(split_metric_definition) > 2
199
+ else None
200
+ )
201
+ if channels is not None:
202
+ metrics_channels = metrics_channels + channels.split("...")
203
+
204
+ for field1, value1 in self.metrics["validation_metrics_config"].items():
205
+ for field2, value2 in self.metrics["validation_metrics_config"][
206
+ field1
207
+ ].items():
208
+ if field2 == "metrics":
209
+ for metric_definition in value2:
210
+ split_metric_definition = metric_definition.split(":")
211
+ channels = (
212
+ split_metric_definition[2]
213
+ if len(split_metric_definition) > 2
214
+ else None
215
+ )
216
+ if channels is not None:
217
+ metrics_channels = metrics_channels + channels.replace(
218
+ "...", "&"
219
+ ).split("&")
220
+
221
+ assert set(metrics_channels).issubset(self.data.channels), (
222
+ f"{set(metrics_channels).difference(self.data.channels)} "
223
+ f"not part of data input channels."
224
+ )
225
+
226
+ assert self.parallelism in [
227
+ "ddp",
228
+ "fsdp",
229
+ ], 'Valid choices for `parallelism` are "ddp" and "fsdp".'
230
+
231
+ @property
232
+ def path_checkpoint(self) -> str:
233
+ if self.path_experiment == "":
234
+ return os.path.join(self.path_weights, "train", "checkpoint.pt")
235
+ else:
236
+ return os.path.join(
237
+ os.path.dirname(self.path_experiment),
238
+ "weights",
239
+ "train",
240
+ "checkpoint.pt",
241
+ )
242
+
243
+ @property
244
+ def path_weights(self) -> str:
245
+ return os.path.join(self.path_experiment, self.make_suffix_path(), "weights")
246
+
247
+ @property
248
+ def path_states(self) -> str:
249
+ return os.path.join(self.path_experiment, self.make_suffix_path(), "states")
250
+
251
+ def to_dict(self):
252
+ d = self.__dict__.copy()
253
+ d["model"] = self.model.to_dict()
254
+ d["data"] = self.data.to_dict()
255
+
256
+ return d
257
+
258
+ @staticmethod
259
+ def from_argparse(args: Namespace):
260
+ return ExperimentConfig(
261
+ data_config=DataConfig.from_argparse(args),
262
+ model_config=ModelConfig.from_argparse(args),
263
+ optimizer_config=OptimizerConfig.from_argparse(args),
264
+ **args.__dict__,
265
+ )
266
+
267
+ @staticmethod
268
+ def from_dict(params: dict):
269
+ return ExperimentConfig(
270
+ data_config=DataConfig(**params["data"]),
271
+ model_config=ModelConfig(**params["model"]),
272
+ optimizer_config=OptimizerConfig(**params["optimizer"]),
273
+ **params,
274
+ )
275
+
276
+ def make_folder_name(self) -> str:
277
+ param_folder = "wpt-c1-s1"
278
+ return param_folder
279
+
280
+ def make_suffix_path(self) -> str:
281
+ return os.path.join(self.job_id)
282
+
283
+ def __str__(self):
284
+ return (
285
+ f"ID: {self.job_id}, "
286
+ f"Epochs: {self.optimizer.max_epochs}, "
287
+ f"Batch size: {self.data.batch_size}, "
288
+ f"LR: {[self.optimizer.learning_rate, self.optimizer.min_lr]}, "
289
+ f"Warm up: {self.optimizer.warm_up_steps},"
290
+ f"DL workers: {self.data.num_data_workers},"
291
+ f"Parallelism: {self.parallelism}"
292
+ )
293
+
294
+ def __repr__(self):
295
+ return (
296
+ f"ID: {self.job_id}, "
297
+ f"Epochs: {self.optimizer.max_epochs}, "
298
+ f"Batch size: {self.data.batch_size}, "
299
+ f"LR: {[self.optimizer.learning_rate, self.optimizer.min_lr]}, "
300
+ f"Warm up: {self.optimizer.warm_up_steps},"
301
+ f"DL workers: {self.data.num_data_workers},"
302
+ f"Parallelism: {self.parallelism}"
303
+ )
304
+
305
+
306
+ def get_config(
307
+ config_path: str,
308
+ ) -> ExperimentConfig:
309
+ cfg = yaml.safe_load(open(config_path, "r"))
310
+ cfg["data"]["scalers"] = yaml.safe_load(open(cfg["data"]["scalers_path"], "r"))
311
+ return ExperimentConfig.from_dict(params=cfg)
surya/utils/data.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from surya.datasets.transformations import Transformation, StandardScaler
7
+ from surya.utils.config import DataConfig
8
+ from surya.utils.misc import class_from_name, view_as_windows
9
+
10
+
11
+ def custom_collate_fn(batch):
12
+ """
13
+ Custom collate function for handling batches of data and metadata in a PyTorch DataLoader.
14
+
15
+ This function separately processes the data and metadata from the input batch.
16
+
17
+ - The `data_batch` is collated using PyTorch's `default_collate`. If collation fails due to incompatible data types,
18
+ the batch is returned as-is.
19
+
20
+ - The `metadata_batch` is assumed to be a dictionary, where each key corresponds to a list of values across the batch.
21
+ Each key is collated using `default_collate`. If collation fails for a particular key, the original list of values
22
+ is retained.
23
+
24
+ Example usage for accessing collated metadata:
25
+ - `collated_metadata['timestamps_input'][batch_idx][input_time]`
26
+ - `collated_metadata['timestamps_input'][batch_idx][rollout_step]`
27
+
28
+ Args:
29
+ batch (list of tuples): Each tuple contains (data, metadata), where:
30
+ - `data` is a tensor or other data structure used for training.
31
+ - `metadata` is a dictionary containing additional information.
32
+
33
+ Returns:
34
+ tuple: (collated_data, collated_metadata)
35
+ - `collated_data`: The processed batch of data.
36
+ - `collated_metadata`: The processed batch of metadata.
37
+ """
38
+
39
+ # Unpack batch into separate lists of data and metadata
40
+ data_batch, metadata_batch = zip(*batch)
41
+
42
+ # Attempt to collate the data batch using PyTorch's default collate function
43
+ try:
44
+ collated_data = torch.utils.data.default_collate(data_batch)
45
+ except TypeError:
46
+ # If default_collate fails (e.g., due to incompatible types), return the data batch as-is
47
+ collated_data = data_batch
48
+
49
+ # Handle metadata collation
50
+ if isinstance(metadata_batch[0], dict):
51
+ collated_metadata = {}
52
+ for key in metadata_batch[0].keys():
53
+ values = [d[key] for d in metadata_batch]
54
+ try:
55
+ # Attempt to collate values under the current key
56
+ collated_metadata[key] = torch.utils.data.default_collate(values)
57
+ except TypeError:
58
+ # If collation fails, keep the values as a list
59
+ collated_metadata[key] = values
60
+ else:
61
+ # If metadata is not a dictionary, try to collate it as a whole
62
+ try:
63
+ collated_metadata = torch.utils.data.default_collate(metadata_batch)
64
+ except TypeError:
65
+ # If collation fails, return metadata as-is
66
+ collated_metadata = metadata_batch
67
+
68
+ return collated_data, collated_metadata
69
+
70
+
71
+ def calc_num_windows(raw_size: int, win_size: int, stride: int) -> int:
72
+ return (raw_size - win_size) // stride + 1
73
+
74
+
75
+ def get_scalers_info(dataset) -> dict:
76
+ return {
77
+ k: (type(v).__module__, type(v).__name__, v.to_dict())
78
+ for k, v in dataset.scalers.items()
79
+ }
80
+
81
+
82
+ def build_scalers_pressure(info: dict) -> Dict[str, Transformation]:
83
+ ret_dict = {k: dict() for k in info.keys()}
84
+ for var_key, var_d in info.items():
85
+ for p_key, p_val in var_d.items():
86
+ ret_dict[var_key][p_key] = class_from_name(
87
+ p_val["base"], p_val["class"]
88
+ ).from_dict(p_val)
89
+ return ret_dict
90
+
91
+
92
+ def build_scalers(info: dict) -> Dict[str, Transformation]:
93
+ ret_dict = {k: None for k in info.keys()}
94
+ for p_key, p_val in info.items():
95
+ ret_dict[p_key]: StandardScaler = class_from_name(
96
+ p_val["base"], p_val["class"]
97
+ ).from_dict(p_val)
98
+ return ret_dict
99
+
100
+
101
+ def break_batch_5d(
102
+ data: list, lat_size: int, lon_size: int, time_steps: int
103
+ ) -> np.ndarray:
104
+ """
105
+ data: list of samples, each sample is [C, T, L, H, W]
106
+ """
107
+ num_levels = data[0].shape[2]
108
+ num_vars = data[0].shape[0]
109
+ big_batch = np.stack(data, axis=0)
110
+ vw = view_as_windows(
111
+ big_batch,
112
+ [1, num_vars, time_steps, num_levels, lat_size, lon_size],
113
+ step=[1, num_vars, time_steps, num_levels, lat_size, lon_size],
114
+ ).squeeze()
115
+ # To check if it is correctly reshaping
116
+ # idx = 30
117
+ # (big_batch[0, :, idx:idx+2, :, 40:80, 40:80]-vw[idx//2, 1, 1]).sum()
118
+ vw = vw.reshape(-1, num_vars, time_steps, num_levels, lat_size, lon_size)
119
+ # How to test:
120
+ # (big_batch[0, :, :2, :, :40, :40] - vw[0]).sum()
121
+ # (big_batch[0, :, :2, :, :40, 40:80] - vw[1]).sum()
122
+ # (big_batch[0, :, :2, :, 40:80, :40] - vw[2]).sum()
123
+
124
+ # Need to move axis because Weather model is expecting [C, L, T, H, W] instead of [C, T, L, H, W]
125
+ vw = np.moveaxis(vw, 3, 2)
126
+ vw = torch.tensor(vw, dtype=torch.float32)
127
+ return vw
128
+
129
+
130
+ def break_batch_5d_aug(data: list, cfg: DataConfig, max_batch: int = 256) -> np.ndarray:
131
+ num_levels = data[0].shape[2]
132
+ num_vars = data[0].shape[0]
133
+ big_batch = np.stack(data, axis=0)
134
+
135
+ y_step, x_step, t_step = (
136
+ cfg.patch_size_lat // 2,
137
+ cfg.patch_size_lon // 2,
138
+ cfg.patch_size_time // 2,
139
+ )
140
+ y_max = calc_num_windows(big_batch.shape[4], cfg.input_size_lat, y_step)
141
+ x_max = calc_num_windows(big_batch.shape[5], cfg.input_size_lon, x_step)
142
+ t_max = calc_num_windows(big_batch.shape[2], cfg.input_size_time, t_step)
143
+ max_batch = min(max_batch, y_max * x_max * t_max)
144
+
145
+ batch = np.empty(
146
+ (
147
+ max_batch,
148
+ num_vars,
149
+ cfg.input_size_time,
150
+ num_levels,
151
+ cfg.input_size_lat,
152
+ cfg.input_size_lon,
153
+ ),
154
+ dtype=np.float32,
155
+ )
156
+ for j, i in enumerate(np.random.permutation(np.arange(max_batch))):
157
+ t, y, x = np.unravel_index(
158
+ i,
159
+ (
160
+ t_max,
161
+ y_max,
162
+ x_max,
163
+ ),
164
+ )
165
+ batch[j] = big_batch[
166
+ :, # batch_id
167
+ :, # vars
168
+ t * t_step : t * t_step + cfg.input_size_time,
169
+ :, # levels
170
+ y * y_step : y * y_step + cfg.input_size_lat,
171
+ x * x_step : x * x_step + cfg.input_size_lon,
172
+ ]
173
+
174
+ batch = np.moveaxis(batch, 3, 2)
175
+ batch = torch.tensor(batch, dtype=torch.float32)
176
+ return batch
surya/utils/distributed.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from datetime import timedelta
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.backends.cudnn as cudnn
8
+ import torch.distributed as dist
9
+ import torch.nn as nn
10
+ from torch.distributed import checkpoint as dist_checkpoint
11
+ from torch.distributed import fsdp
12
+
13
+ import functools
14
+ import itertools
15
+
16
+ from torch.utils.data.distributed import DistributedSampler
17
+ from torch.utils.data import Dataset
18
+ from typing import Any, Dict, Optional
19
+
20
+ from surya.utils.schemas import TrainState
21
+
22
+
23
+ def init_dist(device: str, rank: int, world_size: int):
24
+ torch.distributed.init_process_group(
25
+ device,
26
+ init_method="env://",
27
+ world_size=world_size,
28
+ rank=rank,
29
+ timeout=timedelta(minutes=60),
30
+ )
31
+
32
+
33
+ def init_ddp(use_gpu: bool):
34
+ local_rank = int(os.environ["LOCAL_RANK"])
35
+ rank = int(os.environ["RANK"])
36
+ world_size = int(os.environ["WORLD_SIZE"])
37
+
38
+ if use_gpu:
39
+ assert (
40
+ torch.cuda.is_available()
41
+ ), "GPU requested but none was found in the system."
42
+
43
+ if use_gpu:
44
+ init_dist("nccl", rank, world_size)
45
+ torch.cuda.set_device(local_rank)
46
+ os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
47
+ os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = str(1)
48
+ os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
49
+ cudnn.benchmark = True
50
+ else:
51
+ init_dist("gloo", rank, world_size)
52
+ return local_rank, rank
53
+
54
+
55
+ def set_global_seed(rank):
56
+ random.seed(42 + rank)
57
+ torch.cuda.manual_seed(42 + rank)
58
+ torch.manual_seed(42 + rank)
59
+ np.random.seed(42 + rank)
60
+
61
+
62
+ def is_dist_avail_and_initialized():
63
+ if not dist.is_available():
64
+ return False
65
+ if not dist.is_initialized():
66
+ return False
67
+ return True
68
+
69
+
70
+ def get_world_size():
71
+ if not is_dist_avail_and_initialized():
72
+ return 1
73
+ return dist.get_world_size()
74
+
75
+
76
+ def get_rank():
77
+ if not is_dist_avail_and_initialized():
78
+ return 0
79
+ return dist.get_rank()
80
+
81
+
82
+ def is_main_process():
83
+ return get_rank() == 0
84
+
85
+
86
+ # def save_model_singular(model, *args, **kwargs):
87
+ # """Stream all model parameters to rank 0 on the CPU, then pass all
88
+ # other given arguments to `torch.save` to save the model, but only on
89
+ # the root process.
90
+ # """
91
+ # save_policy = fsdp.FullStateDictConfig(
92
+ # offload_to_cpu=True, rank0_only=True)
93
+ # with fsdp.FullyShardedDataParallel.state_dict_type(
94
+ # model,
95
+ # fsdp.StateDictType.FULL_STATE_DICT,
96
+ # save_policy,
97
+ # ):
98
+ # cpu_state = model.state_dict()
99
+ # # We do *not* want to write to the same location with multiple
100
+ # # processes at the same time.
101
+ # if is_root_process():
102
+ # torch.save(cpu_state, *args, **kwargs)
103
+
104
+
105
+ def save_model(model, save_dir):
106
+ """Obtain sharded model parameters from the GPU, then save the model
107
+ as a distributed checkpoint to the given directory. Saving a
108
+ distributed checkpoint means that the checkpoint will be split into
109
+ individual files, one for each process.
110
+ """
111
+ state_dict_config = fsdp.ShardedStateDictConfig(offload_to_cpu=False)
112
+ with fsdp.FullyShardedDataParallel.state_dict_type(
113
+ model,
114
+ fsdp.StateDictType.SHARDED_STATE_DICT,
115
+ state_dict_config,
116
+ ):
117
+ cp_state_dict = {"model": model.state_dict()}
118
+ dist_checkpoint.save_state_dict(
119
+ cp_state_dict,
120
+ dist_checkpoint.FileSystemWriter(save_dir),
121
+ )
122
+
123
+
124
+ def load_model(model, load_dir):
125
+ """Set the given model's state dictionary in-place from the given
126
+ distributed checkpoint directory.
127
+ """
128
+ state_dict_config = fsdp.ShardedStateDictConfig(offload_to_cpu=False)
129
+ with fsdp.FullyShardedDataParallel.state_dict_type(
130
+ model,
131
+ fsdp.StateDictType.SHARDED_STATE_DICT,
132
+ state_dict_config,
133
+ ):
134
+ cp_state_dict = {"model": model.state_dict()}
135
+ dist_checkpoint.load_state_dict(
136
+ cp_state_dict,
137
+ dist_checkpoint.FileSystemReader(load_dir),
138
+ )
139
+ model.load_state_dict(cp_state_dict["model"])
140
+
141
+
142
+ @functools.lru_cache(maxsize=None)
143
+ def is_root_process():
144
+ """Return whether this process is the root process."""
145
+ return torch.distributed.get_rank() == 0
146
+
147
+
148
+ # The reason we define this is that `torch.distributed` does not
149
+ # implement it; for the global rank, there's
150
+ # `torch.distributed.get_rank()`.
151
+ @functools.lru_cache(maxsize=None)
152
+ def get_local_rank():
153
+ """Return the local rank of this process."""
154
+ return int(os.getenv("LOCAL_RANK"))
155
+
156
+
157
+ def print0(*args, **kwargs):
158
+ """Print something only on the root process."""
159
+ if (not dist.is_initialized()) or is_root_process():
160
+ print(*args, **kwargs)
161
+
162
+
163
+ def save_model_singular(model, save_path, parallelism, *args, **kwargs):
164
+ """Stream all model parameters to rank 0 on the CPU, then pass all
165
+ other given arguments to `torch.save` to save the model, but only on
166
+ the root process.
167
+ """
168
+
169
+ match parallelism:
170
+ case "fsdp":
171
+ save_policy = fsdp.FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
172
+ with fsdp.FullyShardedDataParallel.state_dict_type(
173
+ model,
174
+ fsdp.StateDictType.FULL_STATE_DICT,
175
+ save_policy,
176
+ ):
177
+ cpu_state = model.state_dict()
178
+ # We do *not* want to write to the same location with multiple
179
+ # processes at the same time.
180
+ if is_main_process():
181
+ if not os.path.exists(os.path.dirname(save_path)):
182
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
183
+ torch.save(obj=cpu_state, f=save_path, *args, **kwargs)
184
+
185
+ case "ddp":
186
+ if is_main_process():
187
+ torch.save(obj=model.module.state_dict(), f=save_path, *args, **kwargs)
188
+ dist.barrier()
189
+ case _:
190
+ raise ValueError(
191
+ f'`parallelism` should be one of "ddp" and "fsdp". Got {parallelism}.'
192
+ )
193
+
194
+
195
+ def save_optim_singular(
196
+ model: nn.Module,
197
+ optimizer: torch.optim.Optimizer,
198
+ save_path: str,
199
+ parallelism: str = "fsdp",
200
+ ):
201
+ match parallelism:
202
+ case "fsdp":
203
+ optim_state_dict_config = fsdp.FullOptimStateDictConfig(
204
+ offload_to_cpu=True, rank0_only=True
205
+ )
206
+
207
+ with fsdp.FullyShardedDataParallel.state_dict_type(
208
+ model,
209
+ fsdp.StateDictType.FULL_STATE_DICT,
210
+ optim_state_dict_config=optim_state_dict_config,
211
+ ):
212
+ optim_state_dict = fsdp.FullyShardedDataParallel.optim_state_dict(
213
+ model, optimizer
214
+ )
215
+
216
+ if is_main_process():
217
+ if not os.path.exists(os.path.dirname(save_path)):
218
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
219
+ checkpoint = {
220
+ "optimizer_state_dict": optim_state_dict,
221
+ }
222
+ torch.save(checkpoint, f=save_path)
223
+ case "ddp":
224
+ if is_main_process():
225
+ optim_state_dict = optimizer.state_dict()
226
+ if not os.path.exists(os.path.dirname(save_path)):
227
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
228
+ torch.save(obj=optim_state_dict, f=save_path)
229
+ dist.barrier()
230
+ case _:
231
+ raise ValueError(
232
+ f'`parallelism` should be one of "ddp" and "fsdp". Got {parallelism}.'
233
+ )
234
+
235
+
236
+ def collect_optim_singular(
237
+ model: nn.Module, optimizer: torch.optim.Optimizer, parallelism: str = "fsdp"
238
+ ) -> dict:
239
+ optim_state_dict = {}
240
+ match parallelism:
241
+ case "fsdp":
242
+ optim_state_dict_config = fsdp.FullOptimStateDictConfig(
243
+ offload_to_cpu=True, rank0_only=True
244
+ )
245
+
246
+ with fsdp.FullyShardedDataParallel.state_dict_type(
247
+ model,
248
+ fsdp.StateDictType.FULL_STATE_DICT,
249
+ optim_state_dict_config=optim_state_dict_config,
250
+ ):
251
+ optim_state_dict = fsdp.FullyShardedDataParallel.optim_state_dict(
252
+ model, optimizer
253
+ )
254
+
255
+ case "ddp":
256
+ if is_main_process():
257
+ optim_state_dict = optimizer.state_dict()
258
+ dist.barrier()
259
+ case _:
260
+ raise ValueError(
261
+ f'`parallelism` should be one of "ddp" and "fsdp". Got {parallelism}.'
262
+ )
263
+
264
+ return optim_state_dict
265
+
266
+
267
+ def save_state_singular(states: TrainState, save_path, *args, **kwargs):
268
+ """Stream all model parameters to rank 0 on the CPU, then pass all
269
+ other given arguments to `torch.save` to save paramters, but only on
270
+ the root process.
271
+ """
272
+ if is_main_process():
273
+ if not os.path.exists(os.path.dirname(save_path)):
274
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
275
+ torch.save(obj=states, f=save_path, *args, **kwargs)
276
+ dist.barrier()
277
+
278
+
279
+ class StatefulDistributedSampler(DistributedSampler):
280
+ _YIELDED = "yielded"
281
+
282
+ def __init__(
283
+ self,
284
+ dataset: Dataset,
285
+ num_replicas: Optional[int] = None,
286
+ rank: Optional[int] = None,
287
+ shuffle: bool = True,
288
+ seed: int = 0,
289
+ drop_last: bool = False,
290
+ ) -> None:
291
+ super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
292
+ self.yielded = 0
293
+ self.next_yielded = None
294
+
295
+ def __iter__(self):
296
+ self.yielded = 0
297
+ if self.next_yielded is not None:
298
+ self.yielded = self.next_yielded
299
+ self.next_yielded = None
300
+ it = super().__iter__()
301
+ for idx in itertools.islice(it, self.yielded, None):
302
+ self.yielded += 1
303
+ yield idx
304
+
305
+ def state_dict(self) -> Dict[str, Any]:
306
+ return {self._YIELDED: self.yielded}
307
+
308
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
309
+ if self._YIELDED not in state_dict:
310
+ raise ValueError("Invalid state_dict")
311
+ if state_dict[self._YIELDED] < 0:
312
+ raise ValueError("Cannot load state_dict with negative yielded value")
313
+ self.next_yielded = state_dict[self._YIELDED]
surya/utils/log.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import logging
3
+ import os
4
+ import sys
5
+ from time import time
6
+ from packaging.version import Version
7
+ import wandb
8
+ from typing import Dict, Optional, Any
9
+
10
+
11
+ if Version(wandb.__version__) < Version("0.20.0"):
12
+ WANDB_USE_SYNC = True
13
+ else:
14
+ WANDB_USE_SYNC = False
15
+
16
+
17
+ def log(
18
+ run,
19
+ data: Dict[str, Any],
20
+ step: Optional[int] = None,
21
+ commit: Optional[bool] = None,
22
+ sync: Optional[bool] = None,
23
+ ) -> None:
24
+ if run is not None:
25
+ # Note: wandb changed the .log API with version 0.20.0.
26
+ # This includes: "Removed no-op sync argument from wandb.Run::log function"
27
+ # We didn't test whether sync has any function here. But since we did
28
+ # all our development with it, let's keep it here for now.
29
+ # See https://github.com/wandb/wandb/releases/tag/v0.20.0
30
+ if WANDB_USE_SYNC:
31
+ run.log(data, step, commit, sync)
32
+ else:
33
+ run.log(data, step, commit)
34
+ else:
35
+ print(data)
36
+
37
+
38
+ # See: https://github.com/microsoft/Swin-Transformer/blob/main/logger.py
39
+ # See: https://github.com/Meituan-AutoML/Twins/blob/main/logger.py
40
+ def create_logger(output_dir: str, dist_rank: int, name: str) -> logging.Logger:
41
+ # create logger
42
+ logger = logging.getLogger(name)
43
+ logger.setLevel(logging.DEBUG)
44
+ logger.propagate = False
45
+
46
+ # create formatter
47
+ fmt = "[%(asctime)s %(name)s]: %(levelname)s %(message)s"
48
+
49
+ # create console handlers
50
+ if name.endswith("main"):
51
+ console_handler = logging.StreamHandler(sys.stdout)
52
+ console_handler.setLevel(logging.INFO)
53
+ console_handler.setFormatter(
54
+ logging.Formatter(fmt=fmt, datefmt="%Y-%m-%d %H:%M:%S")
55
+ )
56
+ logger.addHandler(console_handler)
57
+
58
+ # create file handlers
59
+ file_handler = logging.FileHandler(
60
+ os.path.join(output_dir, f"{name}.log"), mode="a"
61
+ )
62
+ file_handler.setLevel(logging.DEBUG)
63
+ file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt="%Y-%m-%d %H:%M:%S"))
64
+ logger.addHandler(file_handler)
65
+
66
+ return logger
67
+
68
+
69
+ def log_decorator(logger, _func=None):
70
+ def log_decorator_info(func):
71
+ @functools.wraps(func)
72
+ def log_decorator_wrapper(*args, **kwargs):
73
+ """Create a list of the positional arguments passed to function.
74
+ - Using repr() for string representation for each argument. repr() is similar to str() only
75
+ difference being it prints with a pair of quotes and if we calculate a value we get more
76
+ precise value than str().
77
+ """
78
+
79
+ # py_file_caller = getframeinfo(stack()[1][0])
80
+
81
+ local_rank = os.environ.get("LOCAL_RANK", default=None)
82
+ rank = os.environ.get("LOCAL_RANK", default=None)
83
+
84
+ try:
85
+ """log return value from the function"""
86
+ start_time = time()
87
+ value = func(*args, **kwargs)
88
+ if local_rank is None or rank is None:
89
+ logger.info(
90
+ f"Function '{func.__name__}' - Execution time: {(time() - start_time):.1f} seconds."
91
+ )
92
+ else:
93
+ logger.info(
94
+ f"Function '{func.__name__}' - Execution time: {(time() - start_time):.1f} "
95
+ f"seconds on rank {os.environ['RANK']} and local_rank {os.environ['LOCAL_RANK']}."
96
+ )
97
+ except Exception as err:
98
+ logger.error(f"Exception: {err}")
99
+ raise
100
+ return value
101
+
102
+ # Return the pointer to the function
103
+ return log_decorator_wrapper
104
+
105
+ # Decorator was called with arguments, so return a decorator function that can read and return a function
106
+ if _func is None:
107
+ return log_decorator_info
108
+ # Decorator was called without arguments, so apply the decorator to the function immediately
109
+ else:
110
+ return log_decorator_info(_func)
surya/utils/misc.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numbers
2
+ from logging import Logger
3
+ from time import time
4
+
5
+ import numpy as np
6
+ import torch
7
+ from numpy.lib.stride_tricks import as_strided
8
+ from torch.utils.data import DataLoader
9
+
10
+
11
+ def view_as_windows(arr_in: np.ndarray, window_shape, step=1) -> np.ndarray:
12
+ """Rolling window view of the input n-dimensional array.
13
+ Windows are overlapping views of the input array, with adjacent windows
14
+ shifted by a single row or column (or an index of a higher dimension).
15
+
16
+ Ref: https://github.com/scikit-image/scikit-image/blob/5e74a4a3a5149a8a14566b81a32bb15499aa3857/skimage/util/shape.py#L97-L247
17
+ Parameters
18
+ """
19
+
20
+ # -- basic checks on arguments
21
+ if not isinstance(arr_in, np.ndarray):
22
+ raise TypeError("`arr_in` must be a numpy ndarray")
23
+
24
+ ndim = arr_in.ndim
25
+
26
+ if isinstance(window_shape, numbers.Number):
27
+ window_shape = (window_shape,) * ndim
28
+ if not (len(window_shape) == ndim):
29
+ raise ValueError("`window_shape` is incompatible with `arr_in.shape`")
30
+
31
+ if isinstance(step, numbers.Number):
32
+ if step < 1:
33
+ raise ValueError("`step` must be >= 1")
34
+ step = (step,) * ndim
35
+ if len(step) != ndim:
36
+ raise ValueError("`step` is incompatible with `arr_in.shape`")
37
+
38
+ arr_shape = np.array(arr_in.shape)
39
+ window_shape = np.array(window_shape, dtype=arr_shape.dtype)
40
+
41
+ if ((arr_shape - window_shape) < 0).any():
42
+ raise ValueError("`window_shape` is too large")
43
+
44
+ if ((window_shape - 1) < 0).any():
45
+ raise ValueError("`window_shape` is too small")
46
+
47
+ # -- build rolling window view
48
+ slices = tuple(slice(None, None, st) for st in step)
49
+ window_strides = np.array(arr_in.strides)
50
+
51
+ indexing_strides = arr_in[slices].strides
52
+
53
+ win_indices_shape = (
54
+ (np.array(arr_in.shape) - np.array(window_shape)) // np.array(step)
55
+ ) + 1
56
+
57
+ new_shape = tuple(list(win_indices_shape) + list(window_shape))
58
+ strides = tuple(list(indexing_strides) + list(window_strides))
59
+
60
+ arr_out = as_strided(arr_in, shape=new_shape, strides=strides)
61
+ return arr_out
62
+
63
+
64
+ def class_from_name(module_name: str, class_name: str) -> object:
65
+ # load the module, will raise ImportError if module cannot be loaded
66
+ m = __import__(module_name, globals(), locals(), [class_name])
67
+ # get the class, will raise AttributeError if class cannot be found
68
+ c = getattr(m, class_name)
69
+ return c
70
+
71
+
72
+ @torch.no_grad()
73
+ def throughput(data_loader: DataLoader, model: torch.nn.Module, logger: Logger):
74
+ model.eval()
75
+
76
+ for idx, (images, _) in enumerate(data_loader):
77
+ images = images.cuda(non_blocking=True)
78
+ batch_size = images.shape[0]
79
+ for i in range(50):
80
+ model(images)
81
+ torch.cuda.synchronize()
82
+ logger.info("throughput averaged with 30 times")
83
+ tic1 = time()
84
+ for i in range(30):
85
+ model(images)
86
+ torch.cuda.synchronize()
87
+ tic2 = time()
88
+ logger.info(
89
+ f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}"
90
+ )
surya/utils/schemas.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TypedDict, Dict, Any
2
+ import torch
3
+
4
+
5
+ class TrainState(TypedDict):
6
+ dataloader: torch.utils.data.DataLoader
7
+ optimizer: Dict[str, Any]
8
+ scheduler: Dict[str, Any]
9
+ sampler: Any # Changed from torch.utils.data.sampler to Any
10
+ profiler: bool
11
+ epoch: int
12
+ iteration: int
13
+ loss: float
14
+ wandb_state: int