Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import torch.nn.functional as F
|
@@ -11,306 +13,271 @@ import os
|
|
11 |
import glob
|
12 |
import warnings
|
13 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
#
|
16 |
warnings.filterwarnings("ignore")
|
17 |
logging.basicConfig(level=logging.INFO)
|
18 |
logger = logging.getLogger(__name__)
|
19 |
|
20 |
-
# ---
|
21 |
-
# NOTE: To make this script self-contained, the required classes and functions
|
22 |
-
# from the 'surya' library are included directly here.
|
23 |
-
# In a full installation, these would be imported.
|
24 |
-
|
25 |
-
from surya_dependencies import (
|
26 |
-
HelioSpectFormer,
|
27 |
-
HelioNetCDFDataset,
|
28 |
-
build_scalers,
|
29 |
-
custom_collate_fn,
|
30 |
-
inverse_transform_single_channel,
|
31 |
-
SDO_CHANNELS,
|
32 |
-
AIA_CHANNELS,
|
33 |
-
HMI_CHANNELS
|
34 |
-
)
|
35 |
-
|
36 |
-
# --- Global Cache for Model and Data ---
|
37 |
-
# We use a simple dictionary to act as a cache to avoid reloading.
|
38 |
APP_CACHE = {
|
39 |
"model": None,
|
40 |
"config": None,
|
41 |
"scalers": None,
|
42 |
-
"
|
43 |
-
"dataloader": None,
|
44 |
"device": "cuda" if torch.cuda.is_available() else "cpu",
|
45 |
}
|
46 |
|
47 |
-
#
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
"""
|
51 |
-
|
52 |
-
|
53 |
-
This function is cached by Gradio to run only once.
|
54 |
"""
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
58 |
snapshot_download(
|
59 |
repo_id="nasa-ibm-ai4science/Surya-1.0",
|
60 |
-
local_dir=
|
61 |
allow_patterns=["config.yaml", "scalers.yaml", "surya.366m.v1.pt"],
|
62 |
)
|
63 |
-
|
64 |
-
# Download validation data
|
65 |
-
data_dir = "data/Surya-1.0_validation_data"
|
66 |
snapshot_download(
|
67 |
repo_id="nasa-ibm-ai4science/Surya-1.0_validation_data",
|
68 |
repo_type="dataset",
|
69 |
-
local_dir=
|
70 |
allow_patterns="20140107_1[5-9]??.nc",
|
71 |
)
|
72 |
|
73 |
-
#
|
74 |
-
|
75 |
-
|
76 |
-
|
|
|
77 |
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
mlp_ratio=config["model"]["mlp_ratio"],
|
112 |
-
drop_rate=config["model"]["drop_rate"],
|
113 |
-
dtype=torch.bfloat16,
|
114 |
-
window_size=config["model"]["window_size"],
|
115 |
-
dp_rank=config["model"]["dp_rank"],
|
116 |
-
learned_flow=config["model"]["learned_flow"],
|
117 |
-
use_latitude_in_learned_flow=config["model"]["learned_flow"],
|
118 |
-
init_weights=False,
|
119 |
-
checkpoint_layers=[i for i in range(config["model"]["depth"])],
|
120 |
-
rpe=config["model"]["rpe"],
|
121 |
-
ensemble=config["model"]["ensemble"],
|
122 |
-
finetune=config["model"]["finetune"],
|
123 |
-
)
|
124 |
-
|
125 |
-
# Load pre-trained weights
|
126 |
-
path_weights = os.path.join(model_dir, "surya.366m.v1.pt")
|
127 |
-
weights = torch.load(path_weights, map_location=torch.device(APP_CACHE["device"]))
|
128 |
-
model.load_state_dict(weights, strict=True)
|
129 |
-
model.to(APP_CACHE["device"])
|
130 |
-
model.eval()
|
131 |
-
|
132 |
-
n_params = sum(p.numel() for p in model.parameters()) / 1e6
|
133 |
-
logger.info(f"Surya FM: {n_params:.2f}M parameters loaded to {APP_CACHE['device']}.")
|
134 |
-
APP_CACHE["model"] = model
|
135 |
-
|
136 |
-
def get_dataloader(index_path):
|
137 |
-
"""Initializes and returns a DataLoader for the validation data."""
|
138 |
-
if APP_CACHE["dataloader"] is None:
|
139 |
-
logger.info("Initializing dataset and dataloader...")
|
140 |
-
config = APP_CACHE["config"]
|
141 |
-
dataset = HelioNetCDFDataset(
|
142 |
-
index_path=index_path,
|
143 |
-
time_delta_input_minutes=config["data"]["time_delta_input_minutes"],
|
144 |
-
time_delta_target_minutes=config["data"]["time_delta_target_minutes"],
|
145 |
-
n_input_timestamps=len(config["data"]["time_delta_input_minutes"]),
|
146 |
-
rollout_steps=1,
|
147 |
-
channels=config["data"]["sdo_channels"],
|
148 |
-
scalers=APP_CACHE["scalers"],
|
149 |
-
phase="valid", # Important: ensure no random augmentations
|
150 |
-
)
|
151 |
-
dataloader = DataLoader(
|
152 |
-
dataset, shuffle=False, batch_size=1, num_workers=2,
|
153 |
-
pin_memory=True, drop_last=False, collate_fn=custom_collate_fn,
|
154 |
-
)
|
155 |
-
APP_CACHE["dataloader"] = dataloader
|
156 |
-
APP_CACHE["dataset"] = dataset # Also cache dataset for transformation info
|
157 |
-
return APP_CACHE["dataloader"]
|
158 |
-
|
159 |
-
|
160 |
-
# --- 3. Core Inference and Visualization Logic ---
|
161 |
-
def run_model_inference():
|
162 |
"""
|
163 |
-
|
164 |
-
Returns the raw input, prediction, and ground truth tensors.
|
165 |
"""
|
|
|
|
|
|
|
166 |
model = APP_CACHE["model"]
|
167 |
-
|
168 |
device = APP_CACHE["device"]
|
169 |
|
170 |
-
#
|
171 |
-
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
with torch.no_grad():
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
|
189 |
-
def
|
190 |
"""
|
191 |
-
|
192 |
-
and converts them to displayable PIL Images.
|
193 |
"""
|
194 |
-
if
|
195 |
-
return None, None, None, "Please run the forecast
|
196 |
|
197 |
-
|
198 |
c_idx = SDO_CHANNELS.index(channel_name)
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
# --- Denormalize data for visualization ---
|
203 |
-
# Final input image given to the model (last in sequence)
|
204 |
input_slice = inverse_transform_single_channel(
|
205 |
-
|
206 |
mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx]
|
207 |
)
|
208 |
-
# Model's prediction
|
209 |
pred_slice = inverse_transform_single_channel(
|
210 |
-
|
211 |
mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx]
|
212 |
)
|
213 |
-
# Ground truth image
|
214 |
target_slice = inverse_transform_single_channel(
|
215 |
-
|
216 |
mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx]
|
217 |
)
|
218 |
-
|
219 |
-
#
|
220 |
-
|
221 |
-
vmax = np.quantile(np.concatenate([input_slice, pred_slice, target_slice]), 0.995)
|
222 |
-
|
223 |
-
# Determine colormap from channel name
|
224 |
cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
|
225 |
cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray'))
|
|
|
|
|
|
|
|
|
|
|
226 |
|
227 |
-
|
228 |
-
|
229 |
-
data_norm = (data_clipped - vmin) / (vmax - vmin)
|
230 |
-
return Image.fromarray((cmap(data_norm)[:, :, :3] * 255).astype(np.uint8))
|
231 |
|
232 |
-
return to_pil(input_slice), to_pil(pred_slice), to_pil(target_slice),
|
233 |
|
234 |
# --- 4. Gradio Controller Functions ---
|
235 |
-
def forecast_controller(channel_name, progress=gr.Progress()):
|
236 |
-
"""
|
237 |
-
Main function triggered by the 'Generate' button. Orchestrates the entire pipeline.
|
238 |
-
"""
|
239 |
-
progress(0, desc="Downloading model and data (first launch only)...")
|
240 |
-
index_path, model_dir = setup_environment_and_download_data()
|
241 |
|
242 |
-
|
243 |
-
|
244 |
-
|
|
|
245 |
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
img_in, img_pred, img_target, status = create_visualizations(channel_name, input_t, pred_t, target_t)
|
251 |
|
252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
|
254 |
|
255 |
# --- 5. Gradio UI Layout ---
|
256 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
257 |
-
#
|
258 |
-
|
259 |
-
|
260 |
-
state_target = gr.State()
|
261 |
-
|
262 |
gr.Markdown(
|
263 |
"""
|
264 |
-
<div align=
|
265 |
# ☀️ Surya: Live Model Demo ☀️
|
266 |
### An Interactive Interface for NASA's Heliophysics Foundation Model
|
267 |
This demo runs the **actual** Surya model on its official validation data for **2014-01-07**.
|
268 |
-
|
|
|
269 |
</div>
|
270 |
"""
|
271 |
)
|
272 |
-
|
273 |
-
with gr.Row():
|
274 |
-
channel_selector = gr.Dropdown(
|
275 |
-
choices=SDO_CHANNELS,
|
276 |
-
value="aia171",
|
277 |
-
label="🛰️ Select SDO Instrument Channel",
|
278 |
-
info="Choose which solar observation channel to visualize."
|
279 |
-
)
|
280 |
-
run_button = gr.Button("🔮 Generate Forecast for 2014-01-07", variant="primary", scale=2)
|
281 |
-
|
282 |
-
status_box = gr.Textbox(label="Status", interactive=False, value="Ready. Press 'Generate Forecast' to start.")
|
283 |
-
|
284 |
with gr.Row():
|
285 |
-
with gr.Column():
|
286 |
-
gr.
|
287 |
-
gr.
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
gr.
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
|
|
|
|
298 |
# --- Event Handlers ---
|
299 |
run_button.click(
|
300 |
fn=forecast_controller,
|
301 |
-
|
302 |
-
outputs=[input_display, prediction_display, label_display, status_box, state_input, state_prediction, state_target]
|
303 |
)
|
304 |
|
|
|
305 |
channel_selector.change(
|
306 |
-
fn=
|
307 |
-
inputs=[
|
308 |
-
outputs=[input_display, prediction_display,
|
|
|
|
|
|
|
|
|
|
|
309 |
)
|
310 |
|
311 |
if __name__ == "__main__":
|
312 |
-
# The 'surya_dependencies.py' file must be in the same directory as this script.
|
313 |
-
# Create the placeholder file if it doesn't exist.
|
314 |
-
if not os.path.exists("surya_dependencies.py"):
|
315 |
-
raise FileNotFoundError("The required 'surya_dependencies.py' file is missing. Please download it from the provided source.")
|
316 |
demo.launch(debug=True)
|
|
|
1 |
+
# Save this file as app.py in the root of the cloned Surya repository
|
2 |
+
|
3 |
import gradio as gr
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
|
|
13 |
import glob
|
14 |
import warnings
|
15 |
import logging
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
|
18 |
+
# --- Use the official Surya modules now that we are in the repo ---
|
19 |
+
from surya.datasets.helio import HelioNetCDFDataset, inverse_transform_single_channel
|
20 |
+
from surya.models.helio_spectformer import HelioSpectFormer
|
21 |
+
from surya.utils.data import build_scalers, custom_collate_fn
|
22 |
|
23 |
+
# Suppress verbose logging and warnings for a cleaner UI
|
24 |
warnings.filterwarnings("ignore")
|
25 |
logging.basicConfig(level=logging.INFO)
|
26 |
logger = logging.getLogger(__name__)
|
27 |
|
28 |
+
# --- Global cache to store expensive-to-load objects ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
APP_CACHE = {
|
30 |
"model": None,
|
31 |
"config": None,
|
32 |
"scalers": None,
|
33 |
+
"full_results": None, # Will store all prediction results
|
|
|
34 |
"device": "cuda" if torch.cuda.is_available() else "cpu",
|
35 |
}
|
36 |
|
37 |
+
# SDO channels from the test script for the dropdown menu
|
38 |
+
SDO_CHANNELS = [
|
39 |
+
"aia94", "aia131", "aia171", "aia193", "aia211", "aia304", "aia335",
|
40 |
+
"aia1600", "hmi_m", "hmi_bx", "hmi_by", "hmi_bz", "hmi_v",
|
41 |
+
]
|
42 |
+
|
43 |
+
# --- 1. Setup, Download, and Model Loading (adapting fixtures from test_surya.py) ---
|
44 |
+
|
45 |
+
def setup_and_load_model(progress=gr.Progress()):
|
46 |
"""
|
47 |
+
Handles all initial setup: downloading data, loading configs, and initializing the model.
|
48 |
+
This function will populate the APP_CACHE.
|
|
|
49 |
"""
|
50 |
+
if APP_CACHE["model"] is not None:
|
51 |
+
logger.info("Model and data already loaded. Skipping setup.")
|
52 |
+
return
|
53 |
+
|
54 |
+
# --- Part A: Download data (from download_data fixture) ---
|
55 |
+
progress(0.1, desc="Downloading model weights and config...")
|
56 |
snapshot_download(
|
57 |
repo_id="nasa-ibm-ai4science/Surya-1.0",
|
58 |
+
local_dir="data/Surya-1.0",
|
59 |
allow_patterns=["config.yaml", "scalers.yaml", "surya.366m.v1.pt"],
|
60 |
)
|
61 |
+
progress(0.3, desc="Downloading validation data for 2014-01-07...")
|
|
|
|
|
62 |
snapshot_download(
|
63 |
repo_id="nasa-ibm-ai4science/Surya-1.0_validation_data",
|
64 |
repo_type="dataset",
|
65 |
+
local_dir="data/Surya-1.0_validation_data",
|
66 |
allow_patterns="20140107_1[5-9]??.nc",
|
67 |
)
|
68 |
|
69 |
+
# --- Part B: Load Config and Scalers (from config & scalers fixtures) ---
|
70 |
+
progress(0.5, desc="Loading configuration and data scalers...")
|
71 |
+
with open("data/Surya-1.0/config.yaml") as fp:
|
72 |
+
config = yaml.safe_load(fp)
|
73 |
+
APP_CACHE["config"] = config
|
74 |
|
75 |
+
scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r"))
|
76 |
+
APP_CACHE["scalers"] = build_scalers(info=scalers_info)
|
77 |
+
|
78 |
+
# --- Part C: Initialize and load model (from model fixture and test function) ---
|
79 |
+
progress(0.7, desc="Initializing model architecture...")
|
80 |
+
model_config = config["model"]
|
81 |
+
model = HelioSpectFormer(
|
82 |
+
img_size=model_config["img_size"], patch_size=model_config["patch_size"],
|
83 |
+
in_chans=len(config["data"]["sdo_channels"]), embed_dim=model_config["embed_dim"],
|
84 |
+
time_embedding={"type": "linear", "time_dim": len(config["data"]["time_delta_input_minutes"])},
|
85 |
+
depth=model_config["depth"], n_spectral_blocks=model_config["n_spectral_blocks"],
|
86 |
+
num_heads=model_config["num_heads"], mlp_ratio=model_config["mlp_ratio"],
|
87 |
+
drop_rate=model_config["drop_rate"], dtype=torch.bfloat16,
|
88 |
+
window_size=model_config["window_size"], dp_rank=model_config["dp_rank"],
|
89 |
+
learned_flow=model_config["learned_flow"], use_latitude_in_learned_flow=model_config["learned_flow"],
|
90 |
+
init_weights=False, checkpoint_layers=list(range(model_config["depth"])),
|
91 |
+
rpe=model_config["rpe"], ensemble=model_config["ensemble"], finetune=model_config["finetune"],
|
92 |
+
)
|
93 |
+
|
94 |
+
progress(0.8, desc=f"Loading model weights to {APP_CACHE['device']}...")
|
95 |
+
path_weights = "data/Surya-1.0/surya.366m.v1.pt"
|
96 |
+
weights = torch.load(path_weights, map_location=torch.device(APP_CACHE["device"]))
|
97 |
+
model.load_state_dict(weights, strict=True)
|
98 |
+
model.to(APP_CACHE["device"])
|
99 |
+
model.eval()
|
100 |
+
|
101 |
+
n_params = sum(p.numel() for p in model.parameters()) / 1e6
|
102 |
+
logger.info(f"Surya FM: {n_params:.2f}M parameters loaded.")
|
103 |
+
APP_CACHE["model"] = model
|
104 |
+
|
105 |
+
# --- 2. Inference Logic (adapting the test loop) ---
|
106 |
+
|
107 |
+
def run_full_forecast():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
"""
|
109 |
+
Runs inference on the entire validation dataset and stores results.
|
|
|
110 |
"""
|
111 |
+
if APP_CACHE["full_results"] is not None:
|
112 |
+
return APP_CACHE["full_results"]
|
113 |
+
|
114 |
model = APP_CACHE["model"]
|
115 |
+
config = APP_CACHE["config"]
|
116 |
device = APP_CACHE["device"]
|
117 |
|
118 |
+
# Create the index file needed by the dataset loader
|
119 |
+
os.makedirs("tests", exist_ok=True)
|
120 |
+
with open("tests/test_surya_index.csv", "w") as f:
|
121 |
+
f.write("path\n")
|
122 |
+
search_path = os.path.join("data/Surya-1.0_validation_data", "**", "*.nc")
|
123 |
+
for nc_file in sorted(glob.glob(search_path, recursive=True)):
|
124 |
+
f.write(f"{nc_file}\n")
|
125 |
|
126 |
+
# Setup dataset and dataloader (from dataset & dataloader fixtures)
|
127 |
+
dataset = HelioNetCDFDataset(
|
128 |
+
index_path="tests/test_surya_index.csv",
|
129 |
+
time_delta_input_minutes=config["data"]["time_delta_input_minutes"],
|
130 |
+
time_delta_target_minutes=config["data"]["time_delta_target_minutes"],
|
131 |
+
n_input_timestamps=len(config["data"]["time_delta_input_minutes"]),
|
132 |
+
rollout_steps=1, channels=config["data"]["sdo_channels"],
|
133 |
+
scalers=APP_CACHE["scalers"], phase="valid",
|
134 |
+
)
|
135 |
+
dataloader = DataLoader(
|
136 |
+
dataset, shuffle=False, batch_size=1, num_workers=2,
|
137 |
+
pin_memory=True, collate_fn=custom_collate_fn
|
138 |
+
)
|
139 |
+
|
140 |
+
all_results = []
|
141 |
with torch.no_grad():
|
142 |
+
for batch_data, batch_metadata in dataloader:
|
143 |
+
input_batch = {k: v.to(device) for k, v in batch_data.items() if k in ["ts", "time_delta_input"]}
|
144 |
+
|
145 |
+
with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16):
|
146 |
+
prediction = model(input_batch)
|
147 |
+
|
148 |
+
# Store the relevant tensors on CPU for later visualization
|
149 |
+
result = {
|
150 |
+
"input": input_batch["ts"].cpu(),
|
151 |
+
"prediction": prediction.cpu(),
|
152 |
+
"target": batch_data["forecast"].cpu(),
|
153 |
+
"input_timestamp": np.datetime_as_string(batch_metadata["timestamps_input"][0][-1], unit='s'),
|
154 |
+
"target_timestamp": np.datetime_as_string(batch_metadata["timestamps_targets"][0][0], unit='s'),
|
155 |
+
}
|
156 |
+
all_results.append(result)
|
157 |
+
|
158 |
+
APP_CACHE["full_results"] = all_results
|
159 |
+
# Cache scalers needed for visualization
|
160 |
+
APP_CACHE["scalers_vis"] = dataset.transformation_inputs()
|
161 |
+
return all_results
|
162 |
+
|
163 |
+
# --- 3. Visualization Logic ---
|
164 |
|
165 |
+
def generate_visualization(results, timestep_index, channel_name):
|
166 |
"""
|
167 |
+
Generates PIL images for a specific timestep and channel from the results.
|
|
|
168 |
"""
|
169 |
+
if not results:
|
170 |
+
return None, None, None, "No results available. Please run the forecast.", ""
|
171 |
|
172 |
+
timestep_data = results[timestep_index]
|
173 |
c_idx = SDO_CHANNELS.index(channel_name)
|
174 |
+
means, stds, epsilons, sl_scale_factors = APP_CACHE["scalers_vis"]
|
175 |
+
|
176 |
+
# Denormalize data for visualization
|
|
|
|
|
177 |
input_slice = inverse_transform_single_channel(
|
178 |
+
timestep_data["input"][0, c_idx, -1].numpy(),
|
179 |
mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx]
|
180 |
)
|
|
|
181 |
pred_slice = inverse_transform_single_channel(
|
182 |
+
timestep_data["prediction"][0, c_idx].numpy(),
|
183 |
mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx]
|
184 |
)
|
|
|
185 |
target_slice = inverse_transform_single_channel(
|
186 |
+
timestep_data["target"][0, c_idx, 0].numpy(),
|
187 |
mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx]
|
188 |
)
|
189 |
+
|
190 |
+
# Convert to PIL Images using appropriate colormaps
|
191 |
+
vmax = np.quantile(target_slice, 0.995)
|
|
|
|
|
|
|
192 |
cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
|
193 |
cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray'))
|
194 |
+
|
195 |
+
def to_pil(data):
|
196 |
+
data_clipped = np.clip(data, 0, vmax)
|
197 |
+
data_norm = data_clipped / vmax
|
198 |
+
return Image.fromarray((cmap(data_norm)[:, :, :3] * 255).astype(np.uint8)).transpose(Image.Transpose.TRANSPOSE)
|
199 |
|
200 |
+
status_text = (f"Displaying Timestep {timestep_index+1}/{len(results)}\n"
|
201 |
+
f"Input: {timestep_data['input_timestamp']} | Forecast/Target: {timestep_data['target_timestamp']}")
|
|
|
|
|
202 |
|
203 |
+
return to_pil(input_slice), to_pil(pred_slice), to_pil(target_slice), status_text
|
204 |
|
205 |
# --- 4. Gradio Controller Functions ---
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
|
207 |
+
def forecast_controller(progress=gr.Progress(track_tqdm=True)):
|
208 |
+
"""Main function for the 'Generate Forecast' button."""
|
209 |
+
progress(0, desc="Starting setup...")
|
210 |
+
setup_and_load_model(progress)
|
211 |
|
212 |
+
logger.info("Running forecast on all validation timesteps...")
|
213 |
+
progress(0.9, desc="Running inference on validation data...")
|
214 |
+
results = run_full_forecast()
|
215 |
+
logger.info(f"Forecast complete. {len(results)} timesteps processed.")
|
|
|
216 |
|
217 |
+
# Generate the first visualization
|
218 |
+
img_in, img_pred, img_target, status = generate_visualization(results, 0, SDO_CHANNELS[2]) # Default to aia171
|
219 |
+
|
220 |
+
# Update the slider to be interactive and have the correct number of steps
|
221 |
+
slider_update = gr.Slider(minimum=1, maximum=len(results), step=1, value=1, interactive=True,
|
222 |
+
label="Forecast Timestep")
|
223 |
+
|
224 |
+
return results, img_in, img_pred, img_target, status, slider_update
|
225 |
+
|
226 |
+
def update_visualization_controller(results, timestep, channel):
|
227 |
+
"""Called when a slider or dropdown is changed."""
|
228 |
+
return generate_visualization(results, timestep - 1, channel)
|
229 |
|
230 |
|
231 |
# --- 5. Gradio UI Layout ---
|
232 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
233 |
+
# State object to hold the results of the full inference run
|
234 |
+
state_results = gr.State()
|
235 |
+
|
|
|
|
|
236 |
gr.Markdown(
|
237 |
"""
|
238 |
+
<div align='center'>
|
239 |
# ☀️ Surya: Live Model Demo ☀️
|
240 |
### An Interactive Interface for NASA's Heliophysics Foundation Model
|
241 |
This demo runs the **actual** Surya model on its official validation data for **2014-01-07**.
|
242 |
+
<br>
|
243 |
+
**Instructions:** 1. Click 'Generate Forecast'. 2. Use the controls to explore the results.
|
244 |
</div>
|
245 |
"""
|
246 |
)
|
247 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
with gr.Row():
|
249 |
+
with gr.Column(scale=1):
|
250 |
+
run_button = gr.Button("🔮 1. Generate Full Forecast", variant="primary")
|
251 |
+
status_box = gr.Textbox(label="Status", interactive=False, value="Ready.", lines=2)
|
252 |
+
channel_selector = gr.Dropdown(
|
253 |
+
choices=SDO_CHANNELS, value="aia171", label="🛰️ 2. Select SDO Channel"
|
254 |
+
)
|
255 |
+
timestep_slider = gr.Slider(
|
256 |
+
minimum=1, maximum=8, step=1, value=1, interactive=False, label="Forecast Timestep"
|
257 |
+
)
|
258 |
+
with gr.Column(scale=3):
|
259 |
+
with gr.Row():
|
260 |
+
input_display = gr.Image(label="Last Input", height=512, width=512, interactive=False)
|
261 |
+
prediction_display = gr.Image(label="Model Forecast", height=512, width=512, interactive=False)
|
262 |
+
target_display = gr.Image(label="Ground Truth", height=512, width=512, interactive=False)
|
263 |
+
|
264 |
# --- Event Handlers ---
|
265 |
run_button.click(
|
266 |
fn=forecast_controller,
|
267 |
+
outputs=[state_results, input_display, prediction_display, target_display, status_box, timestep_slider]
|
|
|
268 |
)
|
269 |
|
270 |
+
# When the user changes the channel or timestep, call the visualization update function
|
271 |
channel_selector.change(
|
272 |
+
fn=update_visualization_controller,
|
273 |
+
inputs=[state_results, timestep_slider, channel_selector],
|
274 |
+
outputs=[input_display, prediction_display, target_display, status_box]
|
275 |
+
)
|
276 |
+
timestep_slider.change(
|
277 |
+
fn=update_visualization_controller,
|
278 |
+
inputs=[state_results, timestep_slider, channel_selector],
|
279 |
+
outputs=[input_display, prediction_display, target_display, status_box]
|
280 |
)
|
281 |
|
282 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
283 |
demo.launch(debug=True)
|