broadfield-dev commited on
Commit
b1a1b26
·
verified ·
1 Parent(s): 2870322

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +211 -201
app.py CHANGED
@@ -2,81 +2,71 @@
2
 
3
  import gradio as gr
4
  import torch
5
- import torch.nn.functional as F
6
  from torch.utils.data import DataLoader
7
  from huggingface_hub import snapshot_download
8
  import yaml
9
  import numpy as np
10
  from PIL import Image
11
- import sunpy.visualization.colormaps as sunpy_cm
 
 
 
 
 
 
 
12
  import os
13
- import glob
14
  import warnings
15
  import logging
 
16
  import matplotlib.pyplot as plt
 
17
 
18
- # --- Use the official Surya modules now that we are in the repo ---
19
- from surya.datasets.helio import HelioNetCDFDataset, inverse_transform_single_channel
20
  from surya.models.helio_spectformer import HelioSpectFormer
21
- from surya.utils.data import build_scalers, custom_collate_fn
22
 
23
- # Suppress verbose logging and warnings for a cleaner UI
24
  warnings.filterwarnings("ignore")
25
  logging.basicConfig(level=logging.INFO)
26
  logger = logging.getLogger(__name__)
27
 
28
- # --- Global cache to store expensive-to-load objects ---
29
- APP_CACHE = {
30
- "model": None,
31
- "config": None,
32
- "scalers": None,
33
- "full_results": None, # Will store all prediction results
34
- "device": "cuda" if torch.cuda.is_available() else "cpu",
 
 
 
 
 
 
 
 
 
35
  }
 
36
 
37
- # SDO channels from the test script for the dropdown menu
38
- SDO_CHANNELS = [
39
- "aia94", "aia131", "aia171", "aia193", "aia211", "aia304", "aia335",
40
- "aia1600", "hmi_m", "hmi_bx", "hmi_by", "hmi_bz", "hmi_v",
41
- ]
42
-
43
- # --- 1. Setup, Download, and Model Loading (adapting fixtures from test_surya.py) ---
44
-
45
  def setup_and_load_model(progress=gr.Progress()):
46
- """
47
- Handles all initial setup: downloading data, loading configs, and initializing the model.
48
- This function will populate the APP_CACHE.
49
- """
50
- if APP_CACHE["model"] is not None:
51
- logger.info("Model and data already loaded. Skipping setup.")
52
  return
53
 
54
- # --- Part A: Download data (from download_data fixture) ---
55
- progress(0.1, desc="Downloading model weights and config...")
56
- snapshot_download(
57
- repo_id="nasa-ibm-ai4science/Surya-1.0",
58
- local_dir="data/Surya-1.0",
59
- allow_patterns=["config.yaml", "scalers.yaml", "surya.366m.v1.pt"],
60
- )
61
- progress(0.3, desc="Downloading validation data for 2014-01-07...")
62
- snapshot_download(
63
- repo_id="nasa-ibm-ai4science/Surya-1.0_validation_data",
64
- repo_type="dataset",
65
- local_dir="data/Surya-1.0_validation_data",
66
- allow_patterns="20140107_1[5-9]??.nc",
67
- )
68
 
69
- # --- Part B: Load Config and Scalers (from config & scalers fixtures) ---
70
- progress(0.5, desc="Loading configuration and data scalers...")
71
  with open("data/Surya-1.0/config.yaml") as fp:
72
  config = yaml.safe_load(fp)
73
  APP_CACHE["config"] = config
74
-
75
  scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r"))
76
  APP_CACHE["scalers"] = build_scalers(info=scalers_info)
77
-
78
- # --- Part C: Initialize and load model (from model fixture and test function) ---
79
- progress(0.7, desc="Initializing model architecture...")
80
  model_config = config["model"]
81
  model = HelioSpectFormer(
82
  img_size=model_config["img_size"], patch_size=model_config["patch_size"],
@@ -90,194 +80,214 @@ def setup_and_load_model(progress=gr.Progress()):
90
  init_weights=False, checkpoint_layers=list(range(model_config["depth"])),
91
  rpe=model_config["rpe"], ensemble=model_config["ensemble"], finetune=model_config["finetune"],
92
  )
