johannesschmude's picture
Removing socket
e794779
raw
history blame
10.7 kB
import yaml
import logging
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import sunpy.visualization.colormaps as sunpy_cm
import gradio as gr
from huggingface_hub import snapshot_download
from surya.datasets.helio import HelioNetCDFDataset, inverse_transform_single_channel
from surya.models.helio_spectformer import HelioSpectFormer
from surya.utils.data import build_scalers, custom_collate_fn
logger = logging.getLogger(__name__)
SDO_CHANNELS = [
"aia94",
"aia131",
"aia171",
"aia193",
"aia211",
"aia304",
"aia335",
"aia1600",
"hmi_m",
"hmi_bx",
"hmi_by",
"hmi_bz",
"hmi_v",
]
@dataclass
class SDOImage:
channel: str
data: np.ndarray
timestamp: str
type: str
def download_data():
snapshot_download(
repo_id="nasa-ibm-ai4science/Surya-1.0",
local_dir="data/Surya-1.0",
allow_patterns=["config.yaml", "scalers.yaml", "surya.366m.v1.pt"],
token=None,
)
snapshot_download(
repo_id="nasa-ibm-ai4science/Surya-1.0_validation_data",
repo_type="dataset",
local_dir="data/Surya-1.0_validation_data",
allow_patterns="20140107_1[5-9]??.nc",
token=None,
)
def get_dataset(config, scalers) -> HelioNetCDFDataset:
dataset = HelioNetCDFDataset(
index_path="tests/test_surya_index.csv",
time_delta_input_minutes=config["data"]["time_delta_input_minutes"],
time_delta_target_minutes=config["data"]["time_delta_target_minutes"],
n_input_timestamps=len(config["data"]["time_delta_input_minutes"]),
rollout_steps=0,
channels=config["data"]["sdo_channels"],
drop_hmi_probability=config["data"]["drop_hmi_probability"],
num_mask_aia_channels=config["data"]["num_mask_aia_channels"],
use_latitude_in_learned_flow=config["data"]["use_latitude_in_learned_flow"],
scalers=scalers,
phase="valid",
pooling=config["data"]["pooling"],
random_vert_flip=config["data"]["random_vert_flip"],
)
logger.info(f"Initialized the dataset. {len(dataset)} samples.")
return dataset
def get_scalers() -> dict:
scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r"))
scalers = build_scalers(info=scalers_info)
logger.info("Built the scalers.")
return scalers
def get_model_from_config(config) -> HelioSpectFormer:
model = HelioSpectFormer(
img_size=config["model"]["img_size"],
patch_size=config["model"]["patch_size"],
in_chans=len(config["data"]["sdo_channels"]),
embed_dim=config["model"]["embed_dim"],
time_embedding={
"type": "linear",
"time_dim": len(config["data"]["time_delta_input_minutes"]),
},
depth=config["model"]["depth"],
n_spectral_blocks=config["model"]["n_spectral_blocks"],
num_heads=config["model"]["num_heads"],
mlp_ratio=config["model"]["mlp_ratio"],
drop_rate=config["model"]["drop_rate"],
dtype=torch.bfloat16,
window_size=config["model"]["window_size"],
dp_rank=config["model"]["dp_rank"],
learned_flow=config["model"]["learned_flow"],
use_latitude_in_learned_flow=config["model"]["learned_flow"],
init_weights=False,
checkpoint_layers=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
rpe=config["model"]["rpe"],
ensemble=config["model"]["ensemble"],
finetune=config["model"]["finetune"],
)
logger.info("Initialized the model.")
return model
def get_config() -> dict:
with open("data/Surya-1.0/config.yaml") as fp:
config = yaml.safe_load(fp)
return config
def setup():
logger.info("Loading data ...")
download_data()
config = get_config()
scalers = get_scalers()
logger.info("Initializing dataset ...")
dataset = get_dataset(config, scalers)
logger.info("Initializing model ...")
model = get_model_from_config(config)
if torch.cuda.is_available():
device = torch.cuda.current_device()
logger.info(f"GPU detected. Running the test on device {device}.")
else:
device = "cpu"
logger.warning(f"No GPU detected. Running the test on CPU.")
model.to(device)
n_parameters = sum(p.numel() for p in model.parameters()) / 1e6
logger.info(f"Surya FM: {n_parameters:.2f} M total parameters.")
path_weights = "data/Surya-1.0/surya.366m.v1.pt"
weights = torch.load(
path_weights, map_location=torch.device(device), weights_only=True
)
model.load_state_dict(weights, strict=True)
logger.info("Loaded weights.")
return dataset, model, device
def batch_step(
model: HelioSpectFormer,
sample_data: dict,
sample_metadata: dict,
device: int | str,
hours_ahead: int = 1,
) -> np.ndarray:
"""
Perform a single batch step for the given model, batch data, metadata, and device.
Args:
model: The PyTorch model to use for prediction.
sample_data: A dictionary containing input and target data for the batch.
sample_metadata: A dictionary containing metadata for the batch, including timestamps.
device: The device to use for computation ('cpu', 'cuda' or device number).
hours_ahead: The number of steps to forecast ahead. Defaults to 1.
Returns:
np.ndarray: Output data.
"""
data_returned = []
forecast_hat = None # Initialize forecast_hat
for step in range(1, hours_ahead + 1):
if step == 1:
curr_batch = {
key: torch.from_numpy(sample_data[key]).unsqueeze(0).to(device)
for key in ["ts", "time_delta_input"]
}
else:
# Use the previous forecast_hat from the previous iteration
if forecast_hat is not None:
curr_batch["ts"] = torch.cat(
(curr_batch["ts"][:, :, 1:, ...], forecast_hat[:, :, None, ...]),
dim=2,
)
forecast_hat = model(curr_batch)
data_returned = forecast_hat.to(dtype=torch.float32).cpu().squeeze(0).numpy()
return data_returned
def run_inference(init_time_idx, plt_channel_idx, hours_ahead):
plt_channel_str = SDO_CHANNELS[plt_channel_idx]
input_timestamp_1 = dataset.valid_indices[init_time_idx]
input_timestamp_0 = input_timestamp_1 - pd.Timedelta(1, "h")
output_timestamp = input_timestamp_1 + pd.Timedelta(int(hours_ahead), "h")
input_timestamp_0 = input_timestamp_0.strftime("%Y-%m-%d %H:%M")
input_timestamp_1 = input_timestamp_1.strftime("%Y-%m-%d %H:%M")
output_timestamp = output_timestamp.strftime("%Y-%m-%d %H:%M")
sample_data, sample_metadata = dataset[init_time_idx]
with torch.no_grad():
model_output = batch_step(
model,
sample_data,
sample_metadata,
device,
hours_ahead
)
means, stds, epsilons, sl_scale_factors = dataset.transformation_inputs()
vmin = float("-inf")
vmax = float("inf")
input_image = []
for i in range(2):
input_image.append(
inverse_transform_single_channel(
sample_data["ts"][plt_channel_idx, i],
mean=means[plt_channel_idx],
std=stds[plt_channel_idx],
epsilon=epsilons[plt_channel_idx],
sl_scale_factor=sl_scale_factors[plt_channel_idx],
)
)
vmin = max(vmin, sample_data["ts"][plt_channel_idx, i].min())
#vmax = min(vmax, np.quantile(sample_data["ts"][plt_channel_idx, i], 0.99))
vmax = min(vmax, sample_data["ts"][plt_channel_idx, i].max())
if plt_channel_str.startswith("aia"):
cm_name = "sdo" + plt_channel_str
else:
cm_name = "hmimag"
input_image = [
sunpy_cm.cmlist[cm_name](
(img[::-1]-vmin) / (vmax-vmin), bytes=True
)
for img in input_image
]
output_image = inverse_transform_single_channel(
model_output[plt_channel_idx],
mean=means[plt_channel_idx],
std=stds[plt_channel_idx],
epsilon=epsilons[plt_channel_idx],
sl_scale_factor=sl_scale_factors[plt_channel_idx],
)
output_image = sunpy_cm.cmlist[cm_name](
(output_image[::-1]-vmin) / (vmax-vmin), bytes=True
)
return input_timestamp_0, input_image[0], input_timestamp_1, input_image[1], output_timestamp, output_image
logging.basicConfig(level=logging.INFO)
dataset, model, device = setup()
with gr.Blocks() as demo:
gr.Markdown(value="# Surya 1.0 - Visual forecasting demo")
#with gr.Row():
#with gr.Column():
with gr.Row():
with gr.Column():
init_time = gr.Dropdown(
[v.strftime("%Y-%m-%d %H:%M") for v in dataset.valid_indices],
label="Initialization time",
multiselect=False,
type="index",
)
with gr.Column():
plt_channel = gr.Dropdown(
[c.upper() for c in SDO_CHANNELS],
label="SDO Band",
value="AIA94",
multiselect=False,
type="index"
)
with gr.Row():
hours_ahead = gr.Slider(minimum=1.0, maximum=6.0, step=1.0, label="Forcast step [hours ahead]")
with gr.Row():
btn = gr.Button("Run")
with gr.Row():
with gr.Column():
input_timestamp_0 = gr.Textbox(label="Input 0")
input_image_0 = gr.Image()
with gr.Column():
input_timestamp_1 = gr.Textbox(label="Input 1")
input_image_1 = gr.Image()
with gr.Column():
output_timestamp = gr.Textbox(label="Prediction")
output_image = gr.Image()
btn.click(
fn=run_inference,
inputs=[init_time, plt_channel, hours_ahead],
outputs=[input_timestamp_0, input_image_0, input_timestamp_1, input_image_1, output_timestamp, output_image]
)
with gr.Row():
gr.Examples(
examples=[
["2014-01-07 17:24", "AIA94", 2],
["2014-01-07 16:12", "AIA94", 6],
["2014-01-07 16:00", "AIA131", 1],
["2014-01-07 16:00", "HMI_M", 2],
],
fn=run_inference,
inputs=[init_time, plt_channel, hours_ahead],
outputs=[input_timestamp_0, input_image_0, input_timestamp_1, input_image_1, output_timestamp, output_image],
cache_examples=False,
)
demo.launch()