Spaces:
Runtime error
Runtime error
File size: 6,304 Bytes
b73936d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
from typing import Dict
import numpy as np
import torch
from surya.datasets.transformations import Transformation, StandardScaler
from surya.utils.config import DataConfig
from surya.utils.misc import class_from_name, view_as_windows
def custom_collate_fn(batch):
"""
Custom collate function for handling batches of data and metadata in a PyTorch DataLoader.
This function separately processes the data and metadata from the input batch.
- The `data_batch` is collated using PyTorch's `default_collate`. If collation fails due to incompatible data types,
the batch is returned as-is.
- The `metadata_batch` is assumed to be a dictionary, where each key corresponds to a list of values across the batch.
Each key is collated using `default_collate`. If collation fails for a particular key, the original list of values
is retained.
Example usage for accessing collated metadata:
- `collated_metadata['timestamps_input'][batch_idx][input_time]`
- `collated_metadata['timestamps_input'][batch_idx][rollout_step]`
Args:
batch (list of tuples): Each tuple contains (data, metadata), where:
- `data` is a tensor or other data structure used for training.
- `metadata` is a dictionary containing additional information.
Returns:
tuple: (collated_data, collated_metadata)
- `collated_data`: The processed batch of data.
- `collated_metadata`: The processed batch of metadata.
"""
# Unpack batch into separate lists of data and metadata
data_batch, metadata_batch = zip(*batch)
# Attempt to collate the data batch using PyTorch's default collate function
try:
collated_data = torch.utils.data.default_collate(data_batch)
except TypeError:
# If default_collate fails (e.g., due to incompatible types), return the data batch as-is
collated_data = data_batch
# Handle metadata collation
if isinstance(metadata_batch[0], dict):
collated_metadata = {}
for key in metadata_batch[0].keys():
values = [d[key] for d in metadata_batch]
try:
# Attempt to collate values under the current key
collated_metadata[key] = torch.utils.data.default_collate(values)
except TypeError:
# If collation fails, keep the values as a list
collated_metadata[key] = values
else:
# If metadata is not a dictionary, try to collate it as a whole
try:
collated_metadata = torch.utils.data.default_collate(metadata_batch)
except TypeError:
# If collation fails, return metadata as-is
collated_metadata = metadata_batch
return collated_data, collated_metadata
def calc_num_windows(raw_size: int, win_size: int, stride: int) -> int:
return (raw_size - win_size) // stride + 1
def get_scalers_info(dataset) -> dict:
return {
k: (type(v).__module__, type(v).__name__, v.to_dict())
for k, v in dataset.scalers.items()
}
def build_scalers_pressure(info: dict) -> Dict[str, Transformation]:
ret_dict = {k: dict() for k in info.keys()}
for var_key, var_d in info.items():
for p_key, p_val in var_d.items():
ret_dict[var_key][p_key] = class_from_name(
p_val["base"], p_val["class"]
).from_dict(p_val)
return ret_dict
def build_scalers(info: dict) -> Dict[str, Transformation]:
ret_dict = {k: None for k in info.keys()}
for p_key, p_val in info.items():
ret_dict[p_key]: StandardScaler = class_from_name(
p_val["base"], p_val["class"]
).from_dict(p_val)
return ret_dict
def break_batch_5d(
data: list, lat_size: int, lon_size: int, time_steps: int
) -> np.ndarray:
"""
data: list of samples, each sample is [C, T, L, H, W]
"""
num_levels = data[0].shape[2]
num_vars = data[0].shape[0]
big_batch = np.stack(data, axis=0)
vw = view_as_windows(
big_batch,
[1, num_vars, time_steps, num_levels, lat_size, lon_size],
step=[1, num_vars, time_steps, num_levels, lat_size, lon_size],
).squeeze()
# To check if it is correctly reshaping
# idx = 30
# (big_batch[0, :, idx:idx+2, :, 40:80, 40:80]-vw[idx//2, 1, 1]).sum()
vw = vw.reshape(-1, num_vars, time_steps, num_levels, lat_size, lon_size)
# How to test:
# (big_batch[0, :, :2, :, :40, :40] - vw[0]).sum()
# (big_batch[0, :, :2, :, :40, 40:80] - vw[1]).sum()
# (big_batch[0, :, :2, :, 40:80, :40] - vw[2]).sum()
# Need to move axis because Weather model is expecting [C, L, T, H, W] instead of [C, T, L, H, W]
vw = np.moveaxis(vw, 3, 2)
vw = torch.tensor(vw, dtype=torch.float32)
return vw
def break_batch_5d_aug(data: list, cfg: DataConfig, max_batch: int = 256) -> np.ndarray:
num_levels = data[0].shape[2]
num_vars = data[0].shape[0]
big_batch = np.stack(data, axis=0)
y_step, x_step, t_step = (
cfg.patch_size_lat // 2,
cfg.patch_size_lon // 2,
cfg.patch_size_time // 2,
)
y_max = calc_num_windows(big_batch.shape[4], cfg.input_size_lat, y_step)
x_max = calc_num_windows(big_batch.shape[5], cfg.input_size_lon, x_step)
t_max = calc_num_windows(big_batch.shape[2], cfg.input_size_time, t_step)
max_batch = min(max_batch, y_max * x_max * t_max)
batch = np.empty(
(
max_batch,
num_vars,
cfg.input_size_time,
num_levels,
cfg.input_size_lat,
cfg.input_size_lon,
),
dtype=np.float32,
)
for j, i in enumerate(np.random.permutation(np.arange(max_batch))):
t, y, x = np.unravel_index(
i,
(
t_max,
y_max,
x_max,
),
)
batch[j] = big_batch[
:, # batch_id
:, # vars
t * t_step : t * t_step + cfg.input_size_time,
:, # levels
y * y_step : y * y_step + cfg.input_size_lat,
x * x_step : x * x_step + cfg.input_size_lon,
]
batch = np.moveaxis(batch, 3, 2)
batch = torch.tensor(batch, dtype=torch.float32)
return batch
|