broadfield-dev commited on
Commit
2870322
·
verified ·
1 Parent(s): 5d92716

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +192 -225
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import gradio as gr
2
  import torch
3
  import torch.nn.functional as F
@@ -11,306 +13,271 @@ import os
11
  import glob
12
  import warnings
13
  import logging
 
 
 
 
 
 
14
 
15
- # --- Suppress verbose logging and warnings for a cleaner UI ---
16
  warnings.filterwarnings("ignore")
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
19
 
20
- # --- Dependencies from the Surya Repository ---
21
- # NOTE: To make this script self-contained, the required classes and functions
22
- # from the 'surya' library are included directly here.
23
- # In a full installation, these would be imported.
24
-
25
- from surya_dependencies import (
26
- HelioSpectFormer,
27
- HelioNetCDFDataset,
28
- build_scalers,
29
- custom_collate_fn,
30
- inverse_transform_single_channel,
31
- SDO_CHANNELS,
32
- AIA_CHANNELS,
33
- HMI_CHANNELS
34
- )
35
-
36
- # --- Global Cache for Model and Data ---
37
- # We use a simple dictionary to act as a cache to avoid reloading.
38
  APP_CACHE = {
39
  "model": None,
40
  "config": None,
41
  "scalers": None,
42
- "dataset": None,
43
- "dataloader": None,
44
  "device": "cuda" if torch.cuda.is_available() else "cpu",
45
  }
46
 
47
- # --- 1. Setup and Data Download ---
48
- @gr.cache
49
- def setup_environment_and_download_data():
 
 
 
 
 
 
50
  """
51
- Downloads all necessary files from Hugging Face: model, config, scalers, and validation data.
52
- Also creates the necessary index file for the dataset loader.
53
- This function is cached by Gradio to run only once.
54
  """
55
- logger.info("Setting up environment. This will run only once.")
56
- local_dir = "data/Surya-1.0"
57
- # Download model, config, and scalers
 
 
 
58
  snapshot_download(
59
  repo_id="nasa-ibm-ai4science/Surya-1.0",
60
- local_dir=local_dir,
61
  allow_patterns=["config.yaml", "scalers.yaml", "surya.366m.v1.pt"],
62
  )
63
-
64
- # Download validation data
65
- data_dir = "data/Surya-1.0_validation_data"
66
  snapshot_download(
67
  repo_id="nasa-ibm-ai4science/Surya-1.0_validation_data",
68
  repo_type="dataset",
69
- local_dir=data_dir,
70
  allow_patterns="20140107_1[5-9]??.nc",
71
  )
72
 
73
- # The test script requires an index file. We'll create it dynamically.
74
- index_dir = "data/test_indices"
75
- os.makedirs(index_dir, exist_ok=True)
76
- index_file_path = os.path.join(index_dir, "test_surya_index.csv")
 
77
 
