File size: 5,044 Bytes
f966038
5f00446
 
f966038
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from surya.model import Surya
import numpy as np
from PIL import Image
import warnings

# Suppress warnings for a cleaner demo experience
warnings.filterwarnings("ignore")

# --- Model Loading ---
@gr.cache
def load_model():
    """
    Downloads the pre-trained Surya model weights and initializes the model.
    This function is cached so the model is only loaded once.
    """
    checkpoint_path = hf_hub_download(
        repo_id="nasa-ibm-ai4science/Surya-1.0",
        filename="surya.366m.v1.pt"
    )

    model = Surya(
        img_size=4096,
        patch_size=16,
        in_chans=13,
        embed_dim=1280,
        spectral_blocks=2,
        attention_blocks=8,
    )

    model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
    model.eval()
    return model

model = load_model()

# --- Core Prediction Logic ---
def predict_solar_activity(time_steps, forecast_horizon):
    """
    Generates a forecast of solar activity using the Surya model.
    For this demo, we use a dummy input tensor to simulate the model's input.
    In a real-world scenario, this function would fetch and preprocess
    actual SDO data for the given time steps.
    """
    # Create a dummy input tensor representing a sequence of solar observations
    # Shape: [batch_size, channels, time_steps, height, width]
    dummy_input = torch.randn(1, 13, time_steps, 4096, 4096)

    # In a real application, you would replace the dummy input with actual,
    # preprocessed data from the Solar Dynamics Observatory (SDO).
    # Preprocessing would involve alignment and normalization as described
    # in the Surya paper.

    with torch.no_grad():
        # The model's prediction would be based on the forecast_horizon
        # For this demo, we simulate a prediction by selecting a slice of the input
        prediction = model(dummy_input)

    # --- Visualization ---
    # For demonstration, we will visualize one of the output channels.
    # We will take the last predicted time step.
    predicted_image_tensor = prediction[0, 0, -1, :, :] # Visualizing the first channel

    # Normalize the tensor to a 0-255 range for image display
    normalized_tensor = (predicted_image_tensor - predicted_image_tensor.min()) / \
                        (predicted_image_tensor.max() - predicted_image_tensor.min())
    image_array = (normalized_tensor * 255).byte().cpu().numpy()
    predicted_image = Image.fromarray(image_array)

    # For the flare prediction, we'll generate a dummy probability
    flare_probability = np.random.rand()
    if flare_probability > 0.5:
        flare_class = "M-class or X-class Flare"
        confidence = flare_probability
    else:
        flare_class = "No significant flare"
        confidence = 1 - flare_probability

    return predicted_image, {flare_class: confidence}

# --- Gradio Interface Definition ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        <div align="center">
        # ☀️ Surya: Foundation Model for Heliophysics ☀️
        *A Gradio Demo for NASA's Solar Foundation Model*
        </div>
        """
    )
    gr.Markdown(
        "Surya is a 366M-parameter foundation model trained on full-resolution, multi-instrument SDO observations. "
        "This demo showcases its capability to forecast solar dynamics."
    )

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### ⚙️ Prediction Parameters")
            time_steps_slider = gr.Slider(
                minimum=1, maximum=10, value=5, step=1,
                label="Number of Input Time Steps (12-min cadence)",
                info="Represents the sequence of past solar observations to feed the model."
            )
            forecast_horizon_slider = gr.Slider(
                minimum=1, maximum=24, value=1,
                label="Forecast Horizon (hours)",
                info="How far into the future to predict."
            )
            predict_button = gr.Button("🔮 Generate Forecast", variant="primary")

        with gr.Column(scale=2):
            gr.Markdown("### 🛰️ Predicted Solar Image")
            output_image = gr.Image(label="Forecasted SDO Image (AIA 171 Å)", height=512, width=512)
            gr.Markdown("### 💥 Solar Flare Prediction")
            output_flare = gr.Label(label="Flare Probability")

    predict_button.click(
        fn=predict_solar_activity,
        inputs=[time_steps_slider, forecast_horizon_slider],
        outputs=[output_image, output_flare]
    )

    gr.Markdown("---")
    gr.Markdown(
        "**Note:** This demo uses a placeholder for real-time data fetching and displays a simulated prediction. "
        "The core of this application is the loaded Surya model from NASA and IBM."
    )
    gr.Markdown(
        "For more information, visit the [Surya model card on Hugging Face](https://huggingface.co/nasa-ibm-ai4science/Surya-1.0)."
    )

if __name__ == "__main__":
    demo.launch()