93
-
94
- progress(0.8, desc=f"Loading model weights to {APP_CACHE['device']}...")
95
- path_weights = "data/Surya-1.0/surya.366m.v1.pt"
96
- weights = torch.load(path_weights, map_location=torch.device(APP_CACHE["device"]))
97
  model.load_state_dict(weights, strict=True)
98
- model.to(APP_CACHE["device"])
99
  model.eval()
100
-
101
- n_params = sum(p.numel() for p in model.parameters()) / 1e6
102
- logger.info(f"Surya FM: {n_params:.2f}M parameters loaded.")
103
  APP_CACHE["model"] = model
 
104
 
105
- # --- 2. Inference Logic (adapting the test loop) ---
106
-
107
- def run_full_forecast():
108
- """
109
- Runs inference on the entire validation dataset and stores results.
110
- """
111
- if APP_CACHE["full_results"] is not None:
112
- return APP_CACHE["full_results"]
113
-
114
- model = APP_CACHE["model"]
115
  config = APP_CACHE["config"]
116
- device = APP_CACHE["device"]
117
-
118
- # Create the index file needed by the dataset loader
119
- os.makedirs("tests", exist_ok=True)
120
- with open("tests/test_surya_index.csv", "w") as f:
121
- f.write("path\n")
122
- search_path = os.path.join("data/Surya-1.0_validation_data", "**", "*.nc")
123
- for nc_file in sorted(glob.glob(search_path, recursive=True)):
124
- f.write(f"{nc_file}\n")
125
 
126
- # Setup dataset and dataloader (from dataset & dataloader fixtures)
127
- dataset = HelioNetCDFDataset(
128
- index_path="tests/test_surya_index.csv",
129
- time_delta_input_minutes=config["data"]["time_delta_input_minutes"],
130
- time_delta_target_minutes=config["data"]["time_delta_target_minutes"],
131
- n_input_timestamps=len(config["data"]["time_delta_input_minutes"]),
132
- rollout_steps=1, channels=config["data"]["sdo_channels"],
133
- scalers=APP_CACHE["scalers"], phase="valid",
134
- )
135
- dataloader = DataLoader(
136
- dataset, shuffle=False, batch_size=1, num_workers=2,
137
- pin_memory=True, collate_fn=custom_collate_fn
138
- )
139
-
140
- all_results = []
141
- with torch.no_grad():
142
- for batch_data, batch_metadata in dataloader:
143
- input_batch = {k: v.to(device) for k, v in batch_data.items() if k in ["ts", "time_delta_input"]}
144
 
145
- with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16):
146
- prediction = model(input_batch)
 
 
 
 
 
 
147
 
148
- # Store the relevant tensors on CPU for later visualization
149
- result = {
150
- "input": input_batch["ts"].cpu(),
151
- "prediction": prediction.cpu(),
152
- "target": batch_data["forecast"].cpu(),
153
- "input_timestamp": np.datetime_as_string(batch_metadata["timestamps_input"][0][-1], unit='s'),
154
- "target_timestamp": np.datetime_as_string(batch_metadata["timestamps_targets"][0][0], unit='s'),
155
- }
156
- all_results.append(result)
157
-
158
- APP_CACHE["full_results"] = all_results
159
- # Cache scalers needed for visualization
160
- APP_CACHE["scalers_vis"] = dataset.transformation_inputs()
161
- return all_results
162
-
163
- # --- 3. Visualization Logic ---
164
-
165
- def generate_visualization(results, timestep_index, channel_name):
166
- """
167
- Generates PIL images for a specific timestep and channel from the results.
168
- """
169
- if not results:
170
- return None, None, None, "No results available. Please run the forecast.", ""
171
-
172
- timestep_data = results[timestep_index]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  c_idx = SDO_CHANNELS.index(channel_name)
174
- means, stds, epsilons, sl_scale_factors = APP_CACHE["scalers_vis"]
175
-
176
- # Denormalize data for visualization
177
- input_slice = inverse_transform_single_channel(
178
- timestep_data["input"][0, c_idx, -1].numpy(),
179
- mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx]
180
- )
181
  pred_slice = inverse_transform_single_channel(
182
- timestep_data["prediction"][0, c_idx].numpy(),
183
- mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx]
184
- )
185
- target_slice = inverse_transform_single_channel(
186
- timestep_data["target"][0, c_idx, 0].numpy(),
187
  mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx]
