|
import yaml |
|
import logging |
|
from dataclasses import dataclass |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
import torch.nn.functional as F |
|
from torch.utils.data import DataLoader |
|
|
|
import matplotlib.pyplot as plt |
|
import sunpy.visualization.colormaps as sunpy_cm |
|
|
|
import gradio as gr |
|
from huggingface_hub import snapshot_download |
|
|
|
from surya.datasets.helio import HelioNetCDFDataset, inverse_transform_single_channel |
|
from surya.models.helio_spectformer import HelioSpectFormer |
|
from surya.utils.data import build_scalers, custom_collate_fn |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
SDO_CHANNELS = [ |
|
"aia94", |
|
"aia131", |
|
"aia171", |
|
"aia193", |
|
"aia211", |
|
"aia304", |
|
"aia335", |
|
"aia1600", |
|
"hmi_m", |
|
"hmi_bx", |
|
"hmi_by", |
|
"hmi_bz", |
|
"hmi_v", |
|
] |
|
|
|
@dataclass |
|
class SDOImage: |
|
channel: str |
|
data: np.ndarray |
|
timestamp: str |
|
type: str |
|
|
|
def download_data(): |
|
snapshot_download( |
|
repo_id="nasa-ibm-ai4science/Surya-1.0", |
|
local_dir="data/Surya-1.0", |
|
allow_patterns=["config.yaml", "scalers.yaml", "surya.366m.v1.pt"], |
|
token=None, |
|
) |
|
snapshot_download( |
|
repo_id="nasa-ibm-ai4science/Surya-1.0_validation_data", |
|
repo_type="dataset", |
|
local_dir="data/Surya-1.0_validation_data", |
|
allow_patterns="20140107_1[5-9]??.nc", |
|
token=None, |
|
) |
|
|
|
def get_dataset(config, scalers) -> HelioNetCDFDataset: |
|
dataset = HelioNetCDFDataset( |
|
index_path="tests/test_surya_index.csv", |
|
time_delta_input_minutes=config["data"]["time_delta_input_minutes"], |
|
time_delta_target_minutes=config["data"]["time_delta_target_minutes"], |
|
n_input_timestamps=len(config["data"]["time_delta_input_minutes"]), |
|
rollout_steps=0, |
|
channels=config["data"]["sdo_channels"], |
|
drop_hmi_probability=config["data"]["drop_hmi_probability"], |
|
num_mask_aia_channels=config["data"]["num_mask_aia_channels"], |
|
use_latitude_in_learned_flow=config["data"]["use_latitude_in_learned_flow"], |
|
scalers=scalers, |
|
phase="valid", |
|
pooling=config["data"]["pooling"], |
|
random_vert_flip=config["data"]["random_vert_flip"], |
|
) |
|
logger.info(f"Initialized the dataset. {len(dataset)} samples.") |
|
|
|
return dataset |
|
|
|
def get_scalers() -> dict: |
|
scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r")) |
|
scalers = build_scalers(info=scalers_info) |
|
logger.info("Built the scalers.") |
|
return scalers |
|
|
|
def get_model_from_config(config) -> HelioSpectFormer: |
|
model = HelioSpectFormer( |
|
img_size=config["model"]["img_size"], |
|
patch_size=config["model"]["patch_size"], |
|
in_chans=len(config["data"]["sdo_channels"]), |
|
embed_dim=config["model"]["embed_dim"], |
|
time_embedding={ |
|
"type": "linear", |
|
"time_dim": len(config["data"]["time_delta_input_minutes"]), |
|
}, |
|
depth=config["model"]["depth"], |
|
n_spectral_blocks=config["model"]["n_spectral_blocks"], |
|
num_heads=config["model"]["num_heads"], |
|
mlp_ratio=config["model"]["mlp_ratio"], |
|
drop_rate=config["model"]["drop_rate"], |
|
dtype=torch.bfloat16, |
|
window_size=config["model"]["window_size"], |
|
dp_rank=config["model"]["dp_rank"], |
|
learned_flow=config["model"]["learned_flow"], |
|
use_latitude_in_learned_flow=config["model"]["learned_flow"], |
|
init_weights=False, |
|
checkpoint_layers=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], |
|
rpe=config["model"]["rpe"], |
|
ensemble=config["model"]["ensemble"], |
|
finetune=config["model"]["finetune"], |
|
) |
|
logger.info("Initialized the model.") |
|
|
|
return model |
|
|
|
def get_config() -> dict: |
|
with open("data/Surya-1.0/config.yaml") as fp: |
|
config = yaml.safe_load(fp) |
|
|
|
return config |
|
|
|
def setup(): |
|
logger.info("Loading data ...") |
|
download_data() |
|
config = get_config() |
|
scalers = get_scalers() |
|
|
|
logger.info("Initializing dataset ...") |
|
dataset = get_dataset(config, scalers) |
|
|
|
logger.info("Initializing model ...") |
|
model = get_model_from_config(config) |
|
if torch.cuda.is_available(): |
|
device = torch.cuda.current_device() |
|
logger.info(f"GPU detected. Running the test on device {device}.") |
|
else: |
|
device = "cpu" |
|
logger.warning(f"No GPU detected. Running the test on CPU.") |
|
model.to(device) |
|
n_parameters = sum(p.numel() for p in model.parameters()) / 1e6 |
|
logger.info(f"Surya FM: {n_parameters:.2f} M total parameters.") |
|
path_weights = "data/Surya-1.0/surya.366m.v1.pt" |
|
weights = torch.load( |
|
path_weights, map_location=torch.device(device), weights_only=True |
|
) |
|
model.load_state_dict(weights, strict=True) |
|
logger.info("Loaded weights.") |
|
|
|
return dataset, model, device |
|
|
|
def batch_step( |
|
model: HelioSpectFormer, |
|
sample_data: dict, |
|
sample_metadata: dict, |
|
device: int | str, |
|
hours_ahead: int = 1, |
|
) -> np.ndarray: |
|
""" |
|
Perform a single batch step for the given model, batch data, metadata, and device. |
|
|
|
Args: |
|
model: The PyTorch model to use for prediction. |
|
sample_data: A dictionary containing input and target data for the batch. |
|
sample_metadata: A dictionary containing metadata for the batch, including timestamps. |
|
device: The device to use for computation ('cpu', 'cuda' or device number). |
|
hours_ahead: The number of steps to forecast ahead. Defaults to 1. |
|
|
|
Returns: |
|
np.ndarray: Output data. |
|
""" |
|
|
|
data_returned = [] |
|
forecast_hat = None |
|
|
|
for step in range(1, hours_ahead + 1): |
|
if step == 1: |
|
curr_batch = { |
|
key: torch.from_numpy(sample_data[key]).unsqueeze(0).to(device) |
|
for key in ["ts", "time_delta_input"] |
|
} |
|
else: |
|
|
|
if forecast_hat is not None: |
|
curr_batch["ts"] = torch.cat( |
|
(curr_batch["ts"][:, :, 1:, ...], forecast_hat[:, :, None, ...]), |
|
dim=2, |
|
) |
|
|
|
forecast_hat = model(curr_batch) |
|
|
|
data_returned = forecast_hat.to(dtype=torch.float32).cpu().squeeze(0).numpy() |
|
|
|
return data_returned |
|
|
|
|
|
def run_inference(init_time_idx, plt_channel_idx, hours_ahead): |
|
plt_channel_str = SDO_CHANNELS[plt_channel_idx] |
|
|
|
input_timestamp_1 = dataset.valid_indices[init_time_idx] |
|
input_timestamp_0 = input_timestamp_1 - pd.Timedelta(1, "h") |
|
output_timestamp = input_timestamp_1 + pd.Timedelta(int(hours_ahead), "h") |
|
|
|
input_timestamp_0 = input_timestamp_0.strftime("%Y-%m-%d %H:%M") |
|
input_timestamp_1 = input_timestamp_1.strftime("%Y-%m-%d %H:%M") |
|
output_timestamp = output_timestamp.strftime("%Y-%m-%d %H:%M") |
|
|
|
sample_data, sample_metadata = dataset[init_time_idx] |
|
with torch.no_grad(): |
|
model_output = batch_step( |
|
model, |
|
sample_data, |
|
sample_metadata, |
|
device, |
|
hours_ahead |
|
) |
|
|
|
means, stds, epsilons, sl_scale_factors = dataset.transformation_inputs() |
|
|
|
vmin = float("-inf") |
|
vmax = float("inf") |
|
input_image = [] |
|
for i in range(2): |
|
input_image.append( |
|
inverse_transform_single_channel( |
|
sample_data["ts"][plt_channel_idx, i], |
|
mean=means[plt_channel_idx], |
|
std=stds[plt_channel_idx], |
|
epsilon=epsilons[plt_channel_idx], |
|
sl_scale_factor=sl_scale_factors[plt_channel_idx], |
|
) |
|
) |
|
vmin = max(vmin, sample_data["ts"][plt_channel_idx, i].min()) |
|
|
|
vmax = min(vmax, sample_data["ts"][plt_channel_idx, i].max()) |
|
|
|
if plt_channel_str.startswith("aia"): |
|
cm_name = "sdo" + plt_channel_str |
|
else: |
|
cm_name = "hmimag" |
|
|
|
input_image = [ |
|
sunpy_cm.cmlist[cm_name]( |
|
(img[::-1]-vmin) / (vmax-vmin), bytes=True |
|
) |
|
for img in input_image |
|
] |
|
|
|
output_image = inverse_transform_single_channel( |
|
model_output[plt_channel_idx], |
|
mean=means[plt_channel_idx], |
|
std=stds[plt_channel_idx], |
|
epsilon=epsilons[plt_channel_idx], |
|
sl_scale_factor=sl_scale_factors[plt_channel_idx], |
|
) |
|
output_image = sunpy_cm.cmlist[cm_name]( |
|
(output_image[::-1]-vmin) / (vmax-vmin), bytes=True |
|
) |
|
|
|
return input_timestamp_0, input_image[0], input_timestamp_1, input_image[1], output_timestamp, output_image |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
dataset, model, device = setup() |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(value="# Surya 1.0 - Visual forecasting demo") |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
init_time = gr.Dropdown( |
|
[v.strftime("%Y-%m-%d %H:%M") for v in dataset.valid_indices], |
|
label="Initialization time", |
|
multiselect=False, |
|
type="index", |
|
) |
|
with gr.Column(): |
|
plt_channel = gr.Dropdown( |
|
[c.upper() for c in SDO_CHANNELS], |
|
label="SDO Band", |
|
value="AIA94", |
|
multiselect=False, |
|
type="index" |
|
) |
|
with gr.Row(): |
|
hours_ahead = gr.Slider(minimum=1.0, maximum=6.0, step=1.0, label="Forcast step [hours ahead]") |
|
with gr.Row(): |
|
btn = gr.Button("Run") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_timestamp_0 = gr.Textbox(label="Input 0") |
|
input_image_0 = gr.Image() |
|
with gr.Column(): |
|
input_timestamp_1 = gr.Textbox(label="Input 1") |
|
input_image_1 = gr.Image() |
|
with gr.Column(): |
|
output_timestamp = gr.Textbox(label="Prediction") |
|
output_image = gr.Image() |
|
|
|
btn.click( |
|
fn=run_inference, |
|
inputs=[init_time, plt_channel, hours_ahead], |
|
outputs=[input_timestamp_0, input_image_0, input_timestamp_1, input_image_1, output_timestamp, output_image] |
|
) |
|
|
|
with gr.Row(): |
|
gr.Examples( |
|
examples=[ |
|
["2014-01-07 17:24", "AIA94", 2], |
|
["2014-01-07 16:12", "AIA94", 6], |
|
["2014-01-07 16:00", "AIA131", 1], |
|
["2014-01-07 16:00", "HMI_M", 2], |
|
], |
|
fn=run_inference, |
|
inputs=[init_time, plt_channel, hours_ahead], |
|
outputs=[input_timestamp_0, input_image_0, input_timestamp_1, input_image_1, output_timestamp, output_image], |
|
cache_examples=False, |
|
) |
|
|
|
demo.launch() |
|
|