78
- with open(index_file_path, "w") as f:
79
- f.write("path\n")
80
- # Find the downloaded NetCDF files and write their paths to the index
81
- search_path = os.path.join(data_dir, "**", "*.nc")
82
- for nc_file in sorted(glob.glob(search_path, recursive=True)):
83
- f.write(f"{nc_file}\n")
84
- logger.info(f"Created index file at {index_file_path}")
85
- return index_file_path, local_dir
86
-
87
- # --- 2. Model and Data Loading ---
88
- def load_essentials(model_dir):
89
- """Loads config, scalers, and the model into the APP_CACHE."""
90
- if APP_CACHE["model"] is None:
91
- logger.info("Loading config, scalers, and model for the first time...")
92
- # Load config
93
- with open(os.path.join(model_dir, "config.yaml")) as fp:
94
- config = yaml.safe_load(fp)
95
- APP_CACHE["config"] = config
96
-
97
- # Build scalers for data normalization
98
- scalers_info = yaml.safe_load(open(os.path.join(model_dir, "scalers.yaml"), "r"))
99
- APP_CACHE["scalers"] = build_scalers(info=scalers_info)
100
-
101
- # Initialize model from config
102
- model = HelioSpectFormer(
103
- img_size=config["model"]["img_size"],
104
- patch_size=config["model"]["patch_size"],
105
- in_chans=len(config["data"]["sdo_channels"]),
106
- embed_dim=config["model"]["embed_dim"],
107
- time_embedding={"type": "linear", "time_dim": len(config["data"]["time_delta_input_minutes"])},
108
- depth=config["model"]["depth"],
109
- n_spectral_blocks=config["model"]["n_spectral_blocks"],
110
- num_heads=config["model"]["num_heads"],
111
- mlp_ratio=config["model"]["mlp_ratio"],
112
- drop_rate=config["model"]["drop_rate"],
113
- dtype=torch.bfloat16,
114
- window_size=config["model"]["window_size"],
115
- dp_rank=config["model"]["dp_rank"],
116
- learned_flow=config["model"]["learned_flow"],
117
- use_latitude_in_learned_flow=config["model"]["learned_flow"],
118
- init_weights=False,
119
- checkpoint_layers=[i for i in range(config["model"]["depth"])],
120
- rpe=config["model"]["rpe"],
121
- ensemble=config["model"]["ensemble"],
122
- finetune=config["model"]["finetune"],
123
- )
124
-
125
- # Load pre-trained weights
126
- path_weights = os.path.join(model_dir, "surya.366m.v1.pt")
127
- weights = torch.load(path_weights, map_location=torch.device(APP_CACHE["device"]))
128
- model.load_state_dict(weights, strict=True)
129
- model.to(APP_CACHE["device"])
130
- model.eval()
131
-
132
- n_params = sum(p.numel() for p in model.parameters()) / 1e6
133
- logger.info(f"Surya FM: {n_params:.2f}M parameters loaded to {APP_CACHE['device']}.")
134
- APP_CACHE["model"] = model
135
-
136
- def get_dataloader(index_path):
137
- """Initializes and returns a DataLoader for the validation data."""
138
- if APP_CACHE["dataloader"] is None:
139
- logger.info("Initializing dataset and dataloader...")
140
- config = APP_CACHE["config"]
141
- dataset = HelioNetCDFDataset(
142
- index_path=index_path,
143
- time_delta_input_minutes=config["data"]["time_delta_input_minutes"],
144
- time_delta_target_minutes=config["data"]["time_delta_target_minutes"],
145
- n_input_timestamps=len(config["data"]["time_delta_input_minutes"]),
146
- rollout_steps=1,
147
- channels=config["data"]["sdo_channels"],
148
- scalers=APP_CACHE["scalers"],
149
- phase="valid", # Important: ensure no random augmentations
150
- )
151
- dataloader = DataLoader(
152
- dataset, shuffle=False, batch_size=1, num_workers=2,
153
- pin_memory=True, drop_last=False, collate_fn=custom_collate_fn,
154
- )
155
- APP_CACHE["dataloader"] = dataloader
156
- APP_CACHE["dataset"] = dataset # Also cache dataset for transformation info
157
- return APP_CACHE["dataloader"]
158
-
159
-
160
- # --- 3. Core Inference and Visualization Logic ---
161
- def run_model_inference():
162
  """
163
- Performs a single prediction step using the loaded model and dataloader.
164
- Returns the raw input, prediction, and ground truth tensors.
165
  """
 
 
 
166
  model = APP_CACHE["model"]
167
- dataloader = APP_CACHE["dataloader"]
168
  device = APP_CACHE["device"]
169
 
170
- # Get the first (and only) batch of data from the validation set
171
- batch_data, batch_metadata = next(iter(dataloader))
 
 
 
 
 
172
 
173
- logger.info("Running inference on the validation batch...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  with torch.no_grad():
175
- # Prepare input batch for the model
176
- input_batch = {key: batch_data[key].to(device) for key in ["ts", "time_delta_input"]}
177
- # Run model prediction
178
- with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16):
179
- prediction_tensor = model(input_batch)
180
-
181
- # Get the input and target tensors for comparison
182
- input_tensor = input_batch["ts"].to(dtype=torch.float32).cpu()
183
- target_tensor = batch_data["forecast"].cpu()
184
- prediction_tensor = prediction_tensor.to(dtype=torch.float32).cpu()
185
-
186
- logger.info("Inference complete.")
187
- return input_tensor, prediction_tensor, target_tensor
 
 
 
 
 
 
 
 
 