188
  )
189
-
190
- # Convert to PIL Images using appropriate colormaps
191
- vmax = np.quantile(target_slice, 0.995)
192
  cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
193
  cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray'))
194
 
195
- def to_pil(data):
196
- data_clipped = np.clip(data, 0, vmax)
197
- data_norm = data_clipped / vmax
198
- return Image.fromarray((cmap(data_norm)[:, :, :3] * 255).astype(np.uint8)).transpose(Image.Transpose.TRANSPOSE)
199
-
200
- status_text = (f"Displaying Timestep {timestep_index+1}/{len(results)}\n"
201
- f"Input: {timestep_data['input_timestamp']} | Forecast/Target: {timestep_data['target_timestamp']}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
- return to_pil(input_slice), to_pil(pred_slice), to_pil(target_slice), status_text
204
 
205
- # --- 4. Gradio Controller Functions ---
206
-
207
- def forecast_controller(progress=gr.Progress(track_tqdm=True)):
208
- """Main function for the 'Generate Forecast' button."""
209
- progress(0, desc="Starting setup...")
210
- setup_and_load_model(progress)
211
-
212
- logger.info("Running forecast on all validation timesteps...")
213
- progress(0.9, desc="Running inference on validation data...")
214
- results = run_full_forecast()
215
- logger.info(f"Forecast complete. {len(results)} timesteps processed.")
216
-
217
- # Generate the first visualization
218
- img_in, img_pred, img_target, status = generate_visualization(results, 0, SDO_CHANNELS[2]) # Default to aia171
219
-
220
- # Update the slider to be interactive and have the correct number of steps
221
- slider_update = gr.Slider(minimum=1, maximum=len(results), step=1, value=1, interactive=True,
222
- label="Forecast Timestep")
223
-
224
- return results, img_in, img_pred, img_target, status, slider_update
225
-
226
- def update_visualization_controller(results, timestep, channel):
227
- """Called when a slider or dropdown is changed."""
228
- return generate_visualization(results, timestep - 1, channel)
229
-
230
-
231
- # --- 5. Gradio UI Layout ---
232
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
233
- # State object to hold the results of the full inference run
234
- state_results = gr.State()
 
 
235
 
236
  gr.Markdown(
237
  """
238
  <div align='center'>
239
- # ☀️ Surya: Live Model Demo ☀️
240
- ### An Interactive Interface for NASA's Heliophysics Foundation Model
241
- This demo runs the **actual** Surya model on its official validation data for **2014-01-07**.
242
- <br>
243
- **Instructions:** 1. Click 'Generate Forecast'. 2. Use the controls to explore the results.
 
244
  </div>
245
  """
246
  )
247
 
248
  with gr.Row():
249
- with gr.Column(scale=1):
250
- run_button = gr.Button("🔮 1. Generate Full Forecast", variant="primary")
251
- status_box = gr.Textbox(label="Status", interactive=False, value="Ready.", lines=2)
252
- channel_selector = gr.Dropdown(
253
- choices=SDO_CHANNELS, value="aia171", label="🛰️ 2. Select SDO Channel"
254
- )
255
- timestep_slider = gr.Slider(
256
- minimum=1, maximum=8, step=1, value=1, interactive=False, label="Forecast Timestep"
257
- )
258
- with gr.Column(scale=3):
259
- with gr.Row():
260
- input_display = gr.Image(label="Last Input", height=512, width=512, interactive=False)
261
- prediction_display = gr.Image(label="Model Forecast", height=512, width=512, interactive=False)
262
- target_display = gr.Image(label="Ground Truth", height=512, width=512, interactive=False)
263
 
264
- # --- Event Handlers ---
265
  run_button.click(
266
  fn=forecast_controller,
267
- outputs=[state_results, input_display, prediction_display, target_display, status_box, timestep_slider]
 
 
268
  )
269
 
270
- # When the user changes the channel or timestep, call the visualization update function
271
  channel_selector.change(
272
  fn=update_visualization_controller,
273
- inputs=[state_results, timestep_slider, channel_selector],
274
- outputs=[input_display, prediction_display, target_display, status_box]
275
- )
276
- timestep_slider.change(
277
- fn=update_visualization_controller,
278
- inputs=[state_results, timestep_slider, channel_selector],
279
- outputs=[input_display, prediction_display, target_display, status_box]
280
  )
281
 
282
  if __name__ == "__main__":
 
 
283
  demo.launch(debug=True)
 
2
 
3
  import gradio as gr
4
  import torch
 
5
  from torch.utils.data import DataLoader
6
  from huggingface_hub import snapshot_download
7
  import yaml
8
  import numpy as np
9
  from PIL import Image
10
+ import sunpy.map
11
+ import sunpy.net.attrs as a
12
+ from sunpy.net import Fido
13
+ from sunpy.coordinates import Helioprojective
14
+ from astropy.coordinates import SkyCoord
15
+ from astropy.wcs import WCS
16
+ import astropy.units as u
17
+ from reproject import reproject_interp
18
  import os
 
19
  import warnings
20
  import logging
21
+ import datetime
22
  import matplotlib.pyplot as plt
23
+ import sunpy.visualization.colormaps as sunpy_cm
24
 
25
+ # --- Use the official Surya modules ---
 
26
  from surya.models.helio_spectformer import HelioSpectFormer
27
+ from surya.utils.data import build_scalers, inverse_transform_single_channel
28
 
29
+ # --- Configuration ---
30
  warnings.filterwarnings("ignore")
31
  logging.basicConfig(level=logging.INFO)
32
  logger = logging.getLogger(__name__)
33
 
34
+ # Global cache for model, config, etc.
35
+ APP_CACHE = {}
36
+ SDO_CHANNELS_MAP = {
37
+ "aia94": (a.Wavelength(94, 94, "angstrom"), a.Sample(12 * u.s)),
38
+ "aia131": (a.Wavelength(131, 131, "angstrom"), a.Sample(12 * u.s)),
39
+ "aia171": (a.Wavelength(171, 171, "angstrom"), a.Sample(12 * u.s)),
40
+ "aia193": (a.Wavelength(193, 193, "angstrom"), a.Sample(12 * u.s)),
41
+ "aia211": (a.Wavelength(211, 211, "angstrom"), a.Sample(12 * u.s)),
42
+ "aia304": (a.Wavelength(304, 304, "angstrom"), a.Sample(12 * u.s)),
43
+ "aia335": (a.Wavelength(335, 335, "angstrom"), a.Sample(12 * u.s)),
44
+ "aia1600": (a.Wavelength(1600, 1600, "angstrom"), a.Sample(24 * u.s)),
45
+ "hmi_m": (a.Physobs("intensity"), a.Sample(45 * u.s)),
46
+ "hmi_bx": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)),
47
+ "hmi_by": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)), # Placeholder
48
+ "hmi_bz": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)), # Placeholder
49
+ "hmi_v": (a.Physobs("los_velocity"), a.Sample(45 * u.s)),
50
  }
