Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,26 +1,49 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
from huggingface_hub import hf_hub_download
|
4 |
-
from surya.model import Surya
|
5 |
import numpy as np
|
6 |
from PIL import Image
|
|
|
7 |
import warnings
|
8 |
|
9 |
# Suppress warnings for a cleaner demo experience
|
10 |
warnings.filterwarnings("ignore")
|
11 |
|
12 |
-
# ---
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
@gr.cache
|
14 |
-
def
|
15 |
"""
|
16 |
-
Downloads the pre-trained Surya model
|
17 |
-
This function is cached so
|
18 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
checkpoint_path = hf_hub_download(
|
20 |
repo_id="nasa-ibm-ai4science/Surya-1.0",
|
21 |
-
filename="surya.366m.v1.pt"
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
)
|
|
|
23 |
|
|
|
24 |
model = Surya(
|
25 |
img_size=4096,
|
26 |
patch_size=16,
|
@@ -30,106 +53,120 @@ def load_model():
|
|
30 |
attention_blocks=8,
|
31 |
)
|
32 |
|
|
|
|
|
33 |
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
|
34 |
model.eval()
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
-
model
|
38 |
|
39 |
-
# ---
|
40 |
-
def
|
41 |
"""
|
42 |
-
|
43 |
-
For this demo, we use a dummy input tensor to simulate the model's input.
|
44 |
-
In a real-world scenario, this function would fetch and preprocess
|
45 |
-
actual SDO data for the given time steps.
|
46 |
"""
|
47 |
-
#
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
-
# In a real application, you would replace the dummy input with actual,
|
52 |
-
# preprocessed data from the Solar Dynamics Observatory (SDO).
|
53 |
-
# Preprocessing would involve alignment and normalization as described
|
54 |
-
# in the Surya paper.
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
with torch.no_grad():
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
#
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
return predicted_image, {flare_class: confidence}
|
82 |
-
|
83 |
-
# --- Gradio Interface Definition ---
|
84 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
85 |
gr.Markdown(
|
86 |
"""
|
87 |
<div align="center">
|
88 |
-
# ☀️ Surya:
|
89 |
-
|
|
|
90 |
</div>
|
91 |
"""
|
92 |
)
|
93 |
-
gr.Markdown(
|
94 |
-
"Surya is a 366M-parameter foundation model trained on full-resolution, multi-instrument SDO observations. "
|
95 |
-
"This demo showcases its capability to forecast solar dynamics."
|
96 |
-
)
|
97 |
|
98 |
with gr.Row():
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
minimum=1, maximum=24, value=1,
|
108 |
-
label="Forecast Horizon (hours)",
|
109 |
-
info="How far into the future to predict."
|
110 |
-
)
|
111 |
-
predict_button = gr.Button("🔮 Generate Forecast", variant="primary")
|
112 |
-
|
113 |
-
with gr.Column(scale=2):
|
114 |
-
gr.Markdown("### 🛰️ Predicted Solar Image")
|
115 |
-
output_image = gr.Image(label="Forecasted SDO Image (AIA 171 Å)", height=512, width=512)
|
116 |
-
gr.Markdown("### 💥 Solar Flare Prediction")
|
117 |
-
output_flare = gr.Label(label="Flare Probability")
|
118 |
-
|
119 |
-
predict_button.click(
|
120 |
-
fn=predict_solar_activity,
|
121 |
-
inputs=[time_steps_slider, forecast_horizon_slider],
|
122 |
-
outputs=[output_image, output_flare]
|
123 |
-
)
|
124 |
|
125 |
-
gr.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
gr.Markdown(
|
127 |
-
"
|
128 |
-
"The
|
|
|
|
|
129 |
)
|
130 |
-
|
131 |
-
|
|
|
|
|
|
|
132 |
)
|
133 |
|
134 |
if __name__ == "__main__":
|
135 |
-
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
from huggingface_hub import hf_hub_download
|
4 |
+
from surya.model import Surya # This now works because of the file structure
|
5 |
import numpy as np
|
6 |
from PIL import Image
|
7 |
+
import os
|
8 |
import warnings
|
9 |
|
10 |
# Suppress warnings for a cleaner demo experience
|
11 |
warnings.filterwarnings("ignore")
|
12 |
|
13 |
+
# --- 1. Define Constants and Data Channels ---
|
14 |
+
# Based on the Surya project's data preprocessing
|
15 |
+
AIA_CHANNELS = ["94", "131", "171", "193", "211", "304", "335", "1600"]
|
16 |
+
HMI_CHANNELS = ["bx", "by", "bz", "by_abs", "bz_abs"]
|
17 |
+
ALL_CHANNELS = [f"AIA {ch} Å" for ch in AIA_CHANNELS] + [f"HMI {ch}" for ch in HMI_CHANNELS]
|
18 |
+
|
19 |
+
# --- 2. Caching and Loading the Model and Data ---
|
20 |
@gr.cache
|
21 |
+
def load_model_and_data():
|
22 |
"""
|
23 |
+
Downloads the pre-trained Surya model, the test data, and initializes the model.
|
24 |
+
This function is cached so this happens only once.
|
25 |
"""
|
26 |
+
print("Downloading model and test data... This may take a moment.")
|
27 |
+
# Define local directories for caching
|
28 |
+
model_dir = "./surya_model"
|
29 |
+
data_dir = "./surya_data"
|
30 |
+
os.makedirs(model_dir, exist_ok=True)
|
31 |
+
os.makedirs(data_dir, exist_ok=True)
|
32 |
+
|
33 |
+
# Download the model weights and test data from Hugging Face
|
34 |
checkpoint_path = hf_hub_download(
|
35 |
repo_id="nasa-ibm-ai4science/Surya-1.0",
|
36 |
+
filename="surya.366m.v1.pt",
|
37 |
+
local_dir=model_dir
|
38 |
+
)
|
39 |
+
test_data_path = hf_hub_download(
|
40 |
+
repo_id="nasa-ibm-ai4science/Surya-1.0",
|
41 |
+
filename="test_data.pt",
|
42 |
+
local_dir=data_dir
|
43 |
)
|
44 |
+
print("Downloads complete.")
|
45 |
|
46 |
+
# Initialize the model architecture
|
47 |
model = Surya(
|
48 |
img_size=4096,
|
49 |
patch_size=16,
|
|
|
53 |
attention_blocks=8,
|
54 |
)
|
55 |
|
56 |
+
# Load the weights into the model
|
57 |
+
print("Loading model weights...")
|
58 |
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
|
59 |
model.eval()
|
60 |
+
print("Model loaded successfully.")
|
61 |
+
|
62 |
+
# Load the test data
|
63 |
+
test_data = torch.load(test_data_path)
|
64 |
+
test_input = test_data["input"] # Input tensor for the model
|
65 |
+
test_label = test_data["label"] # Ground truth for comparison
|
66 |
|
67 |
+
return model, test_input, test_label
|
68 |
|
69 |
+
# --- 3. Helper function for Image Conversion ---
|
70 |
+
def tensor_to_image(tensor_slice):
|
71 |
"""
|
72 |
+
Normalizes a 2D tensor slice and converts it to a PIL Image for display.
|
|
|
|
|
|
|
73 |
"""
|
74 |
+
# Detach tensor from graph, move to CPU, and convert to numpy
|
75 |
+
img_np = tensor_slice.detach().cpu().numpy()
|
76 |
+
|
77 |
+
# Normalize the tensor to a 0-255 range for image display
|
78 |
+
min_val, max_val = np.min(img_np), np.max(img_np)
|
79 |
+
if max_val > min_val:
|
80 |
+
img_np = (img_np - min_val) / (max_val - min_val)
|
81 |
+
|
82 |
+
img_array = (img_np * 255).astype(np.uint8)
|
83 |
+
return Image.fromarray(img_array)
|
84 |
|
|
|
|
|
|
|
|
|
85 |
|
86 |
+
# --- 4. Main Prediction and Visualization Function ---
|
87 |
+
def run_forecast(channel_name, progress=gr.Progress()):
|
88 |
+
"""
|
89 |
+
This function is triggered by the button click in the Gradio interface.
|
90 |
+
It runs the model prediction and generates the images for display.
|
91 |
+
"""
|
92 |
+
progress(0, desc="Loading model and data (first run may be slow)...")
|
93 |
+
# Load the model and data (will be fast after the first run due to caching)
|
94 |
+
model, test_input, test_label = load_model_and_data()
|
95 |
+
|
96 |
+
progress(0.5, desc="Running inference on the model...")
|
97 |
+
# Perform the forecast
|
98 |
with torch.no_grad():
|
99 |
+
prediction = model(test_input)
|
100 |
+
|
101 |
+
progress(0.8, desc="Generating visualizations...")
|
102 |
+
# Get the index of the selected channel
|
103 |
+
channel_index = ALL_CHANNELS.index(channel_name)
|
104 |
+
|
105 |
+
# Extract the last time step from the input sequence for display
|
106 |
+
# Shape: [batch, channels, time, height, width] -> select channel, last time step
|
107 |
+
input_slice = test_input[0, channel_index, -1, :, :]
|
108 |
+
input_image = tensor_to_image(input_slice)
|
109 |
+
|
110 |
+
# Extract the corresponding slice from the model's prediction
|
111 |
+
# Shape: [batch, channels, time, height, width] -> select channel, first predicted step
|
112 |
+
predicted_slice = prediction[0, channel_index, 0, :, :]
|
113 |
+
predicted_image = tensor_to_image(predicted_slice)
|
114 |
+
|
115 |
+
# Extract the corresponding slice from the ground truth label
|
116 |
+
label_slice = test_label[0, channel_index, 0, :, :]
|
117 |
+
label_image = tensor_to_image(label_slice)
|
118 |
+
|
119 |
+
print(f"Forecast generated for channel: {channel_name}")
|
120 |
+
return input_image, predicted_image, label_image
|
121 |
+
|
122 |
+
# --- 5. Building the Gradio Interface ---
|
|
|
|
|
|
|
123 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
124 |
gr.Markdown(
|
125 |
"""
|
126 |
<div align="center">
|
127 |
+
# ☀️ Surya: A Live Demonstration of NASA's Heliophysics Foundation Model ☀️
|
128 |
+
This demo runs the actual Surya model to forecast solar activity. It uses the official test data for **2014-01-07**,
|
129 |
+
allowing a direct comparison between the model's prediction and the real ground truth.
|
130 |
</div>
|
131 |
"""
|
132 |
)
|
|
|
|
|
|
|
|
|
133 |
|
134 |
with gr.Row():
|
135 |
+
channel_selector = gr.Dropdown(
|
136 |
+
choices=ALL_CHANNELS,
|
137 |
+
value=ALL_CHANNELS[2], # Default to "AIA 171 Å"
|
138 |
+
label="🛰️ Select SDO Instrument Channel",
|
139 |
+
info="Choose which solar observation channel to visualize."
|
140 |
+
)
|
141 |
+
|
142 |
+
run_button = gr.Button("🔮 Generate Forecast for 2014-01-07", variant="primary")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
+
with gr.Row():
|
145 |
+
with gr.Column():
|
146 |
+
gr.Markdown("### ⬅️ Final Input Image")
|
147 |
+
gr.Markdown("The last image shown to the model before it makes a prediction.")
|
148 |
+
input_display = gr.Image(label="Input Observation", height=400, width=400)
|
149 |
+
with gr.Column():
|
150 |
+
gr.Markdown("### 🔮 Model's Forecast")
|
151 |
+
gr.Markdown("What the Surya model predicted the Sun would look like.")
|
152 |
+
prediction_display = gr.Image(label="Surya Prediction", height=400, width=400)
|
153 |
+
with gr.Column():
|
154 |
+
gr.Markdown("### ✅ Ground Truth")
|
155 |
+
gr.Markdown("What the Sun *actually* looked like at the forecast time.")
|
156 |
+
label_display = gr.Image(label="Actual Observation", height=400, width=400)
|
157 |
+
|
158 |
gr.Markdown(
|
159 |
+
"--- \n"
|
160 |
+
"**Note:** The first time you run a forecast, the app will download the 366M-parameter model (~1.4 GB) and test data. Subsequent runs will be much faster. "
|
161 |
+
"The images are downscaled for display in this demo. "
|
162 |
+
"For more information, visit the [Surya Hugging Face Repository](https://huggingface.co/nasa-ibm-ai4science/Surya-1.0)."
|
163 |
)
|
164 |
+
|
165 |
+
run_button.click(
|
166 |
+
fn=run_forecast,
|
167 |
+
inputs=[channel_selector],
|
168 |
+
outputs=[input_display, prediction_display, label_display]
|
169 |
)
|
170 |
|
171 |
if __name__ == "__main__":
|
172 |
+
demo.launch(debug=True)
|