188
 
189
- def create_visualizations(channel_name, input_tensor, prediction_tensor, target_tensor):
190
  """
191
- Takes raw tensors and a channel name, applies inverse transformation,
192
- and converts them to displayable PIL Images.
193
  """
194
- if input_tensor is None:
195
- return None, None, None, "Please run the forecast first."
196
 
197
- logger.info(f"Creating visualization for channel: {channel_name}")
198
  c_idx = SDO_CHANNELS.index(channel_name)
199
- dataset = APP_CACHE["dataset"]
200
- means, stds, epsilons, sl_scale_factors = dataset.transformation_inputs()
201
-
202
- # --- Denormalize data for visualization ---
203
- # Final input image given to the model (last in sequence)
204
  input_slice = inverse_transform_single_channel(
205
- input_tensor[0, c_idx, -1, :, :].numpy(),
206
  mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx]
207
  )
208
- # Model's prediction
209
  pred_slice = inverse_transform_single_channel(
210
- prediction_tensor[0, c_idx, :, :].numpy(),
211
  mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx]
212
  )
213
- # Ground truth image
214
  target_slice = inverse_transform_single_channel(
215
- target_tensor[0, c_idx, 0, :, :].numpy(),
216
  mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx]
217
  )
218
-
219
- # --- Convert to images ---
220
- # Use a shared color scale for better comparison, clipped at 99.5th percentile
221
- vmax = np.quantile(np.concatenate([input_slice, pred_slice, target_slice]), 0.995)
222
-
223
- # Determine colormap from channel name
224
  cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
225
  cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray'))
 
 
 
 
 
226
 
227
- def to_pil(data, vmin=0, vmax=vmax, cmap=cmap):
228
- data_clipped = np.clip(data, vmin, vmax)
229
- data_norm = (data_clipped - vmin) / (vmax - vmin)
230
- return Image.fromarray((cmap(data_norm)[:, :, :3] * 255).astype(np.uint8))
231
 
232
- return to_pil(input_slice), to_pil(pred_slice), to_pil(target_slice), f"Displaying forecast for {channel_name}"
233
 
234
  # --- 4. Gradio Controller Functions ---
235
- def forecast_controller(channel_name, progress=gr.Progress()):
236
- """
237
- Main function triggered by the 'Generate' button. Orchestrates the entire pipeline.
238
- """
239
- progress(0, desc="Downloading model and data (first launch only)...")
240
- index_path, model_dir = setup_environment_and_download_data()
241
 
242
- progress(0.4, desc="Loading model and building data pipeline...")
243
- load_essentials(model_dir)
244
- get_dataloader(index_path)
 
245
 
246
- progress(0.7, desc=f"Running inference on {APP_CACHE['device']}...")
247
- input_t, pred_t, target_t = run_model_inference()
248
-
249
- progress(0.9, desc="Creating visualizations...")
250
- img_in, img_pred, img_target, status = create_visualizations(channel_name, input_t, pred_t, target_t)
251
 
252
- return img_in, img_pred, img_target, status, input_t, pred_t, target_t
 
 
 
 
 
 
 
 
 
 
 
253
 
254
 
255
  # --- 5. Gradio UI Layout ---
256
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
257
- # Hidden state variables to store the raw tensors after inference
258
- state_input = gr.State()
259
- state_prediction = gr.State()
260
- state_target = gr.State()
261
-
262
  gr.Markdown(
263
  """
264
- <div align="center">
265
  # ☀️ Surya: Live Model Demo ☀️
266
  ### An Interactive Interface for NASA's Heliophysics Foundation Model
267
  This demo runs the **actual** Surya model on its official validation data for **2014-01-07**.
268
- Click the button to generate a forecast, then use the dropdown to explore the results across different SDO instrument channels.
 
269
  </div>
270
  """
271
  )