51
+ SDO_CHANNELS = list(SDO_CHANNELS_MAP.keys())
52
 
53
+ # --- 1. Model Loading and Setup ---
 
 
 
 
 
 
 
54
  def setup_and_load_model(progress=gr.Progress()):
55
+ if "model" in APP_CACHE:
 
 
 
 
 
56
  return
57
 
58
+ progress(0.1, desc="Downloading model files (first run only)...")
59
+ snapshot_download(repo_id="nasa-ibm-ai4science/Surya-1.0", local_dir="data/Surya-1.0",
60
+ allow_patterns=["config.yaml", "scalers.yaml", "surya.366m.v1.pt"])
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ progress(0.5, desc="Loading configuration and scalers...")
 
63
  with open("data/Surya-1.0/config.yaml") as fp:
64
  config = yaml.safe_load(fp)
65
  APP_CACHE["config"] = config
 
66
  scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r"))
67
  APP_CACHE["scalers"] = build_scalers(info=scalers_info)
68
+
69
+ progress(0.7, desc="Initializing and loading model...")
 
70
  model_config = config["model"]
71
  model = HelioSpectFormer(
72
  img_size=model_config["img_size"], patch_size=model_config["patch_size"],
 
80
  init_weights=False, checkpoint_layers=list(range(model_config["depth"])),
81
  rpe=model_config["rpe"], ensemble=model_config["ensemble"], finetune=model_config["finetune"],
82
  )
