johannesschmude's picture
Initial commit
b73936d
raw
history blame
6.3 kB
from typing import Dict
import numpy as np
import torch
from surya.datasets.transformations import Transformation, StandardScaler
from surya.utils.config import DataConfig
from surya.utils.misc import class_from_name, view_as_windows
def custom_collate_fn(batch):
"""
Custom collate function for handling batches of data and metadata in a PyTorch DataLoader.
This function separately processes the data and metadata from the input batch.
- The `data_batch` is collated using PyTorch's `default_collate`. If collation fails due to incompatible data types,
the batch is returned as-is.
- The `metadata_batch` is assumed to be a dictionary, where each key corresponds to a list of values across the batch.
Each key is collated using `default_collate`. If collation fails for a particular key, the original list of values
is retained.
Example usage for accessing collated metadata:
- `collated_metadata['timestamps_input'][batch_idx][input_time]`
- `collated_metadata['timestamps_input'][batch_idx][rollout_step]`
Args:
batch (list of tuples): Each tuple contains (data, metadata), where:
- `data` is a tensor or other data structure used for training.
- `metadata` is a dictionary containing additional information.
Returns:
tuple: (collated_data, collated_metadata)
- `collated_data`: The processed batch of data.
- `collated_metadata`: The processed batch of metadata.
"""
# Unpack batch into separate lists of data and metadata
data_batch, metadata_batch = zip(*batch)
# Attempt to collate the data batch using PyTorch's default collate function
try:
collated_data = torch.utils.data.default_collate(data_batch)
except TypeError:
# If default_collate fails (e.g., due to incompatible types), return the data batch as-is
collated_data = data_batch
# Handle metadata collation
if isinstance(metadata_batch[0], dict):
collated_metadata = {}
for key in metadata_batch[0].keys():
values = [d[key] for d in metadata_batch]
try:
# Attempt to collate values under the current key
collated_metadata[key] = torch.utils.data.default_collate(values)
except TypeError:
# If collation fails, keep the values as a list
collated_metadata[key] = values
else:
# If metadata is not a dictionary, try to collate it as a whole
try:
collated_metadata = torch.utils.data.default_collate(metadata_batch)
except TypeError:
# If collation fails, return metadata as-is
collated_metadata = metadata_batch
return collated_data, collated_metadata
def calc_num_windows(raw_size: int, win_size: int, stride: int) -> int:
return (raw_size - win_size) // stride + 1
def get_scalers_info(dataset) -> dict:
return {
k: (type(v).__module__, type(v).__name__, v.to_dict())
for k, v in dataset.scalers.items()
}
def build_scalers_pressure(info: dict) -> Dict[str, Transformation]:
ret_dict = {k: dict() for k in info.keys()}
for var_key, var_d in info.items():
for p_key, p_val in var_d.items():
ret_dict[var_key][p_key] = class_from_name(
p_val["base"], p_val["class"]
).from_dict(p_val)
return ret_dict
def build_scalers(info: dict) -> Dict[str, Transformation]:
ret_dict = {k: None for k in info.keys()}
for p_key, p_val in info.items():
ret_dict[p_key]: StandardScaler = class_from_name(
p_val["base"], p_val["class"]
).from_dict(p_val)
return ret_dict
def break_batch_5d(
data: list, lat_size: int, lon_size: int, time_steps: int
) -> np.ndarray:
"""
data: list of samples, each sample is [C, T, L, H, W]
"""
num_levels = data[0].shape[2]
num_vars = data[0].shape[0]
big_batch = np.stack(data, axis=0)
vw = view_as_windows(
big_batch,
[1, num_vars, time_steps, num_levels, lat_size, lon_size],
step=[1, num_vars, time_steps, num_levels, lat_size, lon_size],
).squeeze()
# To check if it is correctly reshaping
# idx = 30
# (big_batch[0, :, idx:idx+2, :, 40:80, 40:80]-vw[idx//2, 1, 1]).sum()
vw = vw.reshape(-1, num_vars, time_steps, num_levels, lat_size, lon_size)
# How to test:
# (big_batch[0, :, :2, :, :40, :40] - vw[0]).sum()
# (big_batch[0, :, :2, :, :40, 40:80] - vw[1]).sum()
# (big_batch[0, :, :2, :, 40:80, :40] - vw[2]).sum()
# Need to move axis because Weather model is expecting [C, L, T, H, W] instead of [C, T, L, H, W]
vw = np.moveaxis(vw, 3, 2)
vw = torch.tensor(vw, dtype=torch.float32)
return vw
def break_batch_5d_aug(data: list, cfg: DataConfig, max_batch: int = 256) -> np.ndarray:
num_levels = data[0].shape[2]
num_vars = data[0].shape[0]
big_batch = np.stack(data, axis=0)
y_step, x_step, t_step = (
cfg.patch_size_lat // 2,
cfg.patch_size_lon // 2,
cfg.patch_size_time // 2,
)
y_max = calc_num_windows(big_batch.shape[4], cfg.input_size_lat, y_step)
x_max = calc_num_windows(big_batch.shape[5], cfg.input_size_lon, x_step)
t_max = calc_num_windows(big_batch.shape[2], cfg.input_size_time, t_step)
max_batch = min(max_batch, y_max * x_max * t_max)
batch = np.empty(
(
max_batch,
num_vars,
cfg.input_size_time,
num_levels,
cfg.input_size_lat,
cfg.input_size_lon,
),
dtype=np.float32,
)
for j, i in enumerate(np.random.permutation(np.arange(max_batch))):
t, y, x = np.unravel_index(
i,
(
t_max,
y_max,
x_max,
),
)
batch[j] = big_batch[
:, # batch_id
:, # vars
t * t_step : t * t_step + cfg.input_size_time,
:, # levels
y * y_step : y * y_step + cfg.input_size_lat,
x * x_step : x * x_step + cfg.input_size_lon,
]
batch = np.moveaxis(batch, 3, 2)
batch = torch.tensor(batch, dtype=torch.float32)
return batch