272
-
273
- with gr.Row():
274
- channel_selector = gr.Dropdown(
275
- choices=SDO_CHANNELS,
276
- value="aia171",
277
- label="🛰️ Select SDO Instrument Channel",
278
- info="Choose which solar observation channel to visualize."
279
- )
280
- run_button = gr.Button("🔮 Generate Forecast for 2014-01-07", variant="primary", scale=2)
281
-
282
- status_box = gr.Textbox(label="Status", interactive=False, value="Ready. Press 'Generate Forecast' to start.")
283
-
284
  with gr.Row():
285
- with gr.Column():
286
- gr.Markdown("### ⬅️ Final Input Image")
287
- gr.Markdown("The last observation shown to the model (T-1).")
288
- input_display = gr.Image(label="Input", height=512, width=512, interactive=False)
289
- with gr.Column():
290
- gr.Markdown("### 🔮 Model's Forecast")
291
- gr.Markdown("Surya's prediction for the next timestep (T+0).")
292
- prediction_display = gr.Image(label="Prediction", height=512, width=512, interactive=False)
293
- with gr.Column():
294
- gr.Markdown("### ✅ Ground Truth")
295
- gr.Markdown("What the Sun *actually* looked like at T+0.")
296
- label_display = gr.Image(label="Ground Truth", height=512, width=512, interactive=False)
297
-
 
 
298
  # --- Event Handlers ---
299
  run_button.click(
300
  fn=forecast_controller,
301
- inputs=[channel_selector],
302
- outputs=[input_display, prediction_display, label_display, status_box, state_input, state_prediction, state_target]
303
  )
304
 
 
305
  channel_selector.change(
306
- fn=create_visualizations,
307
- inputs=[channel_selector, state_input, state_prediction, state_target],
308
- outputs=[input_display, prediction_display, label_display, status_box]
 
 
 
 
 
309
  )
310
 
311
  if __name__ == "__main__":
312
- # The 'surya_dependencies.py' file must be in the same directory as this script.
313
- # Create the placeholder file if it doesn't exist.
314
- if not os.path.exists("surya_dependencies.py"):
315
- raise FileNotFoundError("The required 'surya_dependencies.py' file is missing. Please download it from the provided source.")
316
  demo.launch(debug=True)
 
1
+ # Save this file as app.py in the root of the cloned Surya repository
2
+
3
  import gradio as gr
4
  import torch
5
  import torch.nn.functional as F
 
13
  import glob
14
  import warnings
15
  import logging
16
+ import matplotlib.pyplot as plt
17
+
18
+ # --- Use the official Surya modules now that we are in the repo ---
19
+ from surya.datasets.helio import HelioNetCDFDataset, inverse_transform_single_channel
20
+ from surya.models.helio_spectformer import HelioSpectFormer
21
+ from surya.utils.data import build_scalers, custom_collate_fn
22
 
23
+ # Suppress verbose logging and warnings for a cleaner UI
24
  warnings.filterwarnings("ignore")
25
  logging.basicConfig(level=logging.INFO)
26
  logger = logging.getLogger(__name__)
27
 
28
+ # --- Global cache to store expensive-to-load objects ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  APP_CACHE = {
30
  "model": None,
31
  "config": None,
32
  "scalers": None,
33
+ "full_results": None, # Will store all prediction results
 
34
  "device": "cuda" if torch.cuda.is_available() else "cpu",
35
  }
36
 
37
+ # SDO channels from the test script for the dropdown menu
38
+ SDO_CHANNELS = [
39
+ "aia94", "aia131", "aia171", "aia193", "aia211", "aia304", "aia335",
40
+ "aia1600", "hmi_m", "hmi_bx", "hmi_by", "hmi_bz", "hmi_v",
41
+ ]
42
+
43
+ # --- 1. Setup, Download, and Model Loading (adapting fixtures from test_surya.py) ---
44
+
45
+ def setup_and_load_model(progress=gr.Progress()):
46
  """
47
+ Handles all initial setup: downloading data, loading configs, and initializing the model.
48
+ This function will populate the APP_CACHE.
 
49
  """
