surya-demo / app.py
broadfield-dev's picture
Update app.py
f966038 verified
raw
history blame
5.04 kB
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from surya.model import Surya
import numpy as np
from PIL import Image
import warnings
# Suppress warnings for a cleaner demo experience
warnings.filterwarnings("ignore")
# --- Model Loading ---
@gr.cache
def load_model():
"""
Downloads the pre-trained Surya model weights and initializes the model.
This function is cached so the model is only loaded once.
"""
checkpoint_path = hf_hub_download(
repo_id="nasa-ibm-ai4science/Surya-1.0",
filename="surya.366m.v1.pt"
)
model = Surya(
img_size=4096,
patch_size=16,
in_chans=13,
embed_dim=1280,
spectral_blocks=2,
attention_blocks=8,
)
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
model.eval()
return model
model = load_model()
# --- Core Prediction Logic ---
def predict_solar_activity(time_steps, forecast_horizon):
"""
Generates a forecast of solar activity using the Surya model.
For this demo, we use a dummy input tensor to simulate the model's input.
In a real-world scenario, this function would fetch and preprocess
actual SDO data for the given time steps.
"""
# Create a dummy input tensor representing a sequence of solar observations
# Shape: [batch_size, channels, time_steps, height, width]
dummy_input = torch.randn(1, 13, time_steps, 4096, 4096)
# In a real application, you would replace the dummy input with actual,
# preprocessed data from the Solar Dynamics Observatory (SDO).
# Preprocessing would involve alignment and normalization as described
# in the Surya paper.
with torch.no_grad():
# The model's prediction would be based on the forecast_horizon
# For this demo, we simulate a prediction by selecting a slice of the input
prediction = model(dummy_input)
# --- Visualization ---
# For demonstration, we will visualize one of the output channels.
# We will take the last predicted time step.
predicted_image_tensor = prediction[0, 0, -1, :, :] # Visualizing the first channel
# Normalize the tensor to a 0-255 range for image display
normalized_tensor = (predicted_image_tensor - predicted_image_tensor.min()) / \
(predicted_image_tensor.max() - predicted_image_tensor.min())
image_array = (normalized_tensor * 255).byte().cpu().numpy()
predicted_image = Image.fromarray(image_array)
# For the flare prediction, we'll generate a dummy probability
flare_probability = np.random.rand()
if flare_probability > 0.5:
flare_class = "M-class or X-class Flare"
confidence = flare_probability
else:
flare_class = "No significant flare"
confidence = 1 - flare_probability
return predicted_image, {flare_class: confidence}
# --- Gradio Interface Definition ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
<div align="center">
# โ˜€๏ธ Surya: Foundation Model for Heliophysics โ˜€๏ธ
*A Gradio Demo for NASA's Solar Foundation Model*
</div>
"""
)
gr.Markdown(
"Surya is a 366M-parameter foundation model trained on full-resolution, multi-instrument SDO observations. "
"This demo showcases its capability to forecast solar dynamics."
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### โš™๏ธ Prediction Parameters")
time_steps_slider = gr.Slider(
minimum=1, maximum=10, value=5, step=1,
label="Number of Input Time Steps (12-min cadence)",
info="Represents the sequence of past solar observations to feed the model."
)
forecast_horizon_slider = gr.Slider(
minimum=1, maximum=24, value=1,
label="Forecast Horizon (hours)",
info="How far into the future to predict."
)
predict_button = gr.Button("๐Ÿ”ฎ Generate Forecast", variant="primary")
with gr.Column(scale=2):
gr.Markdown("### ๐Ÿ›ฐ๏ธ Predicted Solar Image")
output_image = gr.Image(label="Forecasted SDO Image (AIA 171 ร…)", height=512, width=512)
gr.Markdown("### ๐Ÿ’ฅ Solar Flare Prediction")
output_flare = gr.Label(label="Flare Probability")
predict_button.click(
fn=predict_solar_activity,
inputs=[time_steps_slider, forecast_horizon_slider],
outputs=[output_image, output_flare]
)
gr.Markdown("---")
gr.Markdown(
"**Note:** This demo uses a placeholder for real-time data fetching and displays a simulated prediction. "
"The core of this application is the loaded Surya model from NASA and IBM."
)
gr.Markdown(
"For more information, visit the [Surya model card on Hugging Face](https://huggingface.co/nasa-ibm-ai4science/Surya-1.0)."
)
if __name__ == "__main__":
demo.launch()