83
+ device = "cuda" if torch.cuda.is_available() else "cpu"
84
+ APP_CACHE["device"] = device
85
+ weights = torch.load(f"data/Surya-1.0/surya.366m.v1.pt", map_location=torch.device(device))
 
86
  model.load_state_dict(weights, strict=True)
87
+ model.to(device)
88
  model.eval()
 
 
 
89
  APP_CACHE["model"] = model
90
+ logger.info("Model setup complete.")
91
 
92
+ # --- 2. Live Data Fetching and Preprocessing ---
93
+ def fetch_and_process_sdo_data(target_dt, progress):
 
 
 
 
 
 
 
 
94
  config = APP_CACHE["config"]
95
+ img_size = config["model"]["img_size"][0]
 
 
 
 
 
 
 
 
96
 
97
+ # Define time windows for input and target (ground truth)
98
+ input_deltas = config["data"]["time_delta_input_minutes"]
99
+ target_delta = config["data"]["time_delta_target_minutes"][0]
100
+ input_times = [target_dt + datetime.timedelta(minutes=m) for m in input_deltas]
101
+ target_time = target_dt + datetime.timedelta(minutes=target_delta)
102
+ all_times = sorted(list(set(input_times + [target_time])))
103
+
104
+ # Download data for all required timestamps
105
+ data_maps = {}
106
+ total_downloads = len(all_times) * len(SDO_CHANNELS_MAP)
107
+ downloads_done = 0
108
+ for t in all_times:
109
+ data_maps[t] = {}
110
+ for i, (channel, (physobs, sample)) in enumerate(SDO_CHANNELS_MAP.items()):
111
+ progress(downloads_done / total_downloads, desc=f"Downloading {channel} for {t.strftime('%H:%M')}...")
 
 
 
112
 
113
+ # HMI vector fields are not standard products, use LoS as a placeholder for demo
114
+ instrument = a.Instrument.hmi if "hmi" in channel else a.Instrument.aia
115
+ if channel in ["hmi_by", "hmi_bz"]:
116
+ if data_maps[t].get("hmi_bx"): data_maps[t][channel] = data_maps[t]["hmi_bx"]
117
+ continue
118
+
119
+ time_attr = a.Time(t - datetime.timedelta(minutes=10), t + datetime.timedelta(minutes=10))
120
+ query = Fido.search(time_attr, a.Instrument.aia, physobs, sample) if "aia" in channel else Fido.search(time_attr, a.Instrument.hmi, physobs, sample)
121
 
