Spaces:
Runtime error
Runtime error
import os | |
import sys | |
import random | |
from datetime import datetime | |
import torch | |
import numpy as np | |
import skimage.measure | |
import xarray as xr | |
import pandas as pd | |
from logging import Logger | |
from torch.utils.data import Dataset | |
from surya.utils.distributed import get_rank | |
from surya.utils.log import create_logger | |
from functools import cache | |
from numba import njit, prange | |
import hdf5plugin | |
def fast_transform(data, means, stds, sl_scale_factors, epsilons): | |
""" | |
Implements signum log transform using numba for speed | |
Notes: | |
- This must reside outside the class definition from which it is called. | |
- We used this function during pretraining for faster data loading. On select | |
GPU clusters it leads to the system hanging however when data loading happens | |
outside the GPU thread. See below for a non-numba-enhanced version. | |
Args: | |
data: Numpy array of shape C, H, W | |
means: Numpy array of shape C. Mean per channel. | |
stds: Numpy array of shape C. Standard deviation per channel. | |
sl_scale_factors: Numpy array of shape C. Signum-log scale factors. | |
epsilons: Numpy array of shape C. Constant to avoid zero division. | |
Returns: | |
Numpy array of shape C, H, W. | |
""" | |
C, H, W = data.shape | |
out = np.empty((C, H, W), dtype=np.float32) | |
for c in prange(C): | |
mean = means[c] | |
std = stds[c] | |
eps = epsilons[c] | |
sl_scale_factor = sl_scale_factors[c] | |
for i in range(H): | |
for j in range(W): | |
val = data[c, i, j] | |
val = val * sl_scale_factor | |
if val >= 0: | |
val = np.log1p(val) | |
else: | |
val = -np.log1p(-val) | |
out[c, i, j] = (val - mean) / (std + eps) | |
return out | |
def transform( | |
data: np.ndarray, | |
means: np.ndarray, | |
stds: np.ndarray, | |
sl_scale_factors: np.ndarray, | |
epsilons: np.ndarray | |
) -> np.ndarray: | |
""" | |
Implements signum log transform. Drop-in replacement for | |
`fast_transform` method above. | |
Args: | |
data: Numpy array of shape C, H, W | |
means: Numpy array of shape C. Mean per channel. | |
stds: Numpy array of shape C. Standard deviation per channel. | |
sl_scale_factors: Numpy array of shape C. Signum-log scale factors. | |
epsilons: Numpy array of shape C. Constant to avoid zero division. | |
Returns: | |
Numpy array of shape C, H, W. | |
""" | |
means = means.reshape(*means.shape, 1, 1) | |
stds = stds.reshape(*stds.shape, 1, 1) | |
sl_scale_factors = sl_scale_factors.reshape(*sl_scale_factors.shape, 1, 1) | |
epsilons = epsilons.reshape(*epsilons.shape, 1, 1) | |
data = data * sl_scale_factors | |
data = np.sign(data) * np.log1p(np.abs(data)) | |
data = (data - means) / (stds + epsilons) | |
return data | |
def inverse_fast_transform(data, means, stds, sl_scale_factors, epsilons): | |
""" | |
Implements inverse signum log transform using numba for speed | |
Args: | |
data: Numpy array of shape C, H, W | |
means: Numpy array of shape C. Mean per channel. | |
stds: Numpy array of shape C. Standard deviation per channel. | |
sl_scale_factors: Numpy array of shape C. Signum-log scale factors. | |
epsilons: Numpy array of shape C. Constant to avoid zero division. | |
Returns: | |
Numpy array of shape C, H, W. | |
""" | |
C, H, W = data.shape | |
out = np.empty((C, H, W), dtype=np.float32) | |
for c in prange(C): | |
mean = means[c] | |
std = stds[c] | |
eps = epsilons[c] | |
sl_scale_factor = sl_scale_factors[c] | |
for i in range(H): | |
for j in range(W): | |
val = data[c, i, j] | |
val = val * (std + eps) + mean | |
if val >= 0: | |
val = np.expm1(val) | |
else: | |
val = -np.expm1(-val) | |
val = val / sl_scale_factor | |
out[c, i, j] = val | |
return out | |
def inverse_transform_single_channel(data, mean, std, sl_scale_factor, epsilon): | |
""" | |
Implements inverse signum log transform. | |
Args: | |
data: Numpy array of shape C, H, W | |
means: Numpy array of shape C. Mean per channel. | |
stds: Numpy array of shape C. Standard deviation per channel. | |
sl_scale_factors: Numpy array of shape C. Signum-log scale factors. | |
epsilons: Numpy array of shape C. Constant to avoid zero division. | |
Returns: | |
Numpy array of shape C, H, W. | |
""" | |
data = data * (std + epsilon) + mean | |
data = np.sign(data) * np.expm1(np.abs(data)) | |
data = data / sl_scale_factor | |
return data | |
class RandomChannelMaskerTransform: | |
def __init__( | |
self, num_channels, num_mask_aia_channels, phase, drop_hmi_probability | |
): | |
""" | |
Initialize the RandomChannelMaskerTransform class as a transform. | |
Args: | |
- num_channels: Total number of channels in the input (3rd dimension of | |
the tensor). | |
- num_mask_aia_channels: Number of channels to randomly mask. | |
""" | |
self.num_channels = num_channels | |
self.num_mask_aia_channels = num_mask_aia_channels | |
self.drop_hmi_probability = drop_hmi_probability | |
def __call__(self, input_tensor): | |
C, T, H, W = input_tensor.shape # Unpacking the correct 5 dimensions | |
# Randomly select channels to mask | |
channels_to_mask = random.sample(range(C), self.num_mask_aia_channels) | |
# Create an in-place mask of shape [1, 1, num_channels, 1, 1] | |
mask = torch.ones((C, 1, 1, 1)) | |
mask[channels_to_mask, ...] = 0 # Set selected channels to zero | |
# Apply the mask in-place for memory efficiency | |
masked_tensor = input_tensor * mask # Modify input_tensor directly | |
if self.drop_hmi_probability > random.random(): | |
masked_tensor[-1, ...] = 0 | |
return masked_tensor | |
class HelioNetCDFDataset(Dataset): | |
""" | |
PyTorch dataset to load a curated dataset from the NASA Solar Dynamics | |
Observatory (SDO) mission stored as NetCDF files, with handling for variable timesteps. | |
Internally maintains two databases. The first is `self.index`. This takes the | |
form | |
path present | |
timestep | |
2011-01-01 00:00:00 /lustre/fs0/scratch/shared/data/2011/01/Arka_2... 1 | |
2011-01-01 00:12:00 /lustre/fs0/scratch/shared/data/2011/01/Arka_2... 1 | |
... ... ... | |
2012-11-30 23:48:00 /lustre/fs0/scratch/shared/data/2012/11/Arka_2... 1 | |
The second is `self.valid_indices`. This is simply a list of timesteps -- entries | |
in the index of `self.index` -- which define valid samples. A sample is valid | |
when all timestamps that can be reached by entris in | |
time_delta_input_minutes and time_delta_target_minutes can be reached from it | |
are present. | |
""" | |
def __init__( | |
self, | |
index_path: str, | |
time_delta_input_minutes: list[int], | |
time_delta_target_minutes: int, | |
n_input_timestamps: int, | |
rollout_steps: int, | |
scalers=None, | |
num_mask_aia_channels: int = 0, | |
drop_hmi_probability: float = 0.0, | |
use_latitude_in_learned_flow=False, | |
channels: list[str] | None = None, | |
phase="train", | |
pooling: int | None = None, | |
random_vert_flip: bool = False, | |
): | |
self.scalers = scalers | |
self.phase = phase | |
self.channels = channels | |
self.num_mask_aia_channels = num_mask_aia_channels | |
self.drop_hmi_probability = drop_hmi_probability | |
self.n_input_timestamps = n_input_timestamps | |
self.rollout_steps = rollout_steps | |
self.use_latitude_in_learned_flow = use_latitude_in_learned_flow | |
self.pooling = pooling if pooling is not None else 1 | |
self.random_vert_flip = random_vert_flip | |
if self.channels is None: | |
# AIA + HMI channels | |
self.channels = [ | |
"0094", | |
"0131", | |
"0171", | |
"0193", | |
"0211", | |
"0304", | |
"0335", | |
"hmi", | |
] | |
self.in_channels = len(self.channels) | |
self.masker = RandomChannelMaskerTransform( | |
num_channels=self.in_channels, | |
num_mask_aia_channels=self.num_mask_aia_channels, | |
phase=self.phase, | |
drop_hmi_probability=self.drop_hmi_probability, | |
) | |
# Convert time delta to numpy timedelta64 | |
self.time_delta_input_minutes = sorted( | |
np.timedelta64(t, "m") for t in time_delta_input_minutes | |
) | |
self.time_delta_target_minutes = [ | |
np.timedelta64(iroll * time_delta_target_minutes, "m") | |
for iroll in range(1, rollout_steps + 2) | |
] | |
# Create the index | |
self.index = pd.read_csv(index_path) | |
self.index = self.index[self.index["present"] == 1] | |
self.index["timestep"] = pd.to_datetime(self.index["timestep"]).values.astype( | |
"datetime64[ns]" | |
) | |
self.index.set_index("timestep", inplace=True) | |
self.index.sort_index(inplace=True) | |
# Filter out rows where the sequence is not fully present | |
self.valid_indices = self.filter_valid_indices() | |
self.adjusted_length = len(self.valid_indices) | |
self.rank = get_rank() | |
self.logger: Logger | None = None | |
def create_logger(self): | |
""" | |
Creates a logger attached to self.logger. | |
The logger is identified by SLURM job ID | |
as well as the data processes rank and process ID. | |
""" | |
os.makedirs("logs/data", exist_ok=True) | |
timestamp = datetime.now().strftime("%Y%m%dT%H%M%SZ") | |
pid = os.getpid() | |
self.logger = create_logger( | |
output_dir="logs/data", | |
dist_rank=self.rank, | |
name=f"{timestamp}_{self.rank:>03}_data_{self.phase}_{pid}", | |
) | |
def filter_valid_indices(self): | |
""" | |
Extracts timestamps from the index of self.index that define valid | |
samples. | |
Args: | |
Returns: | |
List of timestamps. | |
""" | |
valid_indices = [] | |
time_deltas = np.unique( | |
self.time_delta_input_minutes + self.time_delta_target_minutes | |
) | |
for reference_timestep in self.index.index: | |
required_timesteps = reference_timestep + time_deltas | |
if all(t in self.index.index for t in required_timesteps): | |
valid_indices.append(reference_timestep) | |
return valid_indices | |
def __len__(self): | |
return self.adjusted_length | |
def __getitem__(self, idx: int) -> dict: | |
""" | |
Args: | |
idx: Index of sample to load. (Pytorch standard.) | |
Returns: | |
Dictionary with following keys. The values are tensors with shape as follows: | |
ts (torch.Tensor): C, T, H, W | |
time_delta_input (torch.Tensor): T | |
input_latitude (torch.Tensor): T | |
forecast (torch.Tensor): C, L, H, W | |
lead_time_delta (torch.Tensor): L | |
forecast_latitude (torch.Tensor): L | |
C - Channels, T - Input times, H - Image height, W - Image width, L - Lead time. | |
""" | |
if self.logger is None: | |
self.create_logger() | |
self.logger.info(f"HelioNetCDFDataset of length {self.__len__()}.") | |
exception_counter = 0 | |
max_exception = 100 | |
self.logger.info(f"Starting to retrieve index {idx}.") | |
while True: | |
try: | |
sample = self._get_index_data(idx) | |
except Exception as e: | |
exception_counter += 1 | |
if exception_counter >= max_exception: | |
raise e | |
reference_timestep = self.valid_indices[idx] | |
self.logger.warning( | |
f"Failed retrieving index {idx}. Timestamp {reference_timestep}. Attempt {exception_counter}." | |
) | |
idx = (idx + 1) % self.__len__() | |
else: | |
self.logger.info(f"Returning index {idx}.") | |
return sample | |
def _get_index_data(self, idx: int) -> dict: | |
""" | |
Args: | |
idx: Index of sample to load. (Pytorch standard.) | |
Returns: | |
Dictionary with following keys. The values are tensors with shape as follows: | |
ts (torch.Tensor): C, T, H, W | |
time_delta_input (torch.Tensor): T | |
input_latitude (torch.Tensor): T | |
forecast (torch.Tensor): C, L, H, W | |
lead_time_delta (torch.Tensor): L | |
forecast_latitude (torch.Tensor): L | |
C - Channels, T - Input times, H - Image height, W - Image width, L - Lead time. | |
""" | |
# start_time = time.time() | |
time_deltas = np.array( | |
sorted( | |
random.sample( | |
self.time_delta_input_minutes[:-1], self.n_input_timestamps - 1 | |
) | |
) | |
+ [self.time_delta_input_minutes[-1]] | |
+ self.time_delta_target_minutes | |
) | |
reference_timestep = self.valid_indices[idx] | |
required_timesteps = reference_timestep + time_deltas | |
sequence_data = [ | |
self.transform_data( | |
self.load_nc_data( | |
self.index.loc[timestep, "path"], timestep, self.channels | |
) | |
) | |
for timestep in required_timesteps | |
] | |
# Split sequence_data into inputs and target | |
inputs = sequence_data[: -self.rollout_steps - 1] | |
targets = sequence_data[-self.rollout_steps - 1 :] | |
stacked_inputs = np.stack(inputs, axis=1) | |
stacked_targets = np.stack(targets, axis=1) | |
timestamps_input = required_timesteps[: -self.rollout_steps - 1] | |
timestamps_targets = required_timesteps[-self.rollout_steps - 1 :] | |
if self.num_mask_aia_channels > 0 or self.drop_hmi_probability: | |
# assert 0 < self.num_mask_aia_channels < self.in_channels, \ | |
# f'num_mask_aia_channels = {self.num_mask_aia_channels} should lie between 0 and {self.in_channels}' | |
stacked_inputs = self.masker(stacked_inputs) | |
time_delta_input_float = ( | |
time_deltas[-self.rollout_steps - 2] | |
- time_deltas[: -self.rollout_steps - 1] | |
) / np.timedelta64(1, "h") | |
time_delta_input_float = time_delta_input_float.astype(np.float32) | |
lead_time_delta_float = ( | |
time_deltas[-self.rollout_steps - 2] | |
- time_deltas[-self.rollout_steps - 1 :] | |
) / np.timedelta64(1, "h") | |
lead_time_delta_float = lead_time_delta_float.astype(np.float32) | |
# print('LocalRank', int(os.environ["LOCAL_RANK"]), | |
# 'GlobalRank', int(os.environ["RANK"]), | |
# 'worker', torch.utils.data.get_worker_info().id, | |
# f': Processed Input: {idx} ',time.time()- start_time) | |
metadata = { | |
"timestamps_input": timestamps_input, | |
"timestamps_targets": timestamps_targets, | |
} | |
if self.random_vert_flip: | |
if torch.bernoulli(torch.ones(()) / 2) == 1: | |
stacked_inputs = torch.flip(stacked_inputs, dims=-2) | |
stacked_targets = torch.flip(stacked_inputs, dims=-2) | |
if self.use_latitude_in_learned_flow: | |
from sunpy.coordinates.ephemeris import get_earth | |
sequence_latitude = [ | |
get_earth(timestep).lat.value for timestep in required_timesteps | |
] | |
input_latitudes = sequence_latitude[: -self.rollout_steps - 1] | |
target_latitude = sequence_latitude[-self.rollout_steps - 1 :] | |
return { | |
"ts": stacked_inputs, | |
"time_delta_input": time_delta_input_float, | |
"input_latitudes": input_latitudes, | |
"forecast": stacked_targets, | |
"lead_time_delta": lead_time_delta_float, | |
"forecast_latitude": target_latitude, | |
}, metadata | |
return { | |
"ts": stacked_inputs, | |
"time_delta_input": time_delta_input_float, | |
"forecast": stacked_targets, | |
"lead_time_delta": lead_time_delta_float, | |
}, metadata | |
def load_nc_data( | |
self, filepath: str, timestep: pd.Timestamp, channels: list[str] | |
) -> np.ndarray: | |
""" | |
Args: | |
filepath: String or Pathlike. Points to NetCDF file to open. | |
timestep: Identifies timestamp to retrieve. | |
Returns: | |
Numpy array of shape (C, H, W). | |
""" | |
self.logger.info(f"Reading file {filepath}.") | |
with xr.open_dataset( | |
filepath, engine="h5netcdf", chunks=None, cache=False, | |
) as ds: | |
data = ds[channels].to_array().load().to_numpy() | |
return data | |
def transformation_inputs(self) -> (np.ndarray, np.ndarray, np.ndarray, np.ndarray): | |
means = np.array([self.scalers[ch].mean for ch in self.channels]) | |
stds = np.array([self.scalers[ch].std for ch in self.channels]) | |
epsilons = np.array([self.scalers[ch].epsilon for ch in self.channels]) | |
sl_scale_factors = np.array( | |
[self.scalers[ch].sl_scale_factor for ch in self.channels] | |
) | |
return means, stds, epsilons, sl_scale_factors | |
def transform_data(self, data: np.ndarray) -> np.ndarray: | |
""" | |
Applies scalers. | |
Args: | |
data: Numpy array of shape (C, H, W) | |
Returns: | |
Tensor of shape (C, H, W). Data type float32. | |
Uses: | |
numba to speed up transform | |
tvk-srm-heliofm environment cloned from srm-heliofm with numba added | |
tvk_dgx_slurm.sh shell script modified to use new environment and new jobname | |
train_spectformer_dgx.yaml new jobname | |
""" | |
assert data.ndim == 3 | |
if self.pooling > 1: | |
data = skimage.measure.block_reduce( | |
data, block_size=(1, self.pooling, self.pooling), func=np.mean | |
) | |
means, stds, epsilons, sl_scale_factors = self.transformation_inputs() | |
result_np = transform(data, means, stds, sl_scale_factors, epsilons) | |
return result_np | |