File size: 6,705 Bytes
f966038
5f00446
 
f43722b
f966038
 
f43722b
f966038
 
 
 
 
f43722b
 
 
 
 
 
 
f966038
f43722b
f966038
f43722b
 
f966038
f43722b
 
 
 
 
 
 
 
f966038
 
f43722b
 
 
 
 
 
 
f966038
f43722b
f966038
f43722b
f966038
 
 
 
 
 
 
 
 
f43722b
 
f966038
 
f43722b
 
 
 
 
 
f966038
f43722b
f966038
f43722b
 
f966038
f43722b
f966038
f43722b
 
 
 
 
 
 
 
 
 
f966038
 
f43722b
 
 
 
 
 
 
 
 
 
 
 
f966038
f43722b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f966038
 
 
 
f43722b
 
 
f966038
 
 
 
 
f43722b
 
 
 
 
 
 
 
f966038
f43722b
 
 
 
 
 
 
 
 
 
 
 
 
 
f966038
f43722b
 
 
 
f966038
f43722b
 
 
 
 
f966038
 
 
f43722b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from surya.model import Surya # This now works because of the file structure
import numpy as np
from PIL import Image
import os
import warnings

# Suppress warnings for a cleaner demo experience
warnings.filterwarnings("ignore")

# --- 1. Define Constants and Data Channels ---
# Based on the Surya project's data preprocessing
AIA_CHANNELS = ["94", "131", "171", "193", "211", "304", "335", "1600"]
HMI_CHANNELS = ["bx", "by", "bz", "by_abs", "bz_abs"]
ALL_CHANNELS = [f"AIA {ch} Å" for ch in AIA_CHANNELS] + [f"HMI {ch}" for ch in HMI_CHANNELS]

# --- 2. Caching and Loading the Model and Data ---
@gr.cache
def load_model_and_data():
    """
    Downloads the pre-trained Surya model, the test data, and initializes the model.
    This function is cached so this happens only once.
    """
    print("Downloading model and test data... This may take a moment.")
    # Define local directories for caching
    model_dir = "./surya_model"
    data_dir = "./surya_data"
    os.makedirs(model_dir, exist_ok=True)
    os.makedirs(data_dir, exist_ok=True)

    # Download the model weights and test data from Hugging Face
    checkpoint_path = hf_hub_download(
        repo_id="nasa-ibm-ai4science/Surya-1.0",
        filename="surya.366m.v1.pt",
        local_dir=model_dir
    )
    test_data_path = hf_hub_download(
        repo_id="nasa-ibm-ai4science/Surya-1.0",
        filename="test_data.pt",
        local_dir=data_dir
    )
    print("Downloads complete.")

    # Initialize the model architecture
    model = Surya(
        img_size=4096,
        patch_size=16,
        in_chans=13,
        embed_dim=1280,
        spectral_blocks=2,
        attention_blocks=8,
    )

    # Load the weights into the model
    print("Loading model weights...")
    model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
    model.eval()
    print("Model loaded successfully.")

    # Load the test data
    test_data = torch.load(test_data_path)
    test_input = test_data["input"]   # Input tensor for the model
    test_label = test_data["label"]   # Ground truth for comparison

    return model, test_input, test_label

# --- 3. Helper function for Image Conversion ---
def tensor_to_image(tensor_slice):
    """
    Normalizes a 2D tensor slice and converts it to a PIL Image for display.
    """
    # Detach tensor from graph, move to CPU, and convert to numpy
    img_np = tensor_slice.detach().cpu().numpy()
    
    # Normalize the tensor to a 0-255 range for image display
    min_val, max_val = np.min(img_np), np.max(img_np)
    if max_val > min_val:
        img_np = (img_np - min_val) / (max_val - min_val)
    
    img_array = (img_np * 255).astype(np.uint8)
    return Image.fromarray(img_array)


# --- 4. Main Prediction and Visualization Function ---
def run_forecast(channel_name, progress=gr.Progress()):
    """
    This function is triggered by the button click in the Gradio interface.
    It runs the model prediction and generates the images for display.
    """
    progress(0, desc="Loading model and data (first run may be slow)...")
    # Load the model and data (will be fast after the first run due to caching)
    model, test_input, test_label = load_model_and_data()
    
    progress(0.5, desc="Running inference on the model...")
    # Perform the forecast
    with torch.no_grad():
        prediction = model(test_input)

    progress(0.8, desc="Generating visualizations...")
    # Get the index of the selected channel
    channel_index = ALL_CHANNELS.index(channel_name)

    # Extract the last time step from the input sequence for display
    # Shape: [batch, channels, time, height, width] -> select channel, last time step
    input_slice = test_input[0, channel_index, -1, :, :]
    input_image = tensor_to_image(input_slice)

    # Extract the corresponding slice from the model's prediction
    # Shape: [batch, channels, time, height, width] -> select channel, first predicted step
    predicted_slice = prediction[0, channel_index, 0, :, :]
    predicted_image = tensor_to_image(predicted_slice)

    # Extract the corresponding slice from the ground truth label
    label_slice = test_label[0, channel_index, 0, :, :]
    label_image = tensor_to_image(label_slice)
    
    print(f"Forecast generated for channel: {channel_name}")
    return input_image, predicted_image, label_image

# --- 5. Building the Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        <div align="center">
        # ☀️ Surya: A Live Demonstration of NASA's Heliophysics Foundation Model ☀️
        This demo runs the actual Surya model to forecast solar activity. It uses the official test data for **2014-01-07**,
        allowing a direct comparison between the model's prediction and the real ground truth.
        </div>
        """
    )

    with gr.Row():
        channel_selector = gr.Dropdown(
            choices=ALL_CHANNELS,
            value=ALL_CHANNELS[2], # Default to "AIA 171 Å"
            label="🛰️ Select SDO Instrument Channel",
            info="Choose which solar observation channel to visualize."
        )

    run_button = gr.Button("🔮 Generate Forecast for 2014-01-07", variant="primary")

    with gr.Row():
        with gr.Column():
            gr.Markdown("### ⬅️ Final Input Image")
            gr.Markdown("The last image shown to the model before it makes a prediction.")
            input_display = gr.Image(label="Input Observation", height=400, width=400)
        with gr.Column():
            gr.Markdown("### 🔮 Model's Forecast")
            gr.Markdown("What the Surya model predicted the Sun would look like.")
            prediction_display = gr.Image(label="Surya Prediction", height=400, width=400)
        with gr.Column():
            gr.Markdown("### ✅ Ground Truth")
            gr.Markdown("What the Sun *actually* looked like at the forecast time.")
            label_display = gr.Image(label="Actual Observation", height=400, width=400)
            
    gr.Markdown(
        "--- \n"
        "**Note:** The first time you run a forecast, the app will download the 366M-parameter model (~1.4 GB) and test data. Subsequent runs will be much faster. "
        "The images are downscaled for display in this demo. "
        "For more information, visit the [Surya Hugging Face Repository](https://huggingface.co/nasa-ibm-ai4science/Surya-1.0)."
    )
    
    run_button.click(
        fn=run_forecast,
        inputs=[channel_selector],
        outputs=[input_display, prediction_display, label_display]
    )

if __name__ == "__main__":
    demo.launch(debug=True)