import yaml import logging from dataclasses import dataclass from pathlib import Path import numpy as np import pandas as pd import torch import torch.nn.functional as F from torch.utils.data import DataLoader import matplotlib.pyplot as plt import sunpy.visualization.colormaps as sunpy_cm import gradio as gr from huggingface_hub import snapshot_download from surya.datasets.helio import HelioNetCDFDataset, inverse_transform_single_channel from surya.models.helio_spectformer import HelioSpectFormer from surya.utils.data import build_scalers, custom_collate_fn logger = logging.getLogger(__name__) SDO_CHANNELS = [ "aia94", "aia131", "aia171", "aia193", "aia211", "aia304", "aia335", "aia1600", "hmi_m", "hmi_bx", "hmi_by", "hmi_bz", "hmi_v", ] @dataclass class SDOImage: channel: str data: np.ndarray timestamp: str type: str def download_data(): snapshot_download( repo_id="nasa-ibm-ai4science/Surya-1.0", local_dir="data/Surya-1.0", allow_patterns=["config.yaml", "scalers.yaml", "surya.366m.v1.pt"], token=None, ) snapshot_download( repo_id="nasa-ibm-ai4science/Surya-1.0_validation_data", repo_type="dataset", local_dir="data/Surya-1.0_validation_data", allow_patterns="20140107_1[5-9]??.nc", token=None, ) def get_dataset(config, scalers) -> HelioNetCDFDataset: dataset = HelioNetCDFDataset( index_path="tests/test_surya_index.csv", time_delta_input_minutes=config["data"]["time_delta_input_minutes"], time_delta_target_minutes=config["data"]["time_delta_target_minutes"], n_input_timestamps=len(config["data"]["time_delta_input_minutes"]), rollout_steps=0, channels=config["data"]["sdo_channels"], drop_hmi_probability=config["data"]["drop_hmi_probability"], num_mask_aia_channels=config["data"]["num_mask_aia_channels"], use_latitude_in_learned_flow=config["data"]["use_latitude_in_learned_flow"], scalers=scalers, phase="valid", pooling=config["data"]["pooling"], random_vert_flip=config["data"]["random_vert_flip"], ) logger.info(f"Initialized the dataset. {len(dataset)} samples.") return dataset def get_scalers() -> dict: scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r")) scalers = build_scalers(info=scalers_info) logger.info("Built the scalers.") return scalers def get_model_from_config(config) -> HelioSpectFormer: model = HelioSpectFormer( img_size=config["model"]["img_size"], patch_size=config["model"]["patch_size"], in_chans=len(config["data"]["sdo_channels"]), embed_dim=config["model"]["embed_dim"], time_embedding={ "type": "linear", "time_dim": len(config["data"]["time_delta_input_minutes"]), }, depth=config["model"]["depth"], n_spectral_blocks=config["model"]["n_spectral_blocks"], num_heads=config["model"]["num_heads"], mlp_ratio=config["model"]["mlp_ratio"], drop_rate=config["model"]["drop_rate"], dtype=torch.bfloat16, window_size=config["model"]["window_size"], dp_rank=config["model"]["dp_rank"], learned_flow=config["model"]["learned_flow"], use_latitude_in_learned_flow=config["model"]["learned_flow"], init_weights=False, checkpoint_layers=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], rpe=config["model"]["rpe"], ensemble=config["model"]["ensemble"], finetune=config["model"]["finetune"], ) logger.info("Initialized the model.") return model def get_config() -> dict: with open("data/Surya-1.0/config.yaml") as fp: config = yaml.safe_load(fp) return config def setup(): logger.info("Loading data ...") download_data() config = get_config() scalers = get_scalers() logger.info("Initializing dataset ...") dataset = get_dataset(config, scalers) logger.info("Initializing model ...") model = get_model_from_config(config) if torch.cuda.is_available(): device = torch.cuda.current_device() logger.info(f"GPU detected. Running the test on device {device}.") else: device = "cpu" logger.warning(f"No GPU detected. Running the test on CPU.") model.to(device) n_parameters = sum(p.numel() for p in model.parameters()) / 1e6 logger.info(f"Surya FM: {n_parameters:.2f} M total parameters.") path_weights = "data/Surya-1.0/surya.366m.v1.pt" weights = torch.load( path_weights, map_location=torch.device(device), weights_only=True ) model.load_state_dict(weights, strict=True) logger.info("Loaded weights.") return dataset, model, device def batch_step( model: HelioSpectFormer, sample_data: dict, sample_metadata: dict, device: int | str, hours_ahead: int = 1, ) -> np.ndarray: """ Perform a single batch step for the given model, batch data, metadata, and device. Args: model: The PyTorch model to use for prediction. sample_data: A dictionary containing input and target data for the batch. sample_metadata: A dictionary containing metadata for the batch, including timestamps. device: The device to use for computation ('cpu', 'cuda' or device number). hours_ahead: The number of steps to forecast ahead. Defaults to 1. Returns: np.ndarray: Output data. """ data_returned = [] forecast_hat = None # Initialize forecast_hat for step in range(1, hours_ahead + 1): if step == 1: curr_batch = { key: torch.from_numpy(sample_data[key]).unsqueeze(0).to(device) for key in ["ts", "time_delta_input"] } else: # Use the previous forecast_hat from the previous iteration if forecast_hat is not None: curr_batch["ts"] = torch.cat( (curr_batch["ts"][:, :, 1:, ...], forecast_hat[:, :, None, ...]), dim=2, ) forecast_hat = model(curr_batch) data_returned = forecast_hat.to(dtype=torch.float32).cpu().squeeze(0).numpy() return data_returned def run_inference(init_time_idx, plt_channel_idx, hours_ahead): plt_channel_str = SDO_CHANNELS[plt_channel_idx] input_timestamp_1 = dataset.valid_indices[init_time_idx] input_timestamp_0 = input_timestamp_1 - pd.Timedelta(1, "h") output_timestamp = input_timestamp_1 + pd.Timedelta(int(hours_ahead), "h") input_timestamp_0 = input_timestamp_0.strftime("%Y-%m-%d %H:%M") input_timestamp_1 = input_timestamp_1.strftime("%Y-%m-%d %H:%M") output_timestamp = output_timestamp.strftime("%Y-%m-%d %H:%M") sample_data, sample_metadata = dataset[init_time_idx] with torch.no_grad(): model_output = batch_step( model, sample_data, sample_metadata, device, hours_ahead ) means, stds, epsilons, sl_scale_factors = dataset.transformation_inputs() vmin = float("-inf") vmax = float("inf") input_image = [] for i in range(2): input_image.append( inverse_transform_single_channel( sample_data["ts"][plt_channel_idx, i], mean=means[plt_channel_idx], std=stds[plt_channel_idx], epsilon=epsilons[plt_channel_idx], sl_scale_factor=sl_scale_factors[plt_channel_idx], ) ) vmin = max(vmin, input_image[i].min()) vmax = min(vmax, np.quantile(input_image[i], 0.99)) if plt_channel_str.startswith("aia"): cm_name = "sdo" + plt_channel_str else: cm_name = "hmimag" input_image = [ sunpy_cm.cmlist[cm_name]( (img[::-1]-vmin) / (vmax-vmin), bytes=True ) for img in input_image ] output_image = inverse_transform_single_channel( model_output[plt_channel_idx], mean=means[plt_channel_idx], std=stds[plt_channel_idx], epsilon=epsilons[plt_channel_idx], sl_scale_factor=sl_scale_factors[plt_channel_idx], ) output_image = sunpy_cm.cmlist[cm_name]( (output_image[::-1]-vmin) / (vmax-vmin), bytes=True ) return input_timestamp_0, input_image[0], input_timestamp_1, input_image[1], output_timestamp, output_image logging.basicConfig(level=logging.INFO) dataset, model, device = setup() with gr.Blocks() as demo: gr.Markdown(value="# Surya 1.0 - Visual forecasting demo") #with gr.Row(): #with gr.Column(): with gr.Row(): with gr.Column(): init_time = gr.Dropdown( [v.strftime("%Y-%m-%d %H:%M") for v in dataset.valid_indices], label="Initialization time", multiselect=False, type="index", ) with gr.Column(): plt_channel = gr.Dropdown( [c.upper() for c in SDO_CHANNELS], label="SDO Band", value="AIA94", multiselect=False, type="index" ) with gr.Row(): hours_ahead = gr.Slider(minimum=1.0, maximum=6.0, step=1.0, label="Forcast step [hours ahead]") with gr.Row(): btn = gr.Button("Run") with gr.Row(): with gr.Column(): input_timestamp_0 = gr.Textbox(label="Input 0") input_image_0 = gr.Image() with gr.Column(): input_timestamp_1 = gr.Textbox(label="Input 1") input_image_1 = gr.Image() with gr.Column(): output_timestamp = gr.Textbox(label="Prediction") output_image = gr.Image() btn.click( fn=run_inference, inputs=[init_time, plt_channel, hours_ahead], outputs=[input_timestamp_0, input_image_0, input_timestamp_1, input_image_1, output_timestamp, output_image] ) with gr.Row(): gr.Examples( examples=[ ["2014-01-07 17:24", "AIA94", 2], ["2014-01-07 16:12", "AIA94", 6], ["2014-01-07 16:00", "AIA131", 1], ["2014-01-07 16:00", "HMI_M", 2], ], fn=run_inference, inputs=[init_time, plt_channel, hours_ahead], outputs=[input_timestamp_0, input_image_0, input_timestamp_1, input_image_1, output_timestamp, output_image], cache_examples=False, ) demo.launch()