Spaces:
Runtime error
Runtime error
import yaml | |
import logging | |
from dataclasses import dataclass | |
from pathlib import Path | |
import numpy as np | |
import pandas as pd | |
import torch | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader | |
import matplotlib.pyplot as plt | |
import sunpy.visualization.colormaps as sunpy_cm | |
import gradio as gr | |
from huggingface_hub import snapshot_download | |
from surya.datasets.helio import HelioNetCDFDataset, inverse_transform_single_channel | |
from surya.models.helio_spectformer import HelioSpectFormer | |
from surya.utils.data import build_scalers, custom_collate_fn | |
logger = logging.getLogger(__name__) | |
SDO_CHANNELS = [ | |
"aia94", | |
"aia131", | |
"aia171", | |
"aia193", | |
"aia211", | |
"aia304", | |
"aia335", | |
"aia1600", | |
"hmi_m", | |
"hmi_bx", | |
"hmi_by", | |
"hmi_bz", | |
"hmi_v", | |
] | |
class SDOImage: | |
channel: str | |
data: np.ndarray | |
timestamp: str | |
type: str | |
def download_data(): | |
snapshot_download( | |
repo_id="nasa-ibm-ai4science/Surya-1.0", | |
local_dir="data/Surya-1.0", | |
allow_patterns=["config.yaml", "scalers.yaml", "surya.366m.v1.pt"], | |
token=None, | |
) | |
snapshot_download( | |
repo_id="nasa-ibm-ai4science/Surya-1.0_validation_data", | |
repo_type="dataset", | |
local_dir="data/Surya-1.0_validation_data", | |
allow_patterns="20140107_1[5-9]??.nc", | |
token=None, | |
) | |
def get_dataset(config, scalers) -> HelioNetCDFDataset: | |
dataset = HelioNetCDFDataset( | |
index_path="tests/test_surya_index.csv", | |
time_delta_input_minutes=config["data"]["time_delta_input_minutes"], | |
time_delta_target_minutes=config["data"]["time_delta_target_minutes"], | |
n_input_timestamps=len(config["data"]["time_delta_input_minutes"]), | |
rollout_steps=0, | |
channels=config["data"]["sdo_channels"], | |
drop_hmi_probability=config["data"]["drop_hmi_probability"], | |
num_mask_aia_channels=config["data"]["num_mask_aia_channels"], | |
use_latitude_in_learned_flow=config["data"]["use_latitude_in_learned_flow"], | |
scalers=scalers, | |
phase="valid", | |
pooling=config["data"]["pooling"], | |
random_vert_flip=config["data"]["random_vert_flip"], | |
) | |
logger.info(f"Initialized the dataset. {len(dataset)} samples.") | |
return dataset | |
def get_scalers() -> dict: | |
scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r")) | |
scalers = build_scalers(info=scalers_info) | |
logger.info("Built the scalers.") | |
return scalers | |
def get_model_from_config(config) -> HelioSpectFormer: | |
model = HelioSpectFormer( | |
img_size=config["model"]["img_size"], | |
patch_size=config["model"]["patch_size"], | |
in_chans=len(config["data"]["sdo_channels"]), | |
embed_dim=config["model"]["embed_dim"], | |
time_embedding={ | |
"type": "linear", | |
"time_dim": len(config["data"]["time_delta_input_minutes"]), | |
}, | |
depth=config["model"]["depth"], | |
n_spectral_blocks=config["model"]["n_spectral_blocks"], | |
num_heads=config["model"]["num_heads"], | |
mlp_ratio=config["model"]["mlp_ratio"], | |
drop_rate=config["model"]["drop_rate"], | |
dtype=torch.bfloat16, | |
window_size=config["model"]["window_size"], | |
dp_rank=config["model"]["dp_rank"], | |
learned_flow=config["model"]["learned_flow"], | |
use_latitude_in_learned_flow=config["model"]["learned_flow"], | |
init_weights=False, | |
checkpoint_layers=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], | |
rpe=config["model"]["rpe"], | |
ensemble=config["model"]["ensemble"], | |
finetune=config["model"]["finetune"], | |
) | |
logger.info("Initialized the model.") | |
return model | |
def get_config() -> dict: | |
with open("data/Surya-1.0/config.yaml") as fp: | |
config = yaml.safe_load(fp) | |
return config | |
def setup(): | |
logger.info("Loading data ...") | |
download_data() | |
config = get_config() | |
scalers = get_scalers() | |
logger.info("Initializing dataset ...") | |
dataset = get_dataset(config, scalers) | |
logger.info("Initializing model ...") | |
model = get_model_from_config(config) | |
if torch.cuda.is_available(): | |
device = torch.cuda.current_device() | |
logger.info(f"GPU detected. Running the test on device {device}.") | |
else: | |
device = "cpu" | |
logger.warning(f"No GPU detected. Running the test on CPU.") | |
model.to(device) | |
n_parameters = sum(p.numel() for p in model.parameters()) / 1e6 | |
logger.info(f"Surya FM: {n_parameters:.2f} M total parameters.") | |
path_weights = "data/Surya-1.0/surya.366m.v1.pt" | |
weights = torch.load( | |
path_weights, map_location=torch.device(device), weights_only=True | |
) | |
model.load_state_dict(weights, strict=True) | |
logger.info("Loaded weights.") | |
return dataset, model, device | |
def batch_step( | |
model: HelioSpectFormer, | |
sample_data: dict, | |
sample_metadata: dict, | |
device: int | str, | |
hours_ahead: int = 1, | |
) -> np.ndarray: | |
""" | |
Perform a single batch step for the given model, batch data, metadata, and device. | |
Args: | |
model: The PyTorch model to use for prediction. | |
sample_data: A dictionary containing input and target data for the batch. | |
sample_metadata: A dictionary containing metadata for the batch, including timestamps. | |
device: The device to use for computation ('cpu', 'cuda' or device number). | |
hours_ahead: The number of steps to forecast ahead. Defaults to 1. | |
Returns: | |
np.ndarray: Output data. | |
""" | |
data_returned = [] | |
forecast_hat = None # Initialize forecast_hat | |
for step in range(1, hours_ahead + 1): | |
if step == 1: | |
curr_batch = { | |
key: torch.from_numpy(sample_data[key]).unsqueeze(0).to(device) | |
for key in ["ts", "time_delta_input"] | |
} | |
else: | |
# Use the previous forecast_hat from the previous iteration | |
if forecast_hat is not None: | |
curr_batch["ts"] = torch.cat( | |
(curr_batch["ts"][:, :, 1:, ...], forecast_hat[:, :, None, ...]), | |
dim=2, | |
) | |
forecast_hat = model(curr_batch) | |
data_returned = forecast_hat.to(dtype=torch.float32).cpu().squeeze(0).numpy() | |
return data_returned | |
def run_inference(init_time_idx, plt_channel_idx, hours_ahead): | |
plt_channel_str = SDO_CHANNELS[plt_channel_idx] | |
input_timestamp_1 = dataset.valid_indices[init_time_idx] | |
input_timestamp_0 = input_timestamp_1 - pd.Timedelta(1, "h") | |
output_timestamp = input_timestamp_1 + pd.Timedelta(int(hours_ahead), "h") | |
input_timestamp_0 = input_timestamp_0.strftime("%Y-%m-%d %H:%M") | |
input_timestamp_1 = input_timestamp_1.strftime("%Y-%m-%d %H:%M") | |
output_timestamp = output_timestamp.strftime("%Y-%m-%d %H:%M") | |
sample_data, sample_metadata = dataset[init_time_idx] | |
with torch.no_grad(): | |
model_output = batch_step( | |
model, | |
sample_data, | |
sample_metadata, | |
device, | |
hours_ahead | |
) | |
means, stds, epsilons, sl_scale_factors = dataset.transformation_inputs() | |
vmin = float("-inf") | |
vmax = float("inf") | |
input_image = [] | |
for i in range(2): | |
input_image.append( | |
inverse_transform_single_channel( | |
sample_data["ts"][plt_channel_idx, i], | |
mean=means[plt_channel_idx], | |
std=stds[plt_channel_idx], | |
epsilon=epsilons[plt_channel_idx], | |
sl_scale_factor=sl_scale_factors[plt_channel_idx], | |
) | |
) | |
vmin = max(vmin, input_image[i].min()) | |
vmax = min(vmax, np.quantile(input_image[i], 0.99)) | |
if plt_channel_str.startswith("aia"): | |
cm_name = "sdo" + plt_channel_str | |
else: | |
cm_name = "hmimag" | |
input_image = [ | |
sunpy_cm.cmlist[cm_name]( | |
(img[::-1]-vmin) / (vmax-vmin), bytes=True | |
) | |
for img in input_image | |
] | |
output_image = inverse_transform_single_channel( | |
model_output[plt_channel_idx], | |
mean=means[plt_channel_idx], | |
std=stds[plt_channel_idx], | |
epsilon=epsilons[plt_channel_idx], | |
sl_scale_factor=sl_scale_factors[plt_channel_idx], | |
) | |
output_image = sunpy_cm.cmlist[cm_name]( | |
(output_image[::-1]-vmin) / (vmax-vmin), bytes=True | |
) | |
return input_timestamp_0, input_image[0], input_timestamp_1, input_image[1], output_timestamp, output_image | |
logging.basicConfig(level=logging.INFO) | |
dataset, model, device = setup() | |
with gr.Blocks() as demo: | |
gr.Markdown(value="# Surya 1.0 - Visual forecasting demo") | |
#with gr.Row(): | |
#with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
init_time = gr.Dropdown( | |
[v.strftime("%Y-%m-%d %H:%M") for v in dataset.valid_indices], | |
label="Initialization time", | |
multiselect=False, | |
type="index", | |
) | |
with gr.Column(): | |
plt_channel = gr.Dropdown( | |
[c.upper() for c in SDO_CHANNELS], | |
label="SDO Band", | |
value="AIA94", | |
multiselect=False, | |
type="index" | |
) | |
with gr.Row(): | |
hours_ahead = gr.Slider(minimum=1.0, maximum=6.0, step=1.0, label="Forcast step [hours ahead]") | |
with gr.Row(): | |
btn = gr.Button("Run") | |
with gr.Row(): | |
with gr.Column(): | |
input_timestamp_0 = gr.Textbox(label="Input 0") | |
input_image_0 = gr.Image() | |
with gr.Column(): | |
input_timestamp_1 = gr.Textbox(label="Input 1") | |
input_image_1 = gr.Image() | |
with gr.Column(): | |
output_timestamp = gr.Textbox(label="Prediction") | |
output_image = gr.Image() | |
btn.click( | |
fn=run_inference, | |
inputs=[init_time, plt_channel, hours_ahead], | |
outputs=[input_timestamp_0, input_image_0, input_timestamp_1, input_image_1, output_timestamp, output_image] | |
) | |
with gr.Row(): | |
gr.Examples( | |
examples=[ | |
["2014-01-07 17:24", "AIA94", 2], | |
["2014-01-07 16:12", "AIA94", 6], | |
["2014-01-07 16:00", "AIA131", 1], | |
["2014-01-07 16:00", "HMI_M", 2], | |
], | |
fn=run_inference, | |
inputs=[init_time, plt_channel, hours_ahead], | |
outputs=[input_timestamp_0, input_image_0, input_timestamp_1, input_image_1, output_timestamp, output_image], | |
cache_examples=False, | |
) | |
demo.launch() | |