surya-demo / app.py
broadfield-dev's picture
Update app.py
2870322 verified
raw
history blame
12.3 kB
# Save this file as app.py in the root of the cloned Surya repository
import gradio as gr
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from huggingface_hub import snapshot_download
import yaml
import numpy as np
from PIL import Image
import sunpy.visualization.colormaps as sunpy_cm
import os
import glob
import warnings
import logging
import matplotlib.pyplot as plt
# --- Use the official Surya modules now that we are in the repo ---
from surya.datasets.helio import HelioNetCDFDataset, inverse_transform_single_channel
from surya.models.helio_spectformer import HelioSpectFormer
from surya.utils.data import build_scalers, custom_collate_fn
# Suppress verbose logging and warnings for a cleaner UI
warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Global cache to store expensive-to-load objects ---
APP_CACHE = {
"model": None,
"config": None,
"scalers": None,
"full_results": None, # Will store all prediction results
"device": "cuda" if torch.cuda.is_available() else "cpu",
}
# SDO channels from the test script for the dropdown menu
SDO_CHANNELS = [
"aia94", "aia131", "aia171", "aia193", "aia211", "aia304", "aia335",
"aia1600", "hmi_m", "hmi_bx", "hmi_by", "hmi_bz", "hmi_v",
]
# --- 1. Setup, Download, and Model Loading (adapting fixtures from test_surya.py) ---
def setup_and_load_model(progress=gr.Progress()):
"""
Handles all initial setup: downloading data, loading configs, and initializing the model.
This function will populate the APP_CACHE.
"""
if APP_CACHE["model"] is not None:
logger.info("Model and data already loaded. Skipping setup.")
return
# --- Part A: Download data (from download_data fixture) ---
progress(0.1, desc="Downloading model weights and config...")
snapshot_download(
repo_id="nasa-ibm-ai4science/Surya-1.0",
local_dir="data/Surya-1.0",
allow_patterns=["config.yaml", "scalers.yaml", "surya.366m.v1.pt"],
)
progress(0.3, desc="Downloading validation data for 2014-01-07...")
snapshot_download(
repo_id="nasa-ibm-ai4science/Surya-1.0_validation_data",
repo_type="dataset",
local_dir="data/Surya-1.0_validation_data",
allow_patterns="20140107_1[5-9]??.nc",
)
# --- Part B: Load Config and Scalers (from config & scalers fixtures) ---
progress(0.5, desc="Loading configuration and data scalers...")
with open("data/Surya-1.0/config.yaml") as fp:
config = yaml.safe_load(fp)
APP_CACHE["config"] = config
scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r"))
APP_CACHE["scalers"] = build_scalers(info=scalers_info)
# --- Part C: Initialize and load model (from model fixture and test function) ---
progress(0.7, desc="Initializing model architecture...")
model_config = config["model"]
model = HelioSpectFormer(
img_size=model_config["img_size"], patch_size=model_config["patch_size"],
in_chans=len(config["data"]["sdo_channels"]), embed_dim=model_config["embed_dim"],
time_embedding={"type": "linear", "time_dim": len(config["data"]["time_delta_input_minutes"])},
depth=model_config["depth"], n_spectral_blocks=model_config["n_spectral_blocks"],
num_heads=model_config["num_heads"], mlp_ratio=model_config["mlp_ratio"],
drop_rate=model_config["drop_rate"], dtype=torch.bfloat16,
window_size=model_config["window_size"], dp_rank=model_config["dp_rank"],
learned_flow=model_config["learned_flow"], use_latitude_in_learned_flow=model_config["learned_flow"],
init_weights=False, checkpoint_layers=list(range(model_config["depth"])),
rpe=model_config["rpe"], ensemble=model_config["ensemble"], finetune=model_config["finetune"],
)
progress(0.8, desc=f"Loading model weights to {APP_CACHE['device']}...")
path_weights = "data/Surya-1.0/surya.366m.v1.pt"
weights = torch.load(path_weights, map_location=torch.device(APP_CACHE["device"]))
model.load_state_dict(weights, strict=True)
model.to(APP_CACHE["device"])
model.eval()
n_params = sum(p.numel() for p in model.parameters()) / 1e6
logger.info(f"Surya FM: {n_params:.2f}M parameters loaded.")
APP_CACHE["model"] = model
# --- 2. Inference Logic (adapting the test loop) ---
def run_full_forecast():
"""
Runs inference on the entire validation dataset and stores results.
"""
if APP_CACHE["full_results"] is not None:
return APP_CACHE["full_results"]
model = APP_CACHE["model"]
config = APP_CACHE["config"]
device = APP_CACHE["device"]
# Create the index file needed by the dataset loader
os.makedirs("tests", exist_ok=True)
with open("tests/test_surya_index.csv", "w") as f:
f.write("path\n")
search_path = os.path.join("data/Surya-1.0_validation_data", "**", "*.nc")
for nc_file in sorted(glob.glob(search_path, recursive=True)):
f.write(f"{nc_file}\n")
# Setup dataset and dataloader (from dataset & dataloader fixtures)
dataset = HelioNetCDFDataset(
index_path="tests/test_surya_index.csv",
time_delta_input_minutes=config["data"]["time_delta_input_minutes"],
time_delta_target_minutes=config["data"]["time_delta_target_minutes"],
n_input_timestamps=len(config["data"]["time_delta_input_minutes"]),
rollout_steps=1, channels=config["data"]["sdo_channels"],
scalers=APP_CACHE["scalers"], phase="valid",
)
dataloader = DataLoader(
dataset, shuffle=False, batch_size=1, num_workers=2,
pin_memory=True, collate_fn=custom_collate_fn
)
all_results = []
with torch.no_grad():
for batch_data, batch_metadata in dataloader:
input_batch = {k: v.to(device) for k, v in batch_data.items() if k in ["ts", "time_delta_input"]}
with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16):
prediction = model(input_batch)
# Store the relevant tensors on CPU for later visualization
result = {
"input": input_batch["ts"].cpu(),
"prediction": prediction.cpu(),
"target": batch_data["forecast"].cpu(),
"input_timestamp": np.datetime_as_string(batch_metadata["timestamps_input"][0][-1], unit='s'),
"target_timestamp": np.datetime_as_string(batch_metadata["timestamps_targets"][0][0], unit='s'),
}
all_results.append(result)
APP_CACHE["full_results"] = all_results
# Cache scalers needed for visualization
APP_CACHE["scalers_vis"] = dataset.transformation_inputs()
return all_results
# --- 3. Visualization Logic ---
def generate_visualization(results, timestep_index, channel_name):
"""
Generates PIL images for a specific timestep and channel from the results.
"""
if not results:
return None, None, None, "No results available. Please run the forecast.", ""
timestep_data = results[timestep_index]
c_idx = SDO_CHANNELS.index(channel_name)
means, stds, epsilons, sl_scale_factors = APP_CACHE["scalers_vis"]
# Denormalize data for visualization
input_slice = inverse_transform_single_channel(
timestep_data["input"][0, c_idx, -1].numpy(),
mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx]
)
pred_slice = inverse_transform_single_channel(
timestep_data["prediction"][0, c_idx].numpy(),
mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx]
)
target_slice = inverse_transform_single_channel(
timestep_data["target"][0, c_idx, 0].numpy(),
mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx]
)
# Convert to PIL Images using appropriate colormaps
vmax = np.quantile(target_slice, 0.995)
cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray'))
def to_pil(data):
data_clipped = np.clip(data, 0, vmax)
data_norm = data_clipped / vmax
return Image.fromarray((cmap(data_norm)[:, :, :3] * 255).astype(np.uint8)).transpose(Image.Transpose.TRANSPOSE)
status_text = (f"Displaying Timestep {timestep_index+1}/{len(results)}\n"
f"Input: {timestep_data['input_timestamp']} | Forecast/Target: {timestep_data['target_timestamp']}")
return to_pil(input_slice), to_pil(pred_slice), to_pil(target_slice), status_text
# --- 4. Gradio Controller Functions ---
def forecast_controller(progress=gr.Progress(track_tqdm=True)):
"""Main function for the 'Generate Forecast' button."""
progress(0, desc="Starting setup...")
setup_and_load_model(progress)
logger.info("Running forecast on all validation timesteps...")
progress(0.9, desc="Running inference on validation data...")
results = run_full_forecast()
logger.info(f"Forecast complete. {len(results)} timesteps processed.")
# Generate the first visualization
img_in, img_pred, img_target, status = generate_visualization(results, 0, SDO_CHANNELS[2]) # Default to aia171
# Update the slider to be interactive and have the correct number of steps
slider_update = gr.Slider(minimum=1, maximum=len(results), step=1, value=1, interactive=True,
label="Forecast Timestep")
return results, img_in, img_pred, img_target, status, slider_update
def update_visualization_controller(results, timestep, channel):
"""Called when a slider or dropdown is changed."""
return generate_visualization(results, timestep - 1, channel)
# --- 5. Gradio UI Layout ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
# State object to hold the results of the full inference run
state_results = gr.State()
gr.Markdown(
"""
<div align='center'>
# ☀️ Surya: Live Model Demo ☀️
### An Interactive Interface for NASA's Heliophysics Foundation Model
This demo runs the **actual** Surya model on its official validation data for **2014-01-07**.
<br>
**Instructions:** 1. Click 'Generate Forecast'. 2. Use the controls to explore the results.
</div>
"""
)
with gr.Row():
with gr.Column(scale=1):
run_button = gr.Button("🔮 1. Generate Full Forecast", variant="primary")
status_box = gr.Textbox(label="Status", interactive=False, value="Ready.", lines=2)
channel_selector = gr.Dropdown(
choices=SDO_CHANNELS, value="aia171", label="🛰️ 2. Select SDO Channel"
)
timestep_slider = gr.Slider(
minimum=1, maximum=8, step=1, value=1, interactive=False, label="Forecast Timestep"
)
with gr.Column(scale=3):
with gr.Row():
input_display = gr.Image(label="Last Input", height=512, width=512, interactive=False)
prediction_display = gr.Image(label="Model Forecast", height=512, width=512, interactive=False)
target_display = gr.Image(label="Ground Truth", height=512, width=512, interactive=False)
# --- Event Handlers ---
run_button.click(
fn=forecast_controller,
outputs=[state_results, input_display, prediction_display, target_display, status_box, timestep_slider]
)
# When the user changes the channel or timestep, call the visualization update function
channel_selector.change(
fn=update_visualization_controller,
inputs=[state_results, timestep_slider, channel_selector],
outputs=[input_display, prediction_display, target_display, status_box]
)
timestep_slider.change(
fn=update_visualization_controller,
inputs=[state_results, timestep_slider, channel_selector],
outputs=[input_display, prediction_display, target_display, status_box]
)
if __name__ == "__main__":
demo.launch(debug=True)