122
+ if not query: raise ValueError(f"No data found for {channel} at {t}")
123
+ files = Fido.fetch(query[0, 0], path="./data/sdo_cache")
124
+ data_maps[t][channel] = sunpy.map.Map(files[0])
125
+ downloads_done += 1
126
+
127
+ # Create target WCS for reprojection
128
+ output_wcs = WCS(naxis=2)
129
+ output_wcs.wcs.crpix = [(img_size + 1) / 2, (img_size + 1) / 2]
130
+ output_wcs.wcs.cdelt = np.array([-1.2, 1.2]) * u.arcsec
131
+ output_wcs.wcs.crval = [0, 0] * u.arcsec
132
+ output_wcs.wcs.ctype = ['HPLN-TAN', 'HPLT-TAN']
133
+
134
+ # Process data
135
+ processed_tensors = {}
136
+ for t, channel_maps in data_maps.items():
137
+ channel_tensors = []
138
+ for i, channel in enumerate(SDO_CHANNELS):
139
+ progress(i / len(SDO_CHANNELS), desc=f"Processing {channel} for {t.strftime('%H:%M')}...")
140
+ smap = channel_maps[channel]
141
+
142
+ # Reproject to common grid
143
+ reprojected_data, _ = reproject_interp(smap, output_wcs, shape_out=(img_size, img_size))
144
+
145
+ # Normalize by exposure time and apply signed-log transform
146
+ exp_time = smap.meta.get('exptime', 1.0)
147
+ if exp_time <= 0: exp_time = 1.0
148
+ norm_data = reprojected_data / exp_time
149
+
150
+ # Apply the same scaling as the training pipeline
151
+ scaler = APP_CACHE["scalers"][channel]
152
+ scaled_data = scaler.transform(norm_data)
153
+ channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32)))
154
+ processed_tensors[t] = torch.stack(channel_tensors)
155
+
156
+ # Assemble final input and target tensors
157
+ input_tensor_list = [processed_tensors[t] for t in input_times]
158
+ input_tensor = torch.stack(input_tensor_list, dim=1).unsqueeze(0) # Add batch dim
159
+ target_map = data_maps[target_time] # Return raw map for ground truth vis
160
+ last_input_map = data_maps[input_times[-1]]
161
+
162
+ return input_tensor, last_input_map, target_map
163
+
164
+ # --- 3. Inference and Visualization ---
165
+ def run_inference(input_tensor):
166
+ logger.info("Running model inference...")
167
+ model = APP_CACHE["model"]
168
+ device = APP_CACHE["device"]
169
+
170
+ time_deltas = APP_CACHE["config"]["data"]["time_delta_input_minutes"]
171
+ time_delta_tensor = torch.tensor(time_deltas, dtype=torch.float32).unsqueeze(0).to(device)
172
+
173
+ input_batch = {"ts": input_tensor.to(device), "time_delta_input": time_delta_tensor}
174
+
175
+ with torch.no_grad():
176
+ with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16):
177
+ prediction = model(input_batch)
178
+ logger.info("Inference complete.")
179
+ return prediction.cpu()
180
+
181
+ def generate_visualization(last_input_map, prediction_tensor, target_map, channel_name):
182
+ if last_input_map is None:
183
+ return None, None, None
184
+
185
  c_idx = SDO_CHANNELS.index(channel_name)
186
+
187
+ # Process Prediction
188
+ means, stds, epsilons, sl_scale_factors = APP_CACHE["scalers"][SDO_CHANNELS[0]].get_params()
 
 
 
 
189
  pred_slice = inverse_transform_single_channel(
190
+ prediction_tensor[0, c_idx].numpy(),
 
 
 
 
191
  mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx]
192
  )
193
+
194
+ # Get colormap and normalization
195
+ vmax = np.quantile(target_map[channel_name].data, 0.995)
196
  cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
197
  cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray'))
198
 
