Spaces:
Running
Running
# Save this file as app.py in the root of the cloned Surya repository | |
import gradio as gr | |
import torch | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader | |
from huggingface_hub import snapshot_download | |
import yaml | |
import numpy as np | |
from PIL import Image | |
import sunpy.visualization.colormaps as sunpy_cm | |
import os | |
import glob | |
import warnings | |
import logging | |
import matplotlib.pyplot as plt | |
# --- Use the official Surya modules now that we are in the repo --- | |
from surya.datasets.helio import HelioNetCDFDataset, inverse_transform_single_channel | |
from surya.models.helio_spectformer import HelioSpectFormer | |
from surya.utils.data import build_scalers, custom_collate_fn | |
# Suppress verbose logging and warnings for a cleaner UI | |
warnings.filterwarnings("ignore") | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# --- Global cache to store expensive-to-load objects --- | |
APP_CACHE = { | |
"model": None, | |
"config": None, | |
"scalers": None, | |
"full_results": None, # Will store all prediction results | |
"device": "cuda" if torch.cuda.is_available() else "cpu", | |
} | |
# SDO channels from the test script for the dropdown menu | |
SDO_CHANNELS = [ | |
"aia94", "aia131", "aia171", "aia193", "aia211", "aia304", "aia335", | |
"aia1600", "hmi_m", "hmi_bx", "hmi_by", "hmi_bz", "hmi_v", | |
] | |
# --- 1. Setup, Download, and Model Loading (adapting fixtures from test_surya.py) --- | |
def setup_and_load_model(progress=gr.Progress()): | |
""" | |
Handles all initial setup: downloading data, loading configs, and initializing the model. | |
This function will populate the APP_CACHE. | |
""" | |
if APP_CACHE["model"] is not None: | |
logger.info("Model and data already loaded. Skipping setup.") | |
return | |
# --- Part A: Download data (from download_data fixture) --- | |
progress(0.1, desc="Downloading model weights and config...") | |
snapshot_download( | |
repo_id="nasa-ibm-ai4science/Surya-1.0", | |
local_dir="data/Surya-1.0", | |
allow_patterns=["config.yaml", "scalers.yaml", "surya.366m.v1.pt"], | |
) | |
progress(0.3, desc="Downloading validation data for 2014-01-07...") | |
snapshot_download( | |
repo_id="nasa-ibm-ai4science/Surya-1.0_validation_data", | |
repo_type="dataset", | |
local_dir="data/Surya-1.0_validation_data", | |
allow_patterns="20140107_1[5-9]??.nc", | |
) | |
# --- Part B: Load Config and Scalers (from config & scalers fixtures) --- | |
progress(0.5, desc="Loading configuration and data scalers...") | |
with open("data/Surya-1.0/config.yaml") as fp: | |
config = yaml.safe_load(fp) | |
APP_CACHE["config"] = config | |
scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r")) | |
APP_CACHE["scalers"] = build_scalers(info=scalers_info) | |
# --- Part C: Initialize and load model (from model fixture and test function) --- | |
progress(0.7, desc="Initializing model architecture...") | |
model_config = config["model"] | |
model = HelioSpectFormer( | |
img_size=model_config["img_size"], patch_size=model_config["patch_size"], | |
in_chans=len(config["data"]["sdo_channels"]), embed_dim=model_config["embed_dim"], | |
time_embedding={"type": "linear", "time_dim": len(config["data"]["time_delta_input_minutes"])}, | |
depth=model_config["depth"], n_spectral_blocks=model_config["n_spectral_blocks"], | |
num_heads=model_config["num_heads"], mlp_ratio=model_config["mlp_ratio"], | |
drop_rate=model_config["drop_rate"], dtype=torch.bfloat16, | |
window_size=model_config["window_size"], dp_rank=model_config["dp_rank"], | |
learned_flow=model_config["learned_flow"], use_latitude_in_learned_flow=model_config["learned_flow"], | |
init_weights=False, checkpoint_layers=list(range(model_config["depth"])), | |
rpe=model_config["rpe"], ensemble=model_config["ensemble"], finetune=model_config["finetune"], | |
) | |
progress(0.8, desc=f"Loading model weights to {APP_CACHE['device']}...") | |
path_weights = "data/Surya-1.0/surya.366m.v1.pt" | |
weights = torch.load(path_weights, map_location=torch.device(APP_CACHE["device"])) | |
model.load_state_dict(weights, strict=True) | |
model.to(APP_CACHE["device"]) | |
model.eval() | |
n_params = sum(p.numel() for p in model.parameters()) / 1e6 | |
logger.info(f"Surya FM: {n_params:.2f}M parameters loaded.") | |
APP_CACHE["model"] = model | |
# --- 2. Inference Logic (adapting the test loop) --- | |
def run_full_forecast(): | |
""" | |
Runs inference on the entire validation dataset and stores results. | |
""" | |
if APP_CACHE["full_results"] is not None: | |
return APP_CACHE["full_results"] | |
model = APP_CACHE["model"] | |
config = APP_CACHE["config"] | |
device = APP_CACHE["device"] | |
# Create the index file needed by the dataset loader | |
os.makedirs("tests", exist_ok=True) | |
with open("tests/test_surya_index.csv", "w") as f: | |
f.write("path\n") | |
search_path = os.path.join("data/Surya-1.0_validation_data", "**", "*.nc") | |
for nc_file in sorted(glob.glob(search_path, recursive=True)): | |
f.write(f"{nc_file}\n") | |
# Setup dataset and dataloader (from dataset & dataloader fixtures) | |
dataset = HelioNetCDFDataset( | |
index_path="tests/test_surya_index.csv", | |
time_delta_input_minutes=config["data"]["time_delta_input_minutes"], | |
time_delta_target_minutes=config["data"]["time_delta_target_minutes"], | |
n_input_timestamps=len(config["data"]["time_delta_input_minutes"]), | |
rollout_steps=1, channels=config["data"]["sdo_channels"], | |
scalers=APP_CACHE["scalers"], phase="valid", | |
) | |
dataloader = DataLoader( | |
dataset, shuffle=False, batch_size=1, num_workers=2, | |
pin_memory=True, collate_fn=custom_collate_fn | |
) | |
all_results = [] | |
with torch.no_grad(): | |
for batch_data, batch_metadata in dataloader: | |
input_batch = {k: v.to(device) for k, v in batch_data.items() if k in ["ts", "time_delta_input"]} | |
with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16): | |
prediction = model(input_batch) | |
# Store the relevant tensors on CPU for later visualization | |
result = { | |
"input": input_batch["ts"].cpu(), | |
"prediction": prediction.cpu(), | |
"target": batch_data["forecast"].cpu(), | |
"input_timestamp": np.datetime_as_string(batch_metadata["timestamps_input"][0][-1], unit='s'), | |
"target_timestamp": np.datetime_as_string(batch_metadata["timestamps_targets"][0][0], unit='s'), | |
} | |
all_results.append(result) | |
APP_CACHE["full_results"] = all_results | |
# Cache scalers needed for visualization | |
APP_CACHE["scalers_vis"] = dataset.transformation_inputs() | |
return all_results | |
# --- 3. Visualization Logic --- | |
def generate_visualization(results, timestep_index, channel_name): | |
""" | |
Generates PIL images for a specific timestep and channel from the results. | |
""" | |
if not results: | |
return None, None, None, "No results available. Please run the forecast.", "" | |
timestep_data = results[timestep_index] | |
c_idx = SDO_CHANNELS.index(channel_name) | |
means, stds, epsilons, sl_scale_factors = APP_CACHE["scalers_vis"] | |
# Denormalize data for visualization | |
input_slice = inverse_transform_single_channel( | |
timestep_data["input"][0, c_idx, -1].numpy(), | |
mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx] | |
) | |
pred_slice = inverse_transform_single_channel( | |
timestep_data["prediction"][0, c_idx].numpy(), | |
mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx] | |
) | |
target_slice = inverse_transform_single_channel( | |
timestep_data["target"][0, c_idx, 0].numpy(), | |
mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx] | |
) | |
# Convert to PIL Images using appropriate colormaps | |
vmax = np.quantile(target_slice, 0.995) | |
cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag' | |
cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray')) | |
def to_pil(data): | |
data_clipped = np.clip(data, 0, vmax) | |
data_norm = data_clipped / vmax | |
return Image.fromarray((cmap(data_norm)[:, :, :3] * 255).astype(np.uint8)).transpose(Image.Transpose.TRANSPOSE) | |
status_text = (f"Displaying Timestep {timestep_index+1}/{len(results)}\n" | |
f"Input: {timestep_data['input_timestamp']} | Forecast/Target: {timestep_data['target_timestamp']}") | |
return to_pil(input_slice), to_pil(pred_slice), to_pil(target_slice), status_text | |
# --- 4. Gradio Controller Functions --- | |
def forecast_controller(progress=gr.Progress(track_tqdm=True)): | |
"""Main function for the 'Generate Forecast' button.""" | |
progress(0, desc="Starting setup...") | |
setup_and_load_model(progress) | |
logger.info("Running forecast on all validation timesteps...") | |
progress(0.9, desc="Running inference on validation data...") | |
results = run_full_forecast() | |
logger.info(f"Forecast complete. {len(results)} timesteps processed.") | |
# Generate the first visualization | |
img_in, img_pred, img_target, status = generate_visualization(results, 0, SDO_CHANNELS[2]) # Default to aia171 | |
# Update the slider to be interactive and have the correct number of steps | |
slider_update = gr.Slider(minimum=1, maximum=len(results), step=1, value=1, interactive=True, | |
label="Forecast Timestep") | |
return results, img_in, img_pred, img_target, status, slider_update | |
def update_visualization_controller(results, timestep, channel): | |
"""Called when a slider or dropdown is changed.""" | |
return generate_visualization(results, timestep - 1, channel) | |
# --- 5. Gradio UI Layout --- | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
# State object to hold the results of the full inference run | |
state_results = gr.State() | |
gr.Markdown( | |
""" | |
<div align='center'> | |
# ☀️ Surya: Live Model Demo ☀️ | |
### An Interactive Interface for NASA's Heliophysics Foundation Model | |
This demo runs the **actual** Surya model on its official validation data for **2014-01-07**. | |
<br> | |
**Instructions:** 1. Click 'Generate Forecast'. 2. Use the controls to explore the results. | |
</div> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
run_button = gr.Button("🔮 1. Generate Full Forecast", variant="primary") | |
status_box = gr.Textbox(label="Status", interactive=False, value="Ready.", lines=2) | |
channel_selector = gr.Dropdown( | |
choices=SDO_CHANNELS, value="aia171", label="🛰️ 2. Select SDO Channel" | |
) | |
timestep_slider = gr.Slider( | |
minimum=1, maximum=8, step=1, value=1, interactive=False, label="Forecast Timestep" | |
) | |
with gr.Column(scale=3): | |
with gr.Row(): | |
input_display = gr.Image(label="Last Input", height=512, width=512, interactive=False) | |
prediction_display = gr.Image(label="Model Forecast", height=512, width=512, interactive=False) | |
target_display = gr.Image(label="Ground Truth", height=512, width=512, interactive=False) | |
# --- Event Handlers --- | |
run_button.click( | |
fn=forecast_controller, | |
outputs=[state_results, input_display, prediction_display, target_display, status_box, timestep_slider] | |
) | |
# When the user changes the channel or timestep, call the visualization update function | |
channel_selector.change( | |
fn=update_visualization_controller, | |
inputs=[state_results, timestep_slider, channel_selector], | |
outputs=[input_display, prediction_display, target_display, status_box] | |
) | |
timestep_slider.change( | |
fn=update_visualization_controller, | |
inputs=[state_results, timestep_slider, channel_selector], | |
outputs=[input_display, prediction_display, target_display, status_box] | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) |