Spaces:
Running
Running
import gradio as gr | |
import torch | |
from huggingface_hub import hf_hub_download | |
from surya.model import Surya # This now works because of the file structure | |
import numpy as np | |
from PIL import Image | |
import os | |
import warnings | |
# Suppress warnings for a cleaner demo experience | |
warnings.filterwarnings("ignore") | |
# --- 1. Define Constants and Data Channels --- | |
# Based on the Surya project's data preprocessing | |
AIA_CHANNELS = ["94", "131", "171", "193", "211", "304", "335", "1600"] | |
HMI_CHANNELS = ["bx", "by", "bz", "by_abs", "bz_abs"] | |
ALL_CHANNELS = [f"AIA {ch} Å" for ch in AIA_CHANNELS] + [f"HMI {ch}" for ch in HMI_CHANNELS] | |
# --- 2. Caching and Loading the Model and Data --- | |
def load_model_and_data(): | |
""" | |
Downloads the pre-trained Surya model, the test data, and initializes the model. | |
This function is cached so this happens only once. | |
""" | |
print("Downloading model and test data... This may take a moment.") | |
# Define local directories for caching | |
model_dir = "./surya_model" | |
data_dir = "./surya_data" | |
os.makedirs(model_dir, exist_ok=True) | |
os.makedirs(data_dir, exist_ok=True) | |
# Download the model weights and test data from Hugging Face | |
checkpoint_path = hf_hub_download( | |
repo_id="nasa-ibm-ai4science/Surya-1.0", | |
filename="surya.366m.v1.pt", | |
local_dir=model_dir | |
) | |
test_data_path = hf_hub_download( | |
repo_id="nasa-ibm-ai4science/Surya-1.0", | |
filename="test_data.pt", | |
local_dir=data_dir | |
) | |
print("Downloads complete.") | |
# Initialize the model architecture | |
model = Surya( | |
img_size=4096, | |
patch_size=16, | |
in_chans=13, | |
embed_dim=1280, | |
spectral_blocks=2, | |
attention_blocks=8, | |
) | |
# Load the weights into the model | |
print("Loading model weights...") | |
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) | |
model.eval() | |
print("Model loaded successfully.") | |
# Load the test data | |
test_data = torch.load(test_data_path) | |
test_input = test_data["input"] # Input tensor for the model | |
test_label = test_data["label"] # Ground truth for comparison | |
return model, test_input, test_label | |
# --- 3. Helper function for Image Conversion --- | |
def tensor_to_image(tensor_slice): | |
""" | |
Normalizes a 2D tensor slice and converts it to a PIL Image for display. | |
""" | |
# Detach tensor from graph, move to CPU, and convert to numpy | |
img_np = tensor_slice.detach().cpu().numpy() | |
# Normalize the tensor to a 0-255 range for image display | |
min_val, max_val = np.min(img_np), np.max(img_np) | |
if max_val > min_val: | |
img_np = (img_np - min_val) / (max_val - min_val) | |
img_array = (img_np * 255).astype(np.uint8) | |
return Image.fromarray(img_array) | |
# --- 4. Main Prediction and Visualization Function --- | |
def run_forecast(channel_name, progress=gr.Progress()): | |
""" | |
This function is triggered by the button click in the Gradio interface. | |
It runs the model prediction and generates the images for display. | |
""" | |
progress(0, desc="Loading model and data (first run may be slow)...") | |
# Load the model and data (will be fast after the first run due to caching) | |
model, test_input, test_label = load_model_and_data() | |
progress(0.5, desc="Running inference on the model...") | |
# Perform the forecast | |
with torch.no_grad(): | |
prediction = model(test_input) | |
progress(0.8, desc="Generating visualizations...") | |
# Get the index of the selected channel | |
channel_index = ALL_CHANNELS.index(channel_name) | |
# Extract the last time step from the input sequence for display | |
# Shape: [batch, channels, time, height, width] -> select channel, last time step | |
input_slice = test_input[0, channel_index, -1, :, :] | |
input_image = tensor_to_image(input_slice) | |
# Extract the corresponding slice from the model's prediction | |
# Shape: [batch, channels, time, height, width] -> select channel, first predicted step | |
predicted_slice = prediction[0, channel_index, 0, :, :] | |
predicted_image = tensor_to_image(predicted_slice) | |
# Extract the corresponding slice from the ground truth label | |
label_slice = test_label[0, channel_index, 0, :, :] | |
label_image = tensor_to_image(label_slice) | |
print(f"Forecast generated for channel: {channel_name}") | |
return input_image, predicted_image, label_image | |
# --- 5. Building the Gradio Interface --- | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
<div align="center"> | |
# ☀️ Surya: A Live Demonstration of NASA's Heliophysics Foundation Model ☀️ | |
This demo runs the actual Surya model to forecast solar activity. It uses the official test data for **2014-01-07**, | |
allowing a direct comparison between the model's prediction and the real ground truth. | |
</div> | |
""" | |
) | |
with gr.Row(): | |
channel_selector = gr.Dropdown( | |
choices=ALL_CHANNELS, | |
value=ALL_CHANNELS[2], # Default to "AIA 171 Å" | |
label="🛰️ Select SDO Instrument Channel", | |
info="Choose which solar observation channel to visualize." | |
) | |
run_button = gr.Button("🔮 Generate Forecast for 2014-01-07", variant="primary") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### ⬅️ Final Input Image") | |
gr.Markdown("The last image shown to the model before it makes a prediction.") | |
input_display = gr.Image(label="Input Observation", height=400, width=400) | |
with gr.Column(): | |
gr.Markdown("### 🔮 Model's Forecast") | |
gr.Markdown("What the Surya model predicted the Sun would look like.") | |
prediction_display = gr.Image(label="Surya Prediction", height=400, width=400) | |
with gr.Column(): | |
gr.Markdown("### ✅ Ground Truth") | |
gr.Markdown("What the Sun *actually* looked like at the forecast time.") | |
label_display = gr.Image(label="Actual Observation", height=400, width=400) | |
gr.Markdown( | |
"--- \n" | |
"**Note:** The first time you run a forecast, the app will download the 366M-parameter model (~1.4 GB) and test data. Subsequent runs will be much faster. " | |
"The images are downscaled for display in this demo. " | |
"For more information, visit the [Surya Hugging Face Repository](https://huggingface.co/nasa-ibm-ai4science/Surya-1.0)." | |
) | |
run_button.click( | |
fn=run_forecast, | |
inputs=[channel_selector], | |
outputs=[input_display, prediction_display, label_display] | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) |