Spaces:
Running
Running
File size: 5,044 Bytes
f966038 5f00446 f966038 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from surya.model import Surya
import numpy as np
from PIL import Image
import warnings
# Suppress warnings for a cleaner demo experience
warnings.filterwarnings("ignore")
# --- Model Loading ---
@gr.cache
def load_model():
"""
Downloads the pre-trained Surya model weights and initializes the model.
This function is cached so the model is only loaded once.
"""
checkpoint_path = hf_hub_download(
repo_id="nasa-ibm-ai4science/Surya-1.0",
filename="surya.366m.v1.pt"
)
model = Surya(
img_size=4096,
patch_size=16,
in_chans=13,
embed_dim=1280,
spectral_blocks=2,
attention_blocks=8,
)
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
model.eval()
return model
model = load_model()
# --- Core Prediction Logic ---
def predict_solar_activity(time_steps, forecast_horizon):
"""
Generates a forecast of solar activity using the Surya model.
For this demo, we use a dummy input tensor to simulate the model's input.
In a real-world scenario, this function would fetch and preprocess
actual SDO data for the given time steps.
"""
# Create a dummy input tensor representing a sequence of solar observations
# Shape: [batch_size, channels, time_steps, height, width]
dummy_input = torch.randn(1, 13, time_steps, 4096, 4096)
# In a real application, you would replace the dummy input with actual,
# preprocessed data from the Solar Dynamics Observatory (SDO).
# Preprocessing would involve alignment and normalization as described
# in the Surya paper.
with torch.no_grad():
# The model's prediction would be based on the forecast_horizon
# For this demo, we simulate a prediction by selecting a slice of the input
prediction = model(dummy_input)
# --- Visualization ---
# For demonstration, we will visualize one of the output channels.
# We will take the last predicted time step.
predicted_image_tensor = prediction[0, 0, -1, :, :] # Visualizing the first channel
# Normalize the tensor to a 0-255 range for image display
normalized_tensor = (predicted_image_tensor - predicted_image_tensor.min()) / \
(predicted_image_tensor.max() - predicted_image_tensor.min())
image_array = (normalized_tensor * 255).byte().cpu().numpy()
predicted_image = Image.fromarray(image_array)
# For the flare prediction, we'll generate a dummy probability
flare_probability = np.random.rand()
if flare_probability > 0.5:
flare_class = "M-class or X-class Flare"
confidence = flare_probability
else:
flare_class = "No significant flare"
confidence = 1 - flare_probability
return predicted_image, {flare_class: confidence}
# --- Gradio Interface Definition ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
<div align="center">
# ☀️ Surya: Foundation Model for Heliophysics ☀️
*A Gradio Demo for NASA's Solar Foundation Model*
</div>
"""
)
gr.Markdown(
"Surya is a 366M-parameter foundation model trained on full-resolution, multi-instrument SDO observations. "
"This demo showcases its capability to forecast solar dynamics."
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### ⚙️ Prediction Parameters")
time_steps_slider = gr.Slider(
minimum=1, maximum=10, value=5, step=1,
label="Number of Input Time Steps (12-min cadence)",
info="Represents the sequence of past solar observations to feed the model."
)
forecast_horizon_slider = gr.Slider(
minimum=1, maximum=24, value=1,
label="Forecast Horizon (hours)",
info="How far into the future to predict."
)
predict_button = gr.Button("🔮 Generate Forecast", variant="primary")
with gr.Column(scale=2):
gr.Markdown("### 🛰️ Predicted Solar Image")
output_image = gr.Image(label="Forecasted SDO Image (AIA 171 Å)", height=512, width=512)
gr.Markdown("### 💥 Solar Flare Prediction")
output_flare = gr.Label(label="Flare Probability")
predict_button.click(
fn=predict_solar_activity,
inputs=[time_steps_slider, forecast_horizon_slider],
outputs=[output_image, output_flare]
)
gr.Markdown("---")
gr.Markdown(
"**Note:** This demo uses a placeholder for real-time data fetching and displays a simulated prediction. "
"The core of this application is the loaded Surya model from NASA and IBM."
)
gr.Markdown(
"For more information, visit the [Surya model card on Hugging Face](https://huggingface.co/nasa-ibm-ai4science/Surya-1.0)."
)
if __name__ == "__main__":
demo.launch() |