import gradio as gr import torch from huggingface_hub import hf_hub_download from surya.model import Surya import numpy as np from PIL import Image import warnings # Suppress warnings for a cleaner demo experience warnings.filterwarnings("ignore") # --- Model Loading --- @gr.cache def load_model(): """ Downloads the pre-trained Surya model weights and initializes the model. This function is cached so the model is only loaded once. """ checkpoint_path = hf_hub_download( repo_id="nasa-ibm-ai4science/Surya-1.0", filename="surya.366m.v1.pt" ) model = Surya( img_size=4096, patch_size=16, in_chans=13, embed_dim=1280, spectral_blocks=2, attention_blocks=8, ) model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) model.eval() return model model = load_model() # --- Core Prediction Logic --- def predict_solar_activity(time_steps, forecast_horizon): """ Generates a forecast of solar activity using the Surya model. For this demo, we use a dummy input tensor to simulate the model's input. In a real-world scenario, this function would fetch and preprocess actual SDO data for the given time steps. """ # Create a dummy input tensor representing a sequence of solar observations # Shape: [batch_size, channels, time_steps, height, width] dummy_input = torch.randn(1, 13, time_steps, 4096, 4096) # In a real application, you would replace the dummy input with actual, # preprocessed data from the Solar Dynamics Observatory (SDO). # Preprocessing would involve alignment and normalization as described # in the Surya paper. with torch.no_grad(): # The model's prediction would be based on the forecast_horizon # For this demo, we simulate a prediction by selecting a slice of the input prediction = model(dummy_input) # --- Visualization --- # For demonstration, we will visualize one of the output channels. # We will take the last predicted time step. predicted_image_tensor = prediction[0, 0, -1, :, :] # Visualizing the first channel # Normalize the tensor to a 0-255 range for image display normalized_tensor = (predicted_image_tensor - predicted_image_tensor.min()) / \ (predicted_image_tensor.max() - predicted_image_tensor.min()) image_array = (normalized_tensor * 255).byte().cpu().numpy() predicted_image = Image.fromarray(image_array) # For the flare prediction, we'll generate a dummy probability flare_probability = np.random.rand() if flare_probability > 0.5: flare_class = "M-class or X-class Flare" confidence = flare_probability else: flare_class = "No significant flare" confidence = 1 - flare_probability return predicted_image, {flare_class: confidence} # --- Gradio Interface Definition --- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """
# ☀️ Surya: Foundation Model for Heliophysics ☀️ *A Gradio Demo for NASA's Solar Foundation Model*
""" ) gr.Markdown( "Surya is a 366M-parameter foundation model trained on full-resolution, multi-instrument SDO observations. " "This demo showcases its capability to forecast solar dynamics." ) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### ⚙️ Prediction Parameters") time_steps_slider = gr.Slider( minimum=1, maximum=10, value=5, step=1, label="Number of Input Time Steps (12-min cadence)", info="Represents the sequence of past solar observations to feed the model." ) forecast_horizon_slider = gr.Slider( minimum=1, maximum=24, value=1, label="Forecast Horizon (hours)", info="How far into the future to predict." ) predict_button = gr.Button("🔮 Generate Forecast", variant="primary") with gr.Column(scale=2): gr.Markdown("### 🛰️ Predicted Solar Image") output_image = gr.Image(label="Forecasted SDO Image (AIA 171 Å)", height=512, width=512) gr.Markdown("### 💥 Solar Flare Prediction") output_flare = gr.Label(label="Flare Probability") predict_button.click( fn=predict_solar_activity, inputs=[time_steps_slider, forecast_horizon_slider], outputs=[output_image, output_flare] ) gr.Markdown("---") gr.Markdown( "**Note:** This demo uses a placeholder for real-time data fetching and displays a simulated prediction. " "The core of this application is the loaded Surya model from NASA and IBM." ) gr.Markdown( "For more information, visit the [Surya model card on Hugging Face](https://huggingface.co/nasa-ibm-ai4science/Surya-1.0)." ) if __name__ == "__main__": demo.launch()