Spaces:
Running
Running
File size: 6,705 Bytes
f966038 5f00446 f43722b f966038 f43722b f966038 f43722b f966038 f43722b f966038 f43722b f966038 f43722b f966038 f43722b f966038 f43722b f966038 f43722b f966038 f43722b f966038 f43722b f966038 f43722b f966038 f43722b f966038 f43722b f966038 f43722b f966038 f43722b f966038 f43722b f966038 f43722b f966038 f43722b f966038 f43722b f966038 f43722b f966038 f43722b f966038 f43722b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from surya.model import Surya # This now works because of the file structure
import numpy as np
from PIL import Image
import os
import warnings
# Suppress warnings for a cleaner demo experience
warnings.filterwarnings("ignore")
# --- 1. Define Constants and Data Channels ---
# Based on the Surya project's data preprocessing
AIA_CHANNELS = ["94", "131", "171", "193", "211", "304", "335", "1600"]
HMI_CHANNELS = ["bx", "by", "bz", "by_abs", "bz_abs"]
ALL_CHANNELS = [f"AIA {ch} Å" for ch in AIA_CHANNELS] + [f"HMI {ch}" for ch in HMI_CHANNELS]
# --- 2. Caching and Loading the Model and Data ---
@gr.cache
def load_model_and_data():
"""
Downloads the pre-trained Surya model, the test data, and initializes the model.
This function is cached so this happens only once.
"""
print("Downloading model and test data... This may take a moment.")
# Define local directories for caching
model_dir = "./surya_model"
data_dir = "./surya_data"
os.makedirs(model_dir, exist_ok=True)
os.makedirs(data_dir, exist_ok=True)
# Download the model weights and test data from Hugging Face
checkpoint_path = hf_hub_download(
repo_id="nasa-ibm-ai4science/Surya-1.0",
filename="surya.366m.v1.pt",
local_dir=model_dir
)
test_data_path = hf_hub_download(
repo_id="nasa-ibm-ai4science/Surya-1.0",
filename="test_data.pt",
local_dir=data_dir
)
print("Downloads complete.")
# Initialize the model architecture
model = Surya(
img_size=4096,
patch_size=16,
in_chans=13,
embed_dim=1280,
spectral_blocks=2,
attention_blocks=8,
)
# Load the weights into the model
print("Loading model weights...")
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
model.eval()
print("Model loaded successfully.")
# Load the test data
test_data = torch.load(test_data_path)
test_input = test_data["input"] # Input tensor for the model
test_label = test_data["label"] # Ground truth for comparison
return model, test_input, test_label
# --- 3. Helper function for Image Conversion ---
def tensor_to_image(tensor_slice):
"""
Normalizes a 2D tensor slice and converts it to a PIL Image for display.
"""
# Detach tensor from graph, move to CPU, and convert to numpy
img_np = tensor_slice.detach().cpu().numpy()
# Normalize the tensor to a 0-255 range for image display
min_val, max_val = np.min(img_np), np.max(img_np)
if max_val > min_val:
img_np = (img_np - min_val) / (max_val - min_val)
img_array = (img_np * 255).astype(np.uint8)
return Image.fromarray(img_array)
# --- 4. Main Prediction and Visualization Function ---
def run_forecast(channel_name, progress=gr.Progress()):
"""
This function is triggered by the button click in the Gradio interface.
It runs the model prediction and generates the images for display.
"""
progress(0, desc="Loading model and data (first run may be slow)...")
# Load the model and data (will be fast after the first run due to caching)
model, test_input, test_label = load_model_and_data()
progress(0.5, desc="Running inference on the model...")
# Perform the forecast
with torch.no_grad():
prediction = model(test_input)
progress(0.8, desc="Generating visualizations...")
# Get the index of the selected channel
channel_index = ALL_CHANNELS.index(channel_name)
# Extract the last time step from the input sequence for display
# Shape: [batch, channels, time, height, width] -> select channel, last time step
input_slice = test_input[0, channel_index, -1, :, :]
input_image = tensor_to_image(input_slice)
# Extract the corresponding slice from the model's prediction
# Shape: [batch, channels, time, height, width] -> select channel, first predicted step
predicted_slice = prediction[0, channel_index, 0, :, :]
predicted_image = tensor_to_image(predicted_slice)
# Extract the corresponding slice from the ground truth label
label_slice = test_label[0, channel_index, 0, :, :]
label_image = tensor_to_image(label_slice)
print(f"Forecast generated for channel: {channel_name}")
return input_image, predicted_image, label_image
# --- 5. Building the Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
<div align="center">
# ☀️ Surya: A Live Demonstration of NASA's Heliophysics Foundation Model ☀️
This demo runs the actual Surya model to forecast solar activity. It uses the official test data for **2014-01-07**,
allowing a direct comparison between the model's prediction and the real ground truth.
</div>
"""
)
with gr.Row():
channel_selector = gr.Dropdown(
choices=ALL_CHANNELS,
value=ALL_CHANNELS[2], # Default to "AIA 171 Å"
label="🛰️ Select SDO Instrument Channel",
info="Choose which solar observation channel to visualize."
)
run_button = gr.Button("🔮 Generate Forecast for 2014-01-07", variant="primary")
with gr.Row():
with gr.Column():
gr.Markdown("### ⬅️ Final Input Image")
gr.Markdown("The last image shown to the model before it makes a prediction.")
input_display = gr.Image(label="Input Observation", height=400, width=400)
with gr.Column():
gr.Markdown("### 🔮 Model's Forecast")
gr.Markdown("What the Surya model predicted the Sun would look like.")
prediction_display = gr.Image(label="Surya Prediction", height=400, width=400)
with gr.Column():
gr.Markdown("### ✅ Ground Truth")
gr.Markdown("What the Sun *actually* looked like at the forecast time.")
label_display = gr.Image(label="Actual Observation", height=400, width=400)
gr.Markdown(
"--- \n"
"**Note:** The first time you run a forecast, the app will download the 366M-parameter model (~1.4 GB) and test data. Subsequent runs will be much faster. "
"The images are downscaled for display in this demo. "
"For more information, visit the [Surya Hugging Face Repository](https://huggingface.co/nasa-ibm-ai4science/Surya-1.0)."
)
run_button.click(
fn=run_forecast,
inputs=[channel_selector],
outputs=[input_display, prediction_display, label_display]
)
if __name__ == "__main__":
demo.launch(debug=True) |