surya-demo / app.py
broadfield-dev's picture
Update app.py
f43722b verified
raw
history blame
6.71 kB
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from surya.model import Surya # This now works because of the file structure
import numpy as np
from PIL import Image
import os
import warnings
# Suppress warnings for a cleaner demo experience
warnings.filterwarnings("ignore")
# --- 1. Define Constants and Data Channels ---
# Based on the Surya project's data preprocessing
AIA_CHANNELS = ["94", "131", "171", "193", "211", "304", "335", "1600"]
HMI_CHANNELS = ["bx", "by", "bz", "by_abs", "bz_abs"]
ALL_CHANNELS = [f"AIA {ch} Å" for ch in AIA_CHANNELS] + [f"HMI {ch}" for ch in HMI_CHANNELS]
# --- 2. Caching and Loading the Model and Data ---
@gr.cache
def load_model_and_data():
"""
Downloads the pre-trained Surya model, the test data, and initializes the model.
This function is cached so this happens only once.
"""
print("Downloading model and test data... This may take a moment.")
# Define local directories for caching
model_dir = "./surya_model"
data_dir = "./surya_data"
os.makedirs(model_dir, exist_ok=True)
os.makedirs(data_dir, exist_ok=True)
# Download the model weights and test data from Hugging Face
checkpoint_path = hf_hub_download(
repo_id="nasa-ibm-ai4science/Surya-1.0",
filename="surya.366m.v1.pt",
local_dir=model_dir
)
test_data_path = hf_hub_download(
repo_id="nasa-ibm-ai4science/Surya-1.0",
filename="test_data.pt",
local_dir=data_dir
)
print("Downloads complete.")
# Initialize the model architecture
model = Surya(
img_size=4096,
patch_size=16,
in_chans=13,
embed_dim=1280,
spectral_blocks=2,
attention_blocks=8,
)
# Load the weights into the model
print("Loading model weights...")
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
model.eval()
print("Model loaded successfully.")
# Load the test data
test_data = torch.load(test_data_path)
test_input = test_data["input"] # Input tensor for the model
test_label = test_data["label"] # Ground truth for comparison
return model, test_input, test_label
# --- 3. Helper function for Image Conversion ---
def tensor_to_image(tensor_slice):
"""
Normalizes a 2D tensor slice and converts it to a PIL Image for display.
"""
# Detach tensor from graph, move to CPU, and convert to numpy
img_np = tensor_slice.detach().cpu().numpy()
# Normalize the tensor to a 0-255 range for image display
min_val, max_val = np.min(img_np), np.max(img_np)
if max_val > min_val:
img_np = (img_np - min_val) / (max_val - min_val)
img_array = (img_np * 255).astype(np.uint8)
return Image.fromarray(img_array)
# --- 4. Main Prediction and Visualization Function ---
def run_forecast(channel_name, progress=gr.Progress()):
"""
This function is triggered by the button click in the Gradio interface.
It runs the model prediction and generates the images for display.
"""
progress(0, desc="Loading model and data (first run may be slow)...")
# Load the model and data (will be fast after the first run due to caching)
model, test_input, test_label = load_model_and_data()
progress(0.5, desc="Running inference on the model...")
# Perform the forecast
with torch.no_grad():
prediction = model(test_input)
progress(0.8, desc="Generating visualizations...")
# Get the index of the selected channel
channel_index = ALL_CHANNELS.index(channel_name)
# Extract the last time step from the input sequence for display
# Shape: [batch, channels, time, height, width] -> select channel, last time step
input_slice = test_input[0, channel_index, -1, :, :]
input_image = tensor_to_image(input_slice)
# Extract the corresponding slice from the model's prediction
# Shape: [batch, channels, time, height, width] -> select channel, first predicted step
predicted_slice = prediction[0, channel_index, 0, :, :]
predicted_image = tensor_to_image(predicted_slice)
# Extract the corresponding slice from the ground truth label
label_slice = test_label[0, channel_index, 0, :, :]
label_image = tensor_to_image(label_slice)
print(f"Forecast generated for channel: {channel_name}")
return input_image, predicted_image, label_image
# --- 5. Building the Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
<div align="center">
# ☀️ Surya: A Live Demonstration of NASA's Heliophysics Foundation Model ☀️
This demo runs the actual Surya model to forecast solar activity. It uses the official test data for **2014-01-07**,
allowing a direct comparison between the model's prediction and the real ground truth.
</div>
"""
)
with gr.Row():
channel_selector = gr.Dropdown(
choices=ALL_CHANNELS,
value=ALL_CHANNELS[2], # Default to "AIA 171 Å"
label="🛰️ Select SDO Instrument Channel",
info="Choose which solar observation channel to visualize."
)
run_button = gr.Button("🔮 Generate Forecast for 2014-01-07", variant="primary")
with gr.Row():
with gr.Column():
gr.Markdown("### ⬅️ Final Input Image")
gr.Markdown("The last image shown to the model before it makes a prediction.")
input_display = gr.Image(label="Input Observation", height=400, width=400)
with gr.Column():
gr.Markdown("### 🔮 Model's Forecast")
gr.Markdown("What the Surya model predicted the Sun would look like.")
prediction_display = gr.Image(label="Surya Prediction", height=400, width=400)
with gr.Column():
gr.Markdown("### ✅ Ground Truth")
gr.Markdown("What the Sun *actually* looked like at the forecast time.")
label_display = gr.Image(label="Actual Observation", height=400, width=400)
gr.Markdown(
"--- \n"
"**Note:** The first time you run a forecast, the app will download the 366M-parameter model (~1.4 GB) and test data. Subsequent runs will be much faster. "
"The images are downscaled for display in this demo. "
"For more information, visit the [Surya Hugging Face Repository](https://huggingface.co/nasa-ibm-ai4science/Surya-1.0)."
)
run_button.click(
fn=run_forecast,
inputs=[channel_selector],
outputs=[input_display, prediction_display, label_display]
)
if __name__ == "__main__":
demo.launch(debug=True)