Spaces:
Running
Running
import gradio as gr | |
import torch | |
from huggingface_hub import hf_hub_download | |
from surya.model import Surya | |
import numpy as np | |
from PIL import Image | |
import warnings | |
# Suppress warnings for a cleaner demo experience | |
warnings.filterwarnings("ignore") | |
# --- Model Loading --- | |
def load_model(): | |
""" | |
Downloads the pre-trained Surya model weights and initializes the model. | |
This function is cached so the model is only loaded once. | |
""" | |
checkpoint_path = hf_hub_download( | |
repo_id="nasa-ibm-ai4science/Surya-1.0", | |
filename="surya.366m.v1.pt" | |
) | |
model = Surya( | |
img_size=4096, | |
patch_size=16, | |
in_chans=13, | |
embed_dim=1280, | |
spectral_blocks=2, | |
attention_blocks=8, | |
) | |
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) | |
model.eval() | |
return model | |
model = load_model() | |
# --- Core Prediction Logic --- | |
def predict_solar_activity(time_steps, forecast_horizon): | |
""" | |
Generates a forecast of solar activity using the Surya model. | |
For this demo, we use a dummy input tensor to simulate the model's input. | |
In a real-world scenario, this function would fetch and preprocess | |
actual SDO data for the given time steps. | |
""" | |
# Create a dummy input tensor representing a sequence of solar observations | |
# Shape: [batch_size, channels, time_steps, height, width] | |
dummy_input = torch.randn(1, 13, time_steps, 4096, 4096) | |
# In a real application, you would replace the dummy input with actual, | |
# preprocessed data from the Solar Dynamics Observatory (SDO). | |
# Preprocessing would involve alignment and normalization as described | |
# in the Surya paper. | |
with torch.no_grad(): | |
# The model's prediction would be based on the forecast_horizon | |
# For this demo, we simulate a prediction by selecting a slice of the input | |
prediction = model(dummy_input) | |
# --- Visualization --- | |
# For demonstration, we will visualize one of the output channels. | |
# We will take the last predicted time step. | |
predicted_image_tensor = prediction[0, 0, -1, :, :] # Visualizing the first channel | |
# Normalize the tensor to a 0-255 range for image display | |
normalized_tensor = (predicted_image_tensor - predicted_image_tensor.min()) / \ | |
(predicted_image_tensor.max() - predicted_image_tensor.min()) | |
image_array = (normalized_tensor * 255).byte().cpu().numpy() | |
predicted_image = Image.fromarray(image_array) | |
# For the flare prediction, we'll generate a dummy probability | |
flare_probability = np.random.rand() | |
if flare_probability > 0.5: | |
flare_class = "M-class or X-class Flare" | |
confidence = flare_probability | |
else: | |
flare_class = "No significant flare" | |
confidence = 1 - flare_probability | |
return predicted_image, {flare_class: confidence} | |
# --- Gradio Interface Definition --- | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
<div align="center"> | |
# โ๏ธ Surya: Foundation Model for Heliophysics โ๏ธ | |
*A Gradio Demo for NASA's Solar Foundation Model* | |
</div> | |
""" | |
) | |
gr.Markdown( | |
"Surya is a 366M-parameter foundation model trained on full-resolution, multi-instrument SDO observations. " | |
"This demo showcases its capability to forecast solar dynamics." | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### โ๏ธ Prediction Parameters") | |
time_steps_slider = gr.Slider( | |
minimum=1, maximum=10, value=5, step=1, | |
label="Number of Input Time Steps (12-min cadence)", | |
info="Represents the sequence of past solar observations to feed the model." | |
) | |
forecast_horizon_slider = gr.Slider( | |
minimum=1, maximum=24, value=1, | |
label="Forecast Horizon (hours)", | |
info="How far into the future to predict." | |
) | |
predict_button = gr.Button("๐ฎ Generate Forecast", variant="primary") | |
with gr.Column(scale=2): | |
gr.Markdown("### ๐ฐ๏ธ Predicted Solar Image") | |
output_image = gr.Image(label="Forecasted SDO Image (AIA 171 ร )", height=512, width=512) | |
gr.Markdown("### ๐ฅ Solar Flare Prediction") | |
output_flare = gr.Label(label="Flare Probability") | |
predict_button.click( | |
fn=predict_solar_activity, | |
inputs=[time_steps_slider, forecast_horizon_slider], | |
outputs=[output_image, output_flare] | |
) | |
gr.Markdown("---") | |
gr.Markdown( | |
"**Note:** This demo uses a placeholder for real-time data fetching and displays a simulated prediction. " | |
"The core of this application is the loaded Surya model from NASA and IBM." | |
) | |
gr.Markdown( | |
"For more information, visit the [Surya model card on Hugging Face](https://huggingface.co/nasa-ibm-ai4science/Surya-1.0)." | |
) | |
if __name__ == "__main__": | |
demo.launch() |