johannesschmude's picture
Bug fix -- incorrect color scale.
7adac5a
import yaml
import logging
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import sunpy.visualization.colormaps as sunpy_cm
import gradio as gr
from huggingface_hub import snapshot_download
from surya.datasets.helio import HelioNetCDFDataset, inverse_transform_single_channel
from surya.models.helio_spectformer import HelioSpectFormer
from surya.utils.data import build_scalers, custom_collate_fn
logger = logging.getLogger(__name__)
SDO_CHANNELS = [
"aia94",
"aia131",
"aia171",
"aia193",
"aia211",
"aia304",
"aia335",
"aia1600",
"hmi_m",
"hmi_bx",
"hmi_by",
"hmi_bz",
"hmi_v",
]
@dataclass
class SDOImage:
channel: str
data: np.ndarray
timestamp: str
type: str
def download_data():
snapshot_download(
repo_id="nasa-ibm-ai4science/Surya-1.0",
local_dir="data/Surya-1.0",
allow_patterns=["config.yaml", "scalers.yaml", "surya.366m.v1.pt"],
token=None,
)
snapshot_download(
repo_id="nasa-ibm-ai4science/Surya-1.0_validation_data",
repo_type="dataset",
local_dir="data/Surya-1.0_validation_data",
allow_patterns="20140107_1[5-9]??.nc",
token=None,
)
def get_dataset(config, scalers) -> HelioNetCDFDataset:
dataset = HelioNetCDFDataset(
index_path="tests/test_surya_index.csv",
time_delta_input_minutes=config["data"]["time_delta_input_minutes"],
time_delta_target_minutes=config["data"]["time_delta_target_minutes"],
n_input_timestamps=len(config["data"]["time_delta_input_minutes"]),
rollout_steps=0,
channels=config["data"]["sdo_channels"],
drop_hmi_probability=config["data"]["drop_hmi_probability"],
num_mask_aia_channels=config["data"]["num_mask_aia_channels"],
use_latitude_in_learned_flow=config["data"]["use_latitude_in_learned_flow"],
scalers=scalers,
phase="valid",
pooling=config["data"]["pooling"],
random_vert_flip=config["data"]["random_vert_flip"],
)
logger.info(f"Initialized the dataset. {len(dataset)} samples.")
return dataset
def get_scalers() -> dict:
scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r"))
scalers = build_scalers(info=scalers_info)
logger.info("Built the scalers.")
return scalers
def get_model_from_config(config) -> HelioSpectFormer:
model = HelioSpectFormer(
img_size=config["model"]["img_size"],
patch_size=config["model"]["patch_size"],
in_chans=len(config["data"]["sdo_channels"]),
embed_dim=config["model"]["embed_dim"],
time_embedding={
"type": "linear",
"time_dim": len(config["data"]["time_delta_input_minutes"]),
},
depth=config["model"]["depth"],
n_spectral_blocks=config["model"]["n_spectral_blocks"],
num_heads=config["model"]["num_heads"],
mlp_ratio=config["model"]["mlp_ratio"],
drop_rate=config["model"]["drop_rate"],
dtype=torch.bfloat16,
window_size=config["model"]["window_size"],
dp_rank=config["model"]["dp_rank"],
learned_flow=config["model"]["learned_flow"],
use_latitude_in_learned_flow=config["model"]["learned_flow"],
init_weights=False,
checkpoint_layers=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
rpe=config["model"]["rpe"],
ensemble=config["model"]["ensemble"],
finetune=config["model"]["finetune"],
)
logger.info("Initialized the model.")
return model
def get_config() -> dict:
with open("data/Surya-1.0/config.yaml") as fp:
config = yaml.safe_load(fp)
return config
def setup():
logger.info("Loading data ...")
download_data()
config = get_config()
scalers = get_scalers()
logger.info("Initializing dataset ...")
dataset = get_dataset(config, scalers)
logger.info("Initializing model ...")
model = get_model_from_config(config)
if torch.cuda.is_available():
device = torch.cuda.current_device()
logger.info(f"GPU detected. Running the test on device {device}.")
else:
device = "cpu"
logger.warning(f"No GPU detected. Running the test on CPU.")
model.to(device)
n_parameters = sum(p.numel() for p in model.parameters()) / 1e6
logger.info(f"Surya FM: {n_parameters:.2f} M total parameters.")
path_weights = "data/Surya-1.0/surya.366m.v1.pt"
weights = torch.load(
path_weights, map_location=torch.device(device), weights_only=True
)
model.load_state_dict(weights, strict=True)
logger.info("Loaded weights.")
return dataset, model, device
def batch_step(
model: HelioSpectFormer,
sample_data: dict,
sample_metadata: dict,
device: int | str,
hours_ahead: int = 1,
) -> np.ndarray:
"""
Perform a single batch step for the given model, batch data, metadata, and device.
Args:
model: The PyTorch model to use for prediction.
sample_data: A dictionary containing input and target data for the batch.
sample_metadata: A dictionary containing metadata for the batch, including timestamps.
device: The device to use for computation ('cpu', 'cuda' or device number).
hours_ahead: The number of steps to forecast ahead. Defaults to 1.
Returns:
np.ndarray: Output data.
"""
data_returned = []
forecast_hat = None # Initialize forecast_hat
for step in range(1, hours_ahead + 1):
if step == 1:
curr_batch = {
key: torch.from_numpy(sample_data[key]).unsqueeze(0).to(device)
for key in ["ts", "time_delta_input"]
}
else:
# Use the previous forecast_hat from the previous iteration
if forecast_hat is not None:
curr_batch["ts"] = torch.cat(
(curr_batch["ts"][:, :, 1:, ...], forecast_hat[:, :, None, ...]),
dim=2,
)
forecast_hat = model(curr_batch)
data_returned = forecast_hat.to(dtype=torch.float32).cpu().squeeze(0).numpy()
return data_returned
def run_inference(init_time_idx, plt_channel_idx, hours_ahead):
plt_channel_str = SDO_CHANNELS[plt_channel_idx]
input_timestamp_1 = dataset.valid_indices[init_time_idx]
input_timestamp_0 = input_timestamp_1 - pd.Timedelta(1, "h")
output_timestamp = input_timestamp_1 + pd.Timedelta(int(hours_ahead), "h")
input_timestamp_0 = input_timestamp_0.strftime("%Y-%m-%d %H:%M")
input_timestamp_1 = input_timestamp_1.strftime("%Y-%m-%d %H:%M")
output_timestamp = output_timestamp.strftime("%Y-%m-%d %H:%M")
sample_data, sample_metadata = dataset[init_time_idx]
with torch.no_grad():
model_output = batch_step(
model,
sample_data,
sample_metadata,
device,
hours_ahead
)
means, stds, epsilons, sl_scale_factors = dataset.transformation_inputs()
vmin = float("-inf")
vmax = float("inf")
input_image = []
for i in range(2):
input_image.append(
inverse_transform_single_channel(
sample_data["ts"][plt_channel_idx, i],
mean=means[plt_channel_idx],
std=stds[plt_channel_idx],
epsilon=epsilons[plt_channel_idx],
sl_scale_factor=sl_scale_factors[plt_channel_idx],
)
)
vmin = max(vmin, input_image[i].min())
vmax = min(vmax, np.quantile(input_image[i], 0.99))
if plt_channel_str.startswith("aia"):
cm_name = "sdo" + plt_channel_str
else:
cm_name = "hmimag"
input_image = [
sunpy_cm.cmlist[cm_name](
(img[::-1]-vmin) / (vmax-vmin), bytes=True
)
for img in input_image
]
output_image = inverse_transform_single_channel(
model_output[plt_channel_idx],
mean=means[plt_channel_idx],
std=stds[plt_channel_idx],
epsilon=epsilons[plt_channel_idx],
sl_scale_factor=sl_scale_factors[plt_channel_idx],
)
output_image = sunpy_cm.cmlist[cm_name](
(output_image[::-1]-vmin) / (vmax-vmin), bytes=True
)
return input_timestamp_0, input_image[0], input_timestamp_1, input_image[1], output_timestamp, output_image
logging.basicConfig(level=logging.INFO)
dataset, model, device = setup()
with gr.Blocks() as demo:
gr.Markdown(value="# Surya 1.0 - Visual forecasting demo")
#with gr.Row():
#with gr.Column():
with gr.Row():
with gr.Column():
init_time = gr.Dropdown(
[v.strftime("%Y-%m-%d %H:%M") for v in dataset.valid_indices],
label="Initialization time",
multiselect=False,
type="index",
)
with gr.Column():
plt_channel = gr.Dropdown(
[c.upper() for c in SDO_CHANNELS],
label="SDO Band",
value="AIA94",
multiselect=False,
type="index"
)
with gr.Row():
hours_ahead = gr.Slider(minimum=1.0, maximum=6.0, step=1.0, label="Forcast step [hours ahead]")
with gr.Row():
btn = gr.Button("Run")
with gr.Row():
with gr.Column():
input_timestamp_0 = gr.Textbox(label="Input 0")
input_image_0 = gr.Image()
with gr.Column():
input_timestamp_1 = gr.Textbox(label="Input 1")
input_image_1 = gr.Image()
with gr.Column():
output_timestamp = gr.Textbox(label="Prediction")
output_image = gr.Image()
btn.click(
fn=run_inference,
inputs=[init_time, plt_channel, hours_ahead],
outputs=[input_timestamp_0, input_image_0, input_timestamp_1, input_image_1, output_timestamp, output_image]
)
with gr.Row():
gr.Examples(
examples=[
["2014-01-07 17:24", "AIA94", 2],
["2014-01-07 16:12", "AIA94", 6],
["2014-01-07 16:00", "AIA131", 1],
["2014-01-07 16:00", "HMI_M", 2],
],
fn=run_inference,
inputs=[init_time, plt_channel, hours_ahead],
outputs=[input_timestamp_0, input_image_0, input_timestamp_1, input_image_1, output_timestamp, output_image],
cache_examples=False,
)
demo.launch()