johannesschmude's picture
Initial commit
b73936d
from typing import Dict
import numpy as np
import torch
from surya.datasets.transformations import Transformation, StandardScaler
from surya.utils.config import DataConfig
from surya.utils.misc import class_from_name, view_as_windows
def custom_collate_fn(batch):
"""
Custom collate function for handling batches of data and metadata in a PyTorch DataLoader.
This function separately processes the data and metadata from the input batch.
- The `data_batch` is collated using PyTorch's `default_collate`. If collation fails due to incompatible data types,
the batch is returned as-is.
- The `metadata_batch` is assumed to be a dictionary, where each key corresponds to a list of values across the batch.
Each key is collated using `default_collate`. If collation fails for a particular key, the original list of values
is retained.
Example usage for accessing collated metadata:
- `collated_metadata['timestamps_input'][batch_idx][input_time]`
- `collated_metadata['timestamps_input'][batch_idx][rollout_step]`
Args:
batch (list of tuples): Each tuple contains (data, metadata), where:
- `data` is a tensor or other data structure used for training.
- `metadata` is a dictionary containing additional information.
Returns:
tuple: (collated_data, collated_metadata)
- `collated_data`: The processed batch of data.
- `collated_metadata`: The processed batch of metadata.
"""
# Unpack batch into separate lists of data and metadata
data_batch, metadata_batch = zip(*batch)
# Attempt to collate the data batch using PyTorch's default collate function
try:
collated_data = torch.utils.data.default_collate(data_batch)
except TypeError:
# If default_collate fails (e.g., due to incompatible types), return the data batch as-is
collated_data = data_batch
# Handle metadata collation
if isinstance(metadata_batch[0], dict):
collated_metadata = {}
for key in metadata_batch[0].keys():
values = [d[key] for d in metadata_batch]
try:
# Attempt to collate values under the current key
collated_metadata[key] = torch.utils.data.default_collate(values)
except TypeError:
# If collation fails, keep the values as a list
collated_metadata[key] = values
else:
# If metadata is not a dictionary, try to collate it as a whole
try:
collated_metadata = torch.utils.data.default_collate(metadata_batch)
except TypeError:
# If collation fails, return metadata as-is
collated_metadata = metadata_batch
return collated_data, collated_metadata
def calc_num_windows(raw_size: int, win_size: int, stride: int) -> int:
return (raw_size - win_size) // stride + 1
def get_scalers_info(dataset) -> dict:
return {
k: (type(v).__module__, type(v).__name__, v.to_dict())
for k, v in dataset.scalers.items()
}
def build_scalers_pressure(info: dict) -> Dict[str, Transformation]:
ret_dict = {k: dict() for k in info.keys()}
for var_key, var_d in info.items():
for p_key, p_val in var_d.items():
ret_dict[var_key][p_key] = class_from_name(
p_val["base"], p_val["class"]
).from_dict(p_val)
return ret_dict
def build_scalers(info: dict) -> Dict[str, Transformation]:
ret_dict = {k: None for k in info.keys()}
for p_key, p_val in info.items():
ret_dict[p_key]: StandardScaler = class_from_name(
p_val["base"], p_val["class"]
).from_dict(p_val)
return ret_dict
def break_batch_5d(
data: list, lat_size: int, lon_size: int, time_steps: int
) -> np.ndarray:
"""
data: list of samples, each sample is [C, T, L, H, W]
"""
num_levels = data[0].shape[2]
num_vars = data[0].shape[0]
big_batch = np.stack(data, axis=0)
vw = view_as_windows(
big_batch,
[1, num_vars, time_steps, num_levels, lat_size, lon_size],
step=[1, num_vars, time_steps, num_levels, lat_size, lon_size],
).squeeze()
# To check if it is correctly reshaping
# idx = 30
# (big_batch[0, :, idx:idx+2, :, 40:80, 40:80]-vw[idx//2, 1, 1]).sum()
vw = vw.reshape(-1, num_vars, time_steps, num_levels, lat_size, lon_size)
# How to test:
# (big_batch[0, :, :2, :, :40, :40] - vw[0]).sum()
# (big_batch[0, :, :2, :, :40, 40:80] - vw[1]).sum()
# (big_batch[0, :, :2, :, 40:80, :40] - vw[2]).sum()
# Need to move axis because Weather model is expecting [C, L, T, H, W] instead of [C, T, L, H, W]
vw = np.moveaxis(vw, 3, 2)
vw = torch.tensor(vw, dtype=torch.float32)
return vw
def break_batch_5d_aug(data: list, cfg: DataConfig, max_batch: int = 256) -> np.ndarray:
num_levels = data[0].shape[2]
num_vars = data[0].shape[0]
big_batch = np.stack(data, axis=0)
y_step, x_step, t_step = (
cfg.patch_size_lat // 2,
cfg.patch_size_lon // 2,
cfg.patch_size_time // 2,
)
y_max = calc_num_windows(big_batch.shape[4], cfg.input_size_lat, y_step)
x_max = calc_num_windows(big_batch.shape[5], cfg.input_size_lon, x_step)
t_max = calc_num_windows(big_batch.shape[2], cfg.input_size_time, t_step)
max_batch = min(max_batch, y_max * x_max * t_max)
batch = np.empty(
(
max_batch,
num_vars,
cfg.input_size_time,
num_levels,
cfg.input_size_lat,
cfg.input_size_lon,
),
dtype=np.float32,
)
for j, i in enumerate(np.random.permutation(np.arange(max_batch))):
t, y, x = np.unravel_index(
i,
(
t_max,
y_max,
x_max,
),
)
batch[j] = big_batch[
:, # batch_id
:, # vars
t * t_step : t * t_step + cfg.input_size_time,
:, # levels
y * y_step : y * y_step + cfg.input_size_lat,
x * x_step : x * x_step + cfg.input_size_lon,
]
batch = np.moveaxis(batch, 3, 2)
batch = torch.tensor(batch, dtype=torch.float32)
return batch