50
+ if APP_CACHE["model"] is not None:
51
+ logger.info("Model and data already loaded. Skipping setup.")
52
+ return
53
+
54
+ # --- Part A: Download data (from download_data fixture) ---
55
+ progress(0.1, desc="Downloading model weights and config...")
56
  snapshot_download(
57
  repo_id="nasa-ibm-ai4science/Surya-1.0",
58
+ local_dir="data/Surya-1.0",
59
  allow_patterns=["config.yaml", "scalers.yaml", "surya.366m.v1.pt"],
60
  )
61
+ progress(0.3, desc="Downloading validation data for 2014-01-07...")
 
 
62
  snapshot_download(
63
  repo_id="nasa-ibm-ai4science/Surya-1.0_validation_data",
64
  repo_type="dataset",
65
+ local_dir="data/Surya-1.0_validation_data",
66
  allow_patterns="20140107_1[5-9]??.nc",
67
  )
68
 
69
+ # --- Part B: Load Config and Scalers (from config & scalers fixtures) ---
70
+ progress(0.5, desc="Loading configuration and data scalers...")
71
+ with open("data/Surya-1.0/config.yaml") as fp:
72
+ config = yaml.safe_load(fp)
73
+ APP_CACHE["config"] = config
74
 
75
+ scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r"))
76
+ APP_CACHE["scalers"] = build_scalers(info=scalers_info)
77
+
78
+ # --- Part C: Initialize and load model (from model fixture and test function) ---
79
+ progress(0.7, desc="Initializing model architecture...")
80
+ model_config = config["model"]
81
+ model = HelioSpectFormer(
82
+ img_size=model_config["img_size"], patch_size=model_config["patch_size"],
83
+ in_chans=len(config["data"]["sdo_channels"]), embed_dim=model_config["embed_dim"],
84
+ time_embedding={"type": "linear", "time_dim": len(config["data"]["time_delta_input_minutes"])},
85
+ depth=model_config["depth"], n_spectral_blocks=model_config["n_spectral_blocks"],
86
+ num_heads=model_config["num_heads"], mlp_ratio=model_config["mlp_ratio"],
87
+ drop_rate=model_config["drop_rate"], dtype=torch.bfloat16,
88
+ window_size=model_config["window_size"], dp_rank=model_config["dp_rank"],
89
+ learned_flow=model_config["learned_flow"], use_latitude_in_learned_flow=model_config["learned_flow"],
90
+ init_weights=False, checkpoint_layers=list(range(model_config["depth"])),
91
+ rpe=model_config["rpe"], ensemble=model_config["ensemble"], finetune=model_config["finetune"],
92
+ )
93
+
94
+ progress(0.8, desc=f"Loading model weights to {APP_CACHE['device']}...")
95
+ path_weights = "data/Surya-1.0/surya.366m.v1.pt"
96
+ weights = torch.load(path_weights, map_location=torch.device(APP_CACHE["device"]))
97
+ model.load_state_dict(weights, strict=True)
98
+ model.to(APP_CACHE["device"])
99
+ model.eval()
100
+
101
+ n_params = sum(p.numel() for p in model.parameters()) / 1e6
102
+ logger.info(f"Surya FM: {n_params:.2f}M parameters loaded.")
103
+ APP_CACHE["model"] = model
104
+
105
+ # --- 2. Inference Logic (adapting the test loop) ---
106
+
107
+ def run_full_forecast():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  """
109
+ Runs inference on the entire validation dataset and stores results.
 
110
  """
111
+ if APP_CACHE["full_results"] is not None:
112
+ return APP_CACHE["full_results"]
113
+
114
  model = APP_CACHE["model"]
115
+ config = APP_CACHE["config"]
116
  device = APP_CACHE["device"]
117
 