199
+ def to_pil(data, flip=False):
200
+ data_clipped = np.nan_to_num(data)
201
+ data_clipped = np.clip(data_clipped, 0, vmax)
202
+ data_norm = data_clipped / vmax if vmax > 0 else data_clipped
203
+ colored = (cmap(data_norm)[:, :, :3] * 255).astype(np.uint8)
204
+ img = Image.fromarray(colored)
205
+ return img.transpose(Image.Transpose.FLIP_TOP_BOTTOM) if flip else img
206
+
207
+ return to_pil(last_input_map[channel_name].data), to_pil(pred_slice, flip=True), to_pil(target_map[channel_name].data)
208
+
209
+
210
+ # --- 4. Gradio UI and Controllers ---
211
+ def forecast_controller(dt_str, progress=gr.Progress(track_tqdm=True)):
212
+ try:
213
+ if not dt_str:
214
+ raise gr.Error("Please select a date and time.")
215
+
216
+ progress(0, desc="Initializing...")
217
+ setup_and_load_model(progress)
218
+
219
+ target_dt = datetime.datetime.fromisoformat(dt_str)
220
+ logger.info(f"Starting forecast for target time: {target_dt}")
221
+
222
+ input_tensor, last_input_map, target_map = fetch_and_process_sdo_data(target_dt, progress)
223
+
224
+ prediction_tensor = run_inference(input_tensor)
225
+
226
+ # Default visualization for aia171
227
+ img_in, img_pred, img_target = generate_visualization(last_input_map, prediction_tensor, target_map, "aia171")
228
+
229
+ status = f"Forecast complete for {target_dt.isoformat()}. Ready to explore channels."
230
+ logger.info(status)
231
+
232
+ return (last_input_map, prediction_tensor, target_map, # state
233
+ img_in, img_pred, img_target, status, gr.update(visible=True))
234
+
235
+ except Exception as e:
236
+ logger.error(f"An error occurred: {e}", exc_info=True)
237
+ raise gr.Error(f"Failed to generate forecast. Error: {e}")
238
+
239
+ def update_visualization_controller(last_input_map, prediction_tensor, target_map, channel_name):
240
+ if last_input_map is None:
241
+ return None, None, None
242
+ return generate_visualization(last_input_map, prediction_tensor, target_map, channel_name)
243
 
 
244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
246
+ # State objects to hold the data after a forecast is run
247
+ state_last_input = gr.State()
248
+ state_prediction = gr.State()
249
+ state_target = gr.State()
250
 
251
  gr.Markdown(
252
  """
253
  <div align='center'>
254
+ # ☀️ Surya: Live Forecast Demo ☀️
255
+ ### Generate a real forecast for any recent date using NASA's Heliophysics Model.
256
+ **Instructions:**
257
+ 1. Pick a date and time (at least 1 hour in the past).
258
+ 2. Click 'Generate Forecast'. **This will be slow (5-15 minutes) as it downloads live data.**
259
+ 3. Once complete, select different channels to explore the multi-spectrum forecast.
260
  </div>
261
  """
262
  )
263
 
264
  with gr.Row():
265
+ datetime_input = gr.Textbox(label="Enter Forecast Start Time (YYYY-MM-DD HH:MM:SS)",
266
+ value=(datetime.datetime.now() - datetime.timedelta(hours=2)).strftime("%Y-%m-%d %H:%M:%S"))
267
+ run_button = gr.Button("🔮 Generate Forecast", variant="primary")
268
+
269
+ with gr.Group(visible=False) as results_group:
270
+ status_box = gr.Textbox(label="Status", interactive=False)
271
+ channel_selector = gr.Dropdown(choices=SDO_CHANNELS, value="aia171", label="🛰️ Select SDO Channel")
272
+ with gr.Row():
273
+ input_display = gr.Image(label="Last Input to Model", height=512, width=512, interactive=False)
274
+ prediction_display = gr.Image(label="Surya's Forecast", height=512, width=512, interactive=False)
275
+ target_display = gr.Image(label="Ground Truth", height=512, width=512, interactive=False)
 
 
 
276
 
 
277
  run_button.click(
278
  fn=forecast_controller,
279
+ inputs=[datetime_input],
280
+ outputs=[state_last_input, state_prediction, state_target,
281
+ input_display, prediction_display, target_display, status_box, results_group]
282
  )
283
 
 
284
  channel_selector.change(
285
  fn=update_visualization_controller,
286
+ inputs=[state_last_input, state_prediction, state_target, channel_selector],
287
+ outputs=[input_display, prediction_display, target_display]
 
 
 
 
 
288
  )
289
 
290
  if __name__ == "__main__":
291
+ # Create cache directory if it doesn't exist
292
+ os.makedirs("./data/sdo_cache", exist_ok=True)
293
  demo.launch(debug=True)