|
from typing import Dict |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from surya.datasets.transformations import Transformation, StandardScaler |
|
from surya.utils.config import DataConfig |
|
from surya.utils.misc import class_from_name, view_as_windows |
|
|
|
|
|
def custom_collate_fn(batch): |
|
""" |
|
Custom collate function for handling batches of data and metadata in a PyTorch DataLoader. |
|
|
|
This function separately processes the data and metadata from the input batch. |
|
|
|
- The `data_batch` is collated using PyTorch's `default_collate`. If collation fails due to incompatible data types, |
|
the batch is returned as-is. |
|
|
|
- The `metadata_batch` is assumed to be a dictionary, where each key corresponds to a list of values across the batch. |
|
Each key is collated using `default_collate`. If collation fails for a particular key, the original list of values |
|
is retained. |
|
|
|
Example usage for accessing collated metadata: |
|
- `collated_metadata['timestamps_input'][batch_idx][input_time]` |
|
- `collated_metadata['timestamps_input'][batch_idx][rollout_step]` |
|
|
|
Args: |
|
batch (list of tuples): Each tuple contains (data, metadata), where: |
|
- `data` is a tensor or other data structure used for training. |
|
- `metadata` is a dictionary containing additional information. |
|
|
|
Returns: |
|
tuple: (collated_data, collated_metadata) |
|
- `collated_data`: The processed batch of data. |
|
- `collated_metadata`: The processed batch of metadata. |
|
""" |
|
|
|
|
|
data_batch, metadata_batch = zip(*batch) |
|
|
|
|
|
try: |
|
collated_data = torch.utils.data.default_collate(data_batch) |
|
except TypeError: |
|
|
|
collated_data = data_batch |
|
|
|
|
|
if isinstance(metadata_batch[0], dict): |
|
collated_metadata = {} |
|
for key in metadata_batch[0].keys(): |
|
values = [d[key] for d in metadata_batch] |
|
try: |
|
|
|
collated_metadata[key] = torch.utils.data.default_collate(values) |
|
except TypeError: |
|
|
|
collated_metadata[key] = values |
|
else: |
|
|
|
try: |
|
collated_metadata = torch.utils.data.default_collate(metadata_batch) |
|
except TypeError: |
|
|
|
collated_metadata = metadata_batch |
|
|
|
return collated_data, collated_metadata |
|
|
|
|
|
def calc_num_windows(raw_size: int, win_size: int, stride: int) -> int: |
|
return (raw_size - win_size) // stride + 1 |
|
|
|
|
|
def get_scalers_info(dataset) -> dict: |
|
return { |
|
k: (type(v).__module__, type(v).__name__, v.to_dict()) |
|
for k, v in dataset.scalers.items() |
|
} |
|
|
|
|
|
def build_scalers_pressure(info: dict) -> Dict[str, Transformation]: |
|
ret_dict = {k: dict() for k in info.keys()} |
|
for var_key, var_d in info.items(): |
|
for p_key, p_val in var_d.items(): |
|
ret_dict[var_key][p_key] = class_from_name( |
|
p_val["base"], p_val["class"] |
|
).from_dict(p_val) |
|
return ret_dict |
|
|
|
|
|
def build_scalers(info: dict) -> Dict[str, Transformation]: |
|
ret_dict = {k: None for k in info.keys()} |
|
for p_key, p_val in info.items(): |
|
ret_dict[p_key]: StandardScaler = class_from_name( |
|
p_val["base"], p_val["class"] |
|
).from_dict(p_val) |
|
return ret_dict |
|
|
|
|
|
def break_batch_5d( |
|
data: list, lat_size: int, lon_size: int, time_steps: int |
|
) -> np.ndarray: |
|
""" |
|
data: list of samples, each sample is [C, T, L, H, W] |
|
""" |
|
num_levels = data[0].shape[2] |
|
num_vars = data[0].shape[0] |
|
big_batch = np.stack(data, axis=0) |
|
vw = view_as_windows( |
|
big_batch, |
|
[1, num_vars, time_steps, num_levels, lat_size, lon_size], |
|
step=[1, num_vars, time_steps, num_levels, lat_size, lon_size], |
|
).squeeze() |
|
|
|
|
|
|
|
vw = vw.reshape(-1, num_vars, time_steps, num_levels, lat_size, lon_size) |
|
|
|
|
|
|
|
|
|
|
|
|
|
vw = np.moveaxis(vw, 3, 2) |
|
vw = torch.tensor(vw, dtype=torch.float32) |
|
return vw |
|
|
|
|
|
def break_batch_5d_aug(data: list, cfg: DataConfig, max_batch: int = 256) -> np.ndarray: |
|
num_levels = data[0].shape[2] |
|
num_vars = data[0].shape[0] |
|
big_batch = np.stack(data, axis=0) |
|
|
|
y_step, x_step, t_step = ( |
|
cfg.patch_size_lat // 2, |
|
cfg.patch_size_lon // 2, |
|
cfg.patch_size_time // 2, |
|
) |
|
y_max = calc_num_windows(big_batch.shape[4], cfg.input_size_lat, y_step) |
|
x_max = calc_num_windows(big_batch.shape[5], cfg.input_size_lon, x_step) |
|
t_max = calc_num_windows(big_batch.shape[2], cfg.input_size_time, t_step) |
|
max_batch = min(max_batch, y_max * x_max * t_max) |
|
|
|
batch = np.empty( |
|
( |
|
max_batch, |
|
num_vars, |
|
cfg.input_size_time, |
|
num_levels, |
|
cfg.input_size_lat, |
|
cfg.input_size_lon, |
|
), |
|
dtype=np.float32, |
|
) |
|
for j, i in enumerate(np.random.permutation(np.arange(max_batch))): |
|
t, y, x = np.unravel_index( |
|
i, |
|
( |
|
t_max, |
|
y_max, |
|
x_max, |
|
), |
|
) |
|
batch[j] = big_batch[ |
|
:, |
|
:, |
|
t * t_step : t * t_step + cfg.input_size_time, |
|
:, |
|
y * y_step : y * y_step + cfg.input_size_lat, |
|
x * x_step : x * x_step + cfg.input_size_lon, |
|
] |
|
|
|
batch = np.moveaxis(batch, 3, 2) |
|
batch = torch.tensor(batch, dtype=torch.float32) |
|
return batch |
|
|