118
+ # Create the index file needed by the dataset loader
119
+ os.makedirs("tests", exist_ok=True)
120
+ with open("tests/test_surya_index.csv", "w") as f:
121
+ f.write("path\n")
122
+ search_path = os.path.join("data/Surya-1.0_validation_data", "**", "*.nc")
123
+ for nc_file in sorted(glob.glob(search_path, recursive=True)):
124
+ f.write(f"{nc_file}\n")
125
 
126
+ # Setup dataset and dataloader (from dataset & dataloader fixtures)
127
+ dataset = HelioNetCDFDataset(
128
+ index_path="tests/test_surya_index.csv",
129
+ time_delta_input_minutes=config["data"]["time_delta_input_minutes"],
130
+ time_delta_target_minutes=config["data"]["time_delta_target_minutes"],
131
+ n_input_timestamps=len(config["data"]["time_delta_input_minutes"]),
132
+ rollout_steps=1, channels=config["data"]["sdo_channels"],
133
+ scalers=APP_CACHE["scalers"], phase="valid",
134
+ )
135
+ dataloader = DataLoader(
136
+ dataset, shuffle=False, batch_size=1, num_workers=2,
137
+ pin_memory=True, collate_fn=custom_collate_fn
138
+ )
139
+
140
+ all_results = []
141
  with torch.no_grad():
142
+ for batch_data, batch_metadata in dataloader:
143
+ input_batch = {k: v.to(device) for k, v in batch_data.items() if k in ["ts", "time_delta_input"]}
144
+
145
+ with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16):
146
+ prediction = model(input_batch)
147
+
148
+ # Store the relevant tensors on CPU for later visualization
149
+ result = {
150
+ "input": input_batch["ts"].cpu(),
151
+ "prediction": prediction.cpu(),
152
+ "target": batch_data["forecast"].cpu(),
153
+ "input_timestamp": np.datetime_as_string(batch_metadata["timestamps_input"][0][-1], unit='s'),
154
+ "target_timestamp": np.datetime_as_string(batch_metadata["timestamps_targets"][0][0], unit='s'),
155
+ }
156
+ all_results.append(result)
157
+
158
+ APP_CACHE["full_results"] = all_results
159
+ # Cache scalers needed for visualization
160
+ APP_CACHE["scalers_vis"] = dataset.transformation_inputs()
161
+ return all_results
162
+
163
+ # --- 3. Visualization Logic ---
164
 
165
+ def generate_visualization(results, timestep_index, channel_name):
166
  """
167
+ Generates PIL images for a specific timestep and channel from the results.
 
168
  """
169
+ if not results:
170
+ return None, None, None, "No results available. Please run the forecast.", ""
171
 
172
+ timestep_data = results[timestep_index]
173
  c_idx = SDO_CHANNELS.index(channel_name)
174
+ means, stds, epsilons, sl_scale_factors = APP_CACHE["scalers_vis"]
175
+
176
+ # Denormalize data for visualization
 
 
177
  input_slice = inverse_transform_single_channel(
178
+ timestep_data["input"][0, c_idx, -1].numpy(),
179
  mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx]
180
  )
 
181
  pred_slice = inverse_transform_single_channel(
182
+ timestep_data["prediction"][0, c_idx].numpy(),
183
  mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx]
184
  )
 
185
  target_slice = inverse_transform_single_channel(
186
+ timestep_data["target"][0, c_idx, 0].numpy(),
187
  mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx]
188
  )
189
+
190
+ # Convert to PIL Images using appropriate colormaps
191
+ vmax = np.quantile(target_slice, 0.995)
 
 
 
192
  cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
193
  cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray'))
194
+
195
+ def to_pil(data):
196
+ data_clipped = np.clip(data, 0, vmax)
197
+ data_norm = data_clipped / vmax
198
+ return Image.fromarray((cmap(data_norm)[:, :, :3] * 255).astype(np.uint8)).transpose(Image.Transpose.TRANSPOSE)
199
 
200
+ status_text = (f"Displaying Timestep {timestep_index+1}/{len(results)}\n"
201
+ f"Input: {timestep_data['input_timestamp']} | Forecast/Target: {timestep_data['target_timestamp']}")
 
 
202
 
