Spaces:
Runtime error
Runtime error
from typing import Dict | |
import numpy as np | |
import torch | |
from surya.datasets.transformations import Transformation, StandardScaler | |
from surya.utils.config import DataConfig | |
from surya.utils.misc import class_from_name, view_as_windows | |
def custom_collate_fn(batch): | |
""" | |
Custom collate function for handling batches of data and metadata in a PyTorch DataLoader. | |
This function separately processes the data and metadata from the input batch. | |
- The `data_batch` is collated using PyTorch's `default_collate`. If collation fails due to incompatible data types, | |
the batch is returned as-is. | |
- The `metadata_batch` is assumed to be a dictionary, where each key corresponds to a list of values across the batch. | |
Each key is collated using `default_collate`. If collation fails for a particular key, the original list of values | |
is retained. | |
Example usage for accessing collated metadata: | |
- `collated_metadata['timestamps_input'][batch_idx][input_time]` | |
- `collated_metadata['timestamps_input'][batch_idx][rollout_step]` | |
Args: | |
batch (list of tuples): Each tuple contains (data, metadata), where: | |
- `data` is a tensor or other data structure used for training. | |
- `metadata` is a dictionary containing additional information. | |
Returns: | |
tuple: (collated_data, collated_metadata) | |
- `collated_data`: The processed batch of data. | |
- `collated_metadata`: The processed batch of metadata. | |
""" | |
# Unpack batch into separate lists of data and metadata | |
data_batch, metadata_batch = zip(*batch) | |
# Attempt to collate the data batch using PyTorch's default collate function | |
try: | |
collated_data = torch.utils.data.default_collate(data_batch) | |
except TypeError: | |
# If default_collate fails (e.g., due to incompatible types), return the data batch as-is | |
collated_data = data_batch | |
# Handle metadata collation | |
if isinstance(metadata_batch[0], dict): | |
collated_metadata = {} | |
for key in metadata_batch[0].keys(): | |
values = [d[key] for d in metadata_batch] | |
try: | |
# Attempt to collate values under the current key | |
collated_metadata[key] = torch.utils.data.default_collate(values) | |
except TypeError: | |
# If collation fails, keep the values as a list | |
collated_metadata[key] = values | |
else: | |
# If metadata is not a dictionary, try to collate it as a whole | |
try: | |
collated_metadata = torch.utils.data.default_collate(metadata_batch) | |
except TypeError: | |
# If collation fails, return metadata as-is | |
collated_metadata = metadata_batch | |
return collated_data, collated_metadata | |
def calc_num_windows(raw_size: int, win_size: int, stride: int) -> int: | |
return (raw_size - win_size) // stride + 1 | |
def get_scalers_info(dataset) -> dict: | |
return { | |
k: (type(v).__module__, type(v).__name__, v.to_dict()) | |
for k, v in dataset.scalers.items() | |
} | |
def build_scalers_pressure(info: dict) -> Dict[str, Transformation]: | |
ret_dict = {k: dict() for k in info.keys()} | |
for var_key, var_d in info.items(): | |
for p_key, p_val in var_d.items(): | |
ret_dict[var_key][p_key] = class_from_name( | |
p_val["base"], p_val["class"] | |
).from_dict(p_val) | |
return ret_dict | |
def build_scalers(info: dict) -> Dict[str, Transformation]: | |
ret_dict = {k: None for k in info.keys()} | |
for p_key, p_val in info.items(): | |
ret_dict[p_key]: StandardScaler = class_from_name( | |
p_val["base"], p_val["class"] | |
).from_dict(p_val) | |
return ret_dict | |
def break_batch_5d( | |
data: list, lat_size: int, lon_size: int, time_steps: int | |
) -> np.ndarray: | |
""" | |
data: list of samples, each sample is [C, T, L, H, W] | |
""" | |
num_levels = data[0].shape[2] | |
num_vars = data[0].shape[0] | |
big_batch = np.stack(data, axis=0) | |
vw = view_as_windows( | |
big_batch, | |
[1, num_vars, time_steps, num_levels, lat_size, lon_size], | |
step=[1, num_vars, time_steps, num_levels, lat_size, lon_size], | |
).squeeze() | |
# To check if it is correctly reshaping | |
# idx = 30 | |
# (big_batch[0, :, idx:idx+2, :, 40:80, 40:80]-vw[idx//2, 1, 1]).sum() | |
vw = vw.reshape(-1, num_vars, time_steps, num_levels, lat_size, lon_size) | |
# How to test: | |
# (big_batch[0, :, :2, :, :40, :40] - vw[0]).sum() | |
# (big_batch[0, :, :2, :, :40, 40:80] - vw[1]).sum() | |
# (big_batch[0, :, :2, :, 40:80, :40] - vw[2]).sum() | |
# Need to move axis because Weather model is expecting [C, L, T, H, W] instead of [C, T, L, H, W] | |
vw = np.moveaxis(vw, 3, 2) | |
vw = torch.tensor(vw, dtype=torch.float32) | |
return vw | |
def break_batch_5d_aug(data: list, cfg: DataConfig, max_batch: int = 256) -> np.ndarray: | |
num_levels = data[0].shape[2] | |
num_vars = data[0].shape[0] | |
big_batch = np.stack(data, axis=0) | |
y_step, x_step, t_step = ( | |
cfg.patch_size_lat // 2, | |
cfg.patch_size_lon // 2, | |
cfg.patch_size_time // 2, | |
) | |
y_max = calc_num_windows(big_batch.shape[4], cfg.input_size_lat, y_step) | |
x_max = calc_num_windows(big_batch.shape[5], cfg.input_size_lon, x_step) | |
t_max = calc_num_windows(big_batch.shape[2], cfg.input_size_time, t_step) | |
max_batch = min(max_batch, y_max * x_max * t_max) | |
batch = np.empty( | |
( | |
max_batch, | |
num_vars, | |
cfg.input_size_time, | |
num_levels, | |
cfg.input_size_lat, | |
cfg.input_size_lon, | |
), | |
dtype=np.float32, | |
) | |
for j, i in enumerate(np.random.permutation(np.arange(max_batch))): | |
t, y, x = np.unravel_index( | |
i, | |
( | |
t_max, | |
y_max, | |
x_max, | |
), | |
) | |
batch[j] = big_batch[ | |
:, # batch_id | |
:, # vars | |
t * t_step : t * t_step + cfg.input_size_time, | |
:, # levels | |
y * y_step : y * y_step + cfg.input_size_lat, | |
x * x_step : x * x_step + cfg.input_size_lon, | |
] | |
batch = np.moveaxis(batch, 3, 2) | |
batch = torch.tensor(batch, dtype=torch.float32) | |
return batch | |