|
import os |
|
import sys |
|
import random |
|
from datetime import datetime |
|
import torch |
|
import numpy as np |
|
import skimage.measure |
|
import xarray as xr |
|
import pandas as pd |
|
from logging import Logger |
|
from torch.utils.data import Dataset |
|
from surya.utils.distributed import get_rank |
|
from surya.utils.log import create_logger |
|
from functools import cache |
|
|
|
from numba import njit, prange |
|
|
|
import hdf5plugin |
|
|
|
|
|
@njit(parallel=True) |
|
def fast_transform(data, means, stds, sl_scale_factors, epsilons): |
|
""" |
|
Implements signum log transform using numba for speed |
|
Notes: |
|
- This must reside outside the class definition from which it is called. |
|
- We used this function during pretraining for faster data loading. On select |
|
GPU clusters it leads to the system hanging however when data loading happens |
|
outside the GPU thread. See below for a non-numba-enhanced version. |
|
|
|
Args: |
|
data: Numpy array of shape C, H, W |
|
means: Numpy array of shape C. Mean per channel. |
|
stds: Numpy array of shape C. Standard deviation per channel. |
|
sl_scale_factors: Numpy array of shape C. Signum-log scale factors. |
|
epsilons: Numpy array of shape C. Constant to avoid zero division. |
|
|
|
Returns: |
|
Numpy array of shape C, H, W. |
|
""" |
|
C, H, W = data.shape |
|
out = np.empty((C, H, W), dtype=np.float32) |
|
for c in prange(C): |
|
mean = means[c] |
|
std = stds[c] |
|
eps = epsilons[c] |
|
sl_scale_factor = sl_scale_factors[c] |
|
for i in range(H): |
|
for j in range(W): |
|
val = data[c, i, j] |
|
val = val * sl_scale_factor |
|
if val >= 0: |
|
val = np.log1p(val) |
|
else: |
|
val = -np.log1p(-val) |
|
out[c, i, j] = (val - mean) / (std + eps) |
|
return out |
|
|
|
def transform( |
|
data: np.ndarray, |
|
means: np.ndarray, |
|
stds: np.ndarray, |
|
sl_scale_factors: np.ndarray, |
|
epsilons: np.ndarray |
|
) -> np.ndarray: |
|
""" |
|
Implements signum log transform. Drop-in replacement for |
|
`fast_transform` method above. |
|
|
|
Args: |
|
data: Numpy array of shape C, H, W |
|
means: Numpy array of shape C. Mean per channel. |
|
stds: Numpy array of shape C. Standard deviation per channel. |
|
sl_scale_factors: Numpy array of shape C. Signum-log scale factors. |
|
epsilons: Numpy array of shape C. Constant to avoid zero division. |
|
|
|
Returns: |
|
Numpy array of shape C, H, W. |
|
""" |
|
means = means.reshape(*means.shape, 1, 1) |
|
stds = stds.reshape(*stds.shape, 1, 1) |
|
sl_scale_factors = sl_scale_factors.reshape(*sl_scale_factors.shape, 1, 1) |
|
epsilons = epsilons.reshape(*epsilons.shape, 1, 1) |
|
|
|
data = data * sl_scale_factors |
|
data = np.sign(data) * np.log1p(np.abs(data)) |
|
data = (data - means) / (stds + epsilons) |
|
|
|
return data |
|
|
|
@njit(parallel=True) |
|
def inverse_fast_transform(data, means, stds, sl_scale_factors, epsilons): |
|
""" |
|
Implements inverse signum log transform using numba for speed |
|
|
|
Args: |
|
data: Numpy array of shape C, H, W |
|
means: Numpy array of shape C. Mean per channel. |
|
stds: Numpy array of shape C. Standard deviation per channel. |
|
sl_scale_factors: Numpy array of shape C. Signum-log scale factors. |
|
epsilons: Numpy array of shape C. Constant to avoid zero division. |
|
|
|
Returns: |
|
Numpy array of shape C, H, W. |
|
""" |
|
C, H, W = data.shape |
|
out = np.empty((C, H, W), dtype=np.float32) |
|
|
|
for c in prange(C): |
|
mean = means[c] |
|
std = stds[c] |
|
eps = epsilons[c] |
|
sl_scale_factor = sl_scale_factors[c] |
|
|
|
for i in range(H): |
|
for j in range(W): |
|
val = data[c, i, j] |
|
val = val * (std + eps) + mean |
|
|
|
if val >= 0: |
|
val = np.expm1(val) |
|
else: |
|
val = -np.expm1(-val) |
|
|
|
val = val / sl_scale_factor |
|
|
|
out[c, i, j] = val |
|
|
|
return out |
|
|
|
|
|
def inverse_transform_single_channel(data, mean, std, sl_scale_factor, epsilon): |
|
""" |
|
Implements inverse signum log transform. |
|
|
|
Args: |
|
data: Numpy array of shape C, H, W |
|
means: Numpy array of shape C. Mean per channel. |
|
stds: Numpy array of shape C. Standard deviation per channel. |
|
sl_scale_factors: Numpy array of shape C. Signum-log scale factors. |
|
epsilons: Numpy array of shape C. Constant to avoid zero division. |
|
|
|
Returns: |
|
Numpy array of shape C, H, W. |
|
""" |
|
data = data * (std + epsilon) + mean |
|
|
|
data = np.sign(data) * np.expm1(np.abs(data)) |
|
|
|
data = data / sl_scale_factor |
|
|
|
return data |
|
|
|
|
|
class RandomChannelMaskerTransform: |
|
def __init__( |
|
self, num_channels, num_mask_aia_channels, phase, drop_hmi_probability |
|
): |
|
""" |
|
Initialize the RandomChannelMaskerTransform class as a transform. |
|
|
|
Args: |
|
- num_channels: Total number of channels in the input (3rd dimension of |
|
the tensor). |
|
- num_mask_aia_channels: Number of channels to randomly mask. |
|
""" |
|
self.num_channels = num_channels |
|
self.num_mask_aia_channels = num_mask_aia_channels |
|
self.drop_hmi_probability = drop_hmi_probability |
|
|
|
def __call__(self, input_tensor): |
|
C, T, H, W = input_tensor.shape |
|
|
|
|
|
channels_to_mask = random.sample(range(C), self.num_mask_aia_channels) |
|
|
|
|
|
mask = torch.ones((C, 1, 1, 1)) |
|
mask[channels_to_mask, ...] = 0 |
|
|
|
|
|
masked_tensor = input_tensor * mask |
|
|
|
if self.drop_hmi_probability > random.random(): |
|
masked_tensor[-1, ...] = 0 |
|
|
|
return masked_tensor |
|
|
|
|
|
class HelioNetCDFDataset(Dataset): |
|
""" |
|
PyTorch dataset to load a curated dataset from the NASA Solar Dynamics |
|
Observatory (SDO) mission stored as NetCDF files, with handling for variable timesteps. |
|
|
|
Internally maintains two databases. The first is `self.index`. This takes the |
|
form |
|
path present |
|
timestep |
|
2011-01-01 00:00:00 /lustre/fs0/scratch/shared/data/2011/01/Arka_2... 1 |
|
2011-01-01 00:12:00 /lustre/fs0/scratch/shared/data/2011/01/Arka_2... 1 |
|
... ... ... |
|
2012-11-30 23:48:00 /lustre/fs0/scratch/shared/data/2012/11/Arka_2... 1 |
|
|
|
The second is `self.valid_indices`. This is simply a list of timesteps -- entries |
|
in the index of `self.index` -- which define valid samples. A sample is valid |
|
when all timestamps that can be reached by entris in |
|
time_delta_input_minutes and time_delta_target_minutes can be reached from it |
|
are present. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
index_path: str, |
|
time_delta_input_minutes: list[int], |
|
time_delta_target_minutes: int, |
|
n_input_timestamps: int, |
|
rollout_steps: int, |
|
scalers=None, |
|
num_mask_aia_channels: int = 0, |
|
drop_hmi_probability: float = 0.0, |
|
use_latitude_in_learned_flow=False, |
|
channels: list[str] | None = None, |
|
phase="train", |
|
pooling: int | None = None, |
|
random_vert_flip: bool = False, |
|
): |
|
self.scalers = scalers |
|
self.phase = phase |
|
self.channels = channels |
|
self.num_mask_aia_channels = num_mask_aia_channels |
|
self.drop_hmi_probability = drop_hmi_probability |
|
self.n_input_timestamps = n_input_timestamps |
|
self.rollout_steps = rollout_steps |
|
self.use_latitude_in_learned_flow = use_latitude_in_learned_flow |
|
self.pooling = pooling if pooling is not None else 1 |
|
self.random_vert_flip = random_vert_flip |
|
|
|
if self.channels is None: |
|
|
|
self.channels = [ |
|
"0094", |
|
"0131", |
|
"0171", |
|
"0193", |
|
"0211", |
|
"0304", |
|
"0335", |
|
"hmi", |
|
] |
|
self.in_channels = len(self.channels) |
|
|
|
self.masker = RandomChannelMaskerTransform( |
|
num_channels=self.in_channels, |
|
num_mask_aia_channels=self.num_mask_aia_channels, |
|
phase=self.phase, |
|
drop_hmi_probability=self.drop_hmi_probability, |
|
) |
|
|
|
|
|
self.time_delta_input_minutes = sorted( |
|
np.timedelta64(t, "m") for t in time_delta_input_minutes |
|
) |
|
self.time_delta_target_minutes = [ |
|
np.timedelta64(iroll * time_delta_target_minutes, "m") |
|
for iroll in range(1, rollout_steps + 2) |
|
] |
|
|
|
|
|
self.index = pd.read_csv(index_path) |
|
self.index = self.index[self.index["present"] == 1] |
|
self.index["timestep"] = pd.to_datetime(self.index["timestep"]).values.astype( |
|
"datetime64[ns]" |
|
) |
|
self.index.set_index("timestep", inplace=True) |
|
self.index.sort_index(inplace=True) |
|
|
|
|
|
self.valid_indices = self.filter_valid_indices() |
|
self.adjusted_length = len(self.valid_indices) |
|
|
|
self.rank = get_rank() |
|
self.logger: Logger | None = None |
|
|
|
def create_logger(self): |
|
""" |
|
Creates a logger attached to self.logger. |
|
The logger is identified by SLURM job ID |
|
as well as the data processes rank and process ID. |
|
""" |
|
os.makedirs("logs/data", exist_ok=True) |
|
timestamp = datetime.now().strftime("%Y%m%dT%H%M%SZ") |
|
pid = os.getpid() |
|
self.logger = create_logger( |
|
output_dir="logs/data", |
|
dist_rank=self.rank, |
|
name=f"{timestamp}_{self.rank:>03}_data_{self.phase}_{pid}", |
|
) |
|
|
|
def filter_valid_indices(self): |
|
""" |
|
Extracts timestamps from the index of self.index that define valid |
|
samples. |
|
|
|
Args: |
|
Returns: |
|
List of timestamps. |
|
""" |
|
|
|
valid_indices = [] |
|
time_deltas = np.unique( |
|
self.time_delta_input_minutes + self.time_delta_target_minutes |
|
) |
|
|
|
for reference_timestep in self.index.index: |
|
required_timesteps = reference_timestep + time_deltas |
|
|
|
if all(t in self.index.index for t in required_timesteps): |
|
valid_indices.append(reference_timestep) |
|
|
|
return valid_indices |
|
|
|
def __len__(self): |
|
return self.adjusted_length |
|
|
|
def __getitem__(self, idx: int) -> dict: |
|
""" |
|
Args: |
|
idx: Index of sample to load. (Pytorch standard.) |
|
Returns: |
|
Dictionary with following keys. The values are tensors with shape as follows: |
|
ts (torch.Tensor): C, T, H, W |
|
time_delta_input (torch.Tensor): T |
|
input_latitude (torch.Tensor): T |
|
forecast (torch.Tensor): C, L, H, W |
|
lead_time_delta (torch.Tensor): L |
|
forecast_latitude (torch.Tensor): L |
|
C - Channels, T - Input times, H - Image height, W - Image width, L - Lead time. |
|
""" |
|
if self.logger is None: |
|
self.create_logger() |
|
self.logger.info(f"HelioNetCDFDataset of length {self.__len__()}.") |
|
|
|
exception_counter = 0 |
|
max_exception = 100 |
|
|
|
self.logger.info(f"Starting to retrieve index {idx}.") |
|
|
|
while True: |
|
try: |
|
sample = self._get_index_data(idx) |
|
except Exception as e: |
|
exception_counter += 1 |
|
if exception_counter >= max_exception: |
|
raise e |
|
|
|
reference_timestep = self.valid_indices[idx] |
|
self.logger.warning( |
|
f"Failed retrieving index {idx}. Timestamp {reference_timestep}. Attempt {exception_counter}." |
|
) |
|
|
|
idx = (idx + 1) % self.__len__() |
|
else: |
|
self.logger.info(f"Returning index {idx}.") |
|
return sample |
|
|
|
def _get_index_data(self, idx: int) -> dict: |
|
""" |
|
Args: |
|
idx: Index of sample to load. (Pytorch standard.) |
|
Returns: |
|
Dictionary with following keys. The values are tensors with shape as follows: |
|
ts (torch.Tensor): C, T, H, W |
|
time_delta_input (torch.Tensor): T |
|
input_latitude (torch.Tensor): T |
|
forecast (torch.Tensor): C, L, H, W |
|
lead_time_delta (torch.Tensor): L |
|
forecast_latitude (torch.Tensor): L |
|
C - Channels, T - Input times, H - Image height, W - Image width, L - Lead time. |
|
""" |
|
|
|
|
|
time_deltas = np.array( |
|
sorted( |
|
random.sample( |
|
self.time_delta_input_minutes[:-1], self.n_input_timestamps - 1 |
|
) |
|
) |
|
+ [self.time_delta_input_minutes[-1]] |
|
+ self.time_delta_target_minutes |
|
) |
|
reference_timestep = self.valid_indices[idx] |
|
required_timesteps = reference_timestep + time_deltas |
|
|
|
sequence_data = [ |
|
self.transform_data( |
|
self.load_nc_data( |
|
self.index.loc[timestep, "path"], timestep, self.channels |
|
) |
|
) |
|
for timestep in required_timesteps |
|
] |
|
|
|
|
|
inputs = sequence_data[: -self.rollout_steps - 1] |
|
targets = sequence_data[-self.rollout_steps - 1 :] |
|
|
|
stacked_inputs = np.stack(inputs, axis=1) |
|
stacked_targets = np.stack(targets, axis=1) |
|
|
|
timestamps_input = required_timesteps[: -self.rollout_steps - 1] |
|
timestamps_targets = required_timesteps[-self.rollout_steps - 1 :] |
|
|
|
if self.num_mask_aia_channels > 0 or self.drop_hmi_probability: |
|
|
|
|
|
|
|
stacked_inputs = self.masker(stacked_inputs) |
|
|
|
time_delta_input_float = ( |
|
time_deltas[-self.rollout_steps - 2] |
|
- time_deltas[: -self.rollout_steps - 1] |
|
) / np.timedelta64(1, "h") |
|
time_delta_input_float = time_delta_input_float.astype(np.float32) |
|
|
|
lead_time_delta_float = ( |
|
time_deltas[-self.rollout_steps - 2] |
|
- time_deltas[-self.rollout_steps - 1 :] |
|
) / np.timedelta64(1, "h") |
|
lead_time_delta_float = lead_time_delta_float.astype(np.float32) |
|
|
|
|
|
|
|
|
|
|
|
|
|
metadata = { |
|
"timestamps_input": timestamps_input, |
|
"timestamps_targets": timestamps_targets, |
|
} |
|
|
|
if self.random_vert_flip: |
|
if torch.bernoulli(torch.ones(()) / 2) == 1: |
|
stacked_inputs = torch.flip(stacked_inputs, dims=-2) |
|
stacked_targets = torch.flip(stacked_inputs, dims=-2) |
|
|
|
if self.use_latitude_in_learned_flow: |
|
from sunpy.coordinates.ephemeris import get_earth |
|
|
|
sequence_latitude = [ |
|
get_earth(timestep).lat.value for timestep in required_timesteps |
|
] |
|
input_latitudes = sequence_latitude[: -self.rollout_steps - 1] |
|
target_latitude = sequence_latitude[-self.rollout_steps - 1 :] |
|
|
|
return { |
|
"ts": stacked_inputs, |
|
"time_delta_input": time_delta_input_float, |
|
"input_latitudes": input_latitudes, |
|
"forecast": stacked_targets, |
|
"lead_time_delta": lead_time_delta_float, |
|
"forecast_latitude": target_latitude, |
|
}, metadata |
|
|
|
return { |
|
"ts": stacked_inputs, |
|
"time_delta_input": time_delta_input_float, |
|
"forecast": stacked_targets, |
|
"lead_time_delta": lead_time_delta_float, |
|
}, metadata |
|
|
|
def load_nc_data( |
|
self, filepath: str, timestep: pd.Timestamp, channels: list[str] |
|
) -> np.ndarray: |
|
""" |
|
Args: |
|
filepath: String or Pathlike. Points to NetCDF file to open. |
|
timestep: Identifies timestamp to retrieve. |
|
Returns: |
|
Numpy array of shape (C, H, W). |
|
""" |
|
self.logger.info(f"Reading file {filepath}.") |
|
|
|
with xr.open_dataset( |
|
filepath, engine="h5netcdf", chunks=None, cache=False, |
|
) as ds: |
|
data = ds[channels].to_array().load().to_numpy() |
|
|
|
return data |
|
|
|
@cache |
|
def transformation_inputs(self) -> (np.ndarray, np.ndarray, np.ndarray, np.ndarray): |
|
means = np.array([self.scalers[ch].mean for ch in self.channels]) |
|
stds = np.array([self.scalers[ch].std for ch in self.channels]) |
|
epsilons = np.array([self.scalers[ch].epsilon for ch in self.channels]) |
|
sl_scale_factors = np.array( |
|
[self.scalers[ch].sl_scale_factor for ch in self.channels] |
|
) |
|
|
|
return means, stds, epsilons, sl_scale_factors |
|
|
|
def transform_data(self, data: np.ndarray) -> np.ndarray: |
|
""" |
|
Applies scalers. |
|
|
|
Args: |
|
data: Numpy array of shape (C, H, W) |
|
Returns: |
|
Tensor of shape (C, H, W). Data type float32. |
|
Uses: |
|
numba to speed up transform |
|
tvk-srm-heliofm environment cloned from srm-heliofm with numba added |
|
tvk_dgx_slurm.sh shell script modified to use new environment and new jobname |
|
train_spectformer_dgx.yaml new jobname |
|
""" |
|
assert data.ndim == 3 |
|
|
|
if self.pooling > 1: |
|
data = skimage.measure.block_reduce( |
|
data, block_size=(1, self.pooling, self.pooling), func=np.mean |
|
) |
|
|
|
means, stds, epsilons, sl_scale_factors = self.transformation_inputs() |
|
result_np = transform(data, means, stds, sl_scale_factors, epsilons) |
|
return result_np |
|
|