203
+ return to_pil(input_slice), to_pil(pred_slice), to_pil(target_slice), status_text
204
 
205
  # --- 4. Gradio Controller Functions ---
 
 
 
 
 
 
206
 
207
+ def forecast_controller(progress=gr.Progress(track_tqdm=True)):
208
+ """Main function for the 'Generate Forecast' button."""
209
+ progress(0, desc="Starting setup...")
210
+ setup_and_load_model(progress)
211
 
212
+ logger.info("Running forecast on all validation timesteps...")
213
+ progress(0.9, desc="Running inference on validation data...")
214
+ results = run_full_forecast()
215
+ logger.info(f"Forecast complete. {len(results)} timesteps processed.")
 
216
 
217
+ # Generate the first visualization
218
+ img_in, img_pred, img_target, status = generate_visualization(results, 0, SDO_CHANNELS[2]) # Default to aia171
219
+
220
+ # Update the slider to be interactive and have the correct number of steps
221
+ slider_update = gr.Slider(minimum=1, maximum=len(results), step=1, value=1, interactive=True,
222
+ label="Forecast Timestep")
223
+
224
+ return results, img_in, img_pred, img_target, status, slider_update
225
+
226
+ def update_visualization_controller(results, timestep, channel):
227
+ """Called when a slider or dropdown is changed."""
228
+ return generate_visualization(results, timestep - 1, channel)
229
 
230
 
231
  # --- 5. Gradio UI Layout ---
232
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
233
+ # State object to hold the results of the full inference run
234
+ state_results = gr.State()
235
+
 
 
236
  gr.Markdown(
237
  """
238
+ <div align='center'>
239
  # ☀️ Surya: Live Model Demo ☀️
240
  ### An Interactive Interface for NASA's Heliophysics Foundation Model
241
  This demo runs the **actual** Surya model on its official validation data for **2014-01-07**.
242
+ <br>
243
+ **Instructions:** 1. Click 'Generate Forecast'. 2. Use the controls to explore the results.
244
  </div>
245
  """
246
  )
247
+
 
 
 
 
 
 
 
 
 
 
 
248
  with gr.Row():
249
+ with gr.Column(scale=1):
250
+ run_button = gr.Button("🔮 1. Generate Full Forecast", variant="primary")
251
+ status_box = gr.Textbox(label="Status", interactive=False, value="Ready.", lines=2)
252
+ channel_selector = gr.Dropdown(
253
+ choices=SDO_CHANNELS, value="aia171", label="🛰️ 2. Select SDO Channel"
254
+ )
255
+ timestep_slider = gr.Slider(
256
+ minimum=1, maximum=8, step=1, value=1, interactive=False, label="Forecast Timestep"
257
+ )
258
+ with gr.Column(scale=3):
259
+ with gr.Row():
260
+ input_display = gr.Image(label="Last Input", height=512, width=512, interactive=False)
261
+ prediction_display = gr.Image(label="Model Forecast", height=512, width=512, interactive=False)
262
+ target_display = gr.Image(label="Ground Truth", height=512, width=512, interactive=False)
263
+
264
  # --- Event Handlers ---
265
  run_button.click(
266
  fn=forecast_controller,
267
+ outputs=[state_results, input_display, prediction_display, target_display, status_box, timestep_slider]
 
268
  )
269
 
270
+ # When the user changes the channel or timestep, call the visualization update function
271
  channel_selector.change(
272
+ fn=update_visualization_controller,
273
+ inputs=[state_results, timestep_slider, channel_selector],
274
+ outputs=[input_display, prediction_display, target_display, status_box]
275
+ )
276
+ timestep_slider.change(
277
+ fn=update_visualization_controller,
278
+ inputs=[state_results, timestep_slider, channel_selector],
279
+ outputs=[input_display, prediction_display, target_display, status_box]
280
  )
281
 
282
  if __name__ == "__main__":
 
 
 
 
283
  demo.launch(debug=True)