broadfield-dev commited on
Commit
d4d895f
·
verified ·
1 Parent(s): bf136f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -63
app.py CHANGED
@@ -18,6 +18,7 @@ import logging
18
  import datetime
19
  import matplotlib.pyplot as plt
20
  import sunpy.visualization.colormaps as sunpy_cm
 
21
 
22
  # --- Use the official Surya modules ---
23
  from surya.models.helio_spectformer import HelioSpectFormer
@@ -33,7 +34,6 @@ logger = logging.getLogger(__name__)
33
  # Global cache for model, config, etc.
34
  APP_CACHE = {}
35
 
36
- # *** FIX: Corrected the a.Wavelength calls to use astropy units ***
37
  SDO_CHANNELS_MAP = {
38
  "aia94": (a.Wavelength(94 * u.angstrom), a.Sample(12 * u.s)),
39
  "aia131": (a.Wavelength(131 * u.angstrom), a.Sample(12 * u.s)),
@@ -52,22 +52,23 @@ SDO_CHANNELS_MAP = {
52
  SDO_CHANNELS = list(SDO_CHANNELS_MAP.keys())
53
 
54
  # --- 1. Model Loading and Setup ---
55
- def setup_and_load_model(progress=gr.Progress()):
56
  if "model" in APP_CACHE:
 
57
  return
58
 
59
- progress(0.1, desc="Downloading model files (first run only)...")
60
  snapshot_download(repo_id="nasa-ibm-ai4science/Surya-1.0", local_dir="data/Surya-1.0",
61
  allow_patterns=["config.yaml", "scalers.yaml", "surya.366m.v1.pt"])
62
 
63
- progress(0.5, desc="Loading configuration and scalers...")
64
  with open("data/Surya-1.0/config.yaml") as fp:
65
  config = yaml.safe_load(fp)
66
  APP_CACHE["config"] = config
67
  scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r"))
68
  APP_CACHE["scalers"] = build_scalers(info=scalers_info)
69
 
70
- progress(0.7, desc="Initializing and loading model...")
71
  model_config = config["model"]
72
  model = HelioSpectFormer(
73
  img_size=model_config["img_size"], patch_size=model_config["patch_size"],
@@ -83,15 +84,17 @@ def setup_and_load_model(progress=gr.Progress()):
83
  )
84
  device = "cuda" if torch.cuda.is_available() else "cpu"
85
  APP_CACHE["device"] = device
 
 
86
  weights = torch.load(f"data/Surya-1.0/surya.366m.v1.pt", map_location=torch.device(device))
87
  model.load_state_dict(weights, strict=True)
88
  model.to(device)
89
  model.eval()
90
  APP_CACHE["model"] = model
91
- logger.info("Model setup complete.")
92
 
93
- # --- 2. Live Data Fetching and Preprocessing ---
94
- def fetch_and_process_sdo_data(target_dt, progress):
95
  config = APP_CACHE["config"]
96
  img_size = config["model"]["img_size"][0]
97
 
@@ -104,11 +107,12 @@ def fetch_and_process_sdo_data(target_dt, progress):
104
  data_maps = {}
105
  total_downloads = len(all_times) * len(SDO_CHANNELS)
106
  downloads_done = 0
 
107
  for t in all_times:
108
  data_maps[t] = {}
109
  for i, (channel, (physobs, sample)) in enumerate(SDO_CHANNELS_MAP.items()):
110
  downloads_done += 1
111
- progress(downloads_done / total_downloads, desc=f"Downloading {channel} for {t.strftime('%H:%M')}...")
112
 
113
  if channel in ["hmi_by", "hmi_bz"]:
114
  if data_maps[t].get("hmi_bx"): data_maps[t][channel] = data_maps[t]["hmi_bx"]
@@ -122,12 +126,14 @@ def fetch_and_process_sdo_data(target_dt, progress):
122
  files = Fido.fetch(query[0, 0], path="./data/sdo_cache")
123
  data_maps[t][channel] = sunpy.map.Map(files[0])
124
 
 
125
  output_wcs = WCS(naxis=2)
126
  output_wcs.wcs.crpix = [(img_size + 1) / 2, (img_size + 1) / 2]
127
  output_wcs.wcs.cdelt = np.array([-1.2, 1.2]) * u.arcsec
128
  output_wcs.wcs.crval = [0, 0] * u.arcsec
129
  output_wcs.wcs.ctype = ['HPLN-TAN', 'HPLT-TAN']
130
 
 
131
  processed_tensors = {}
132
  for t, channel_maps in data_maps.items():
133
  channel_tensors = []
@@ -139,48 +145,45 @@ def fetch_and_process_sdo_data(target_dt, progress):
139
  if exp_time is None or exp_time <= 0: exp_time = 1.0
140
  norm_data = reprojected_data / exp_time
141
 
142
- scaler = APP_CACHE["scalers"][channel]
143
- scaled_data = scaler.transform(norm_data)
144
  channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32)))
145
  processed_tensors[t] = torch.stack(channel_tensors)
146
 
 
147
  input_tensor_list = [processed_tensors[t] for t in input_times]
148
  input_tensor = torch.stack(input_tensor_list, dim=1).unsqueeze(0)
149
  target_map = data_maps[target_time]
150
  last_input_map = data_maps[input_times[-1]]
151
 
152
- return input_tensor, last_input_map, target_map
 
 
153
 
154
  # --- 3. Inference and Visualization ---
 
155
  def run_inference(input_tensor):
156
- logger.info("Running model inference...")
157
  model = APP_CACHE["model"]
158
  device = APP_CACHE["device"]
159
-
160
  time_deltas = APP_CACHE["config"]["data"]["time_delta_input_minutes"]
161
  time_delta_tensor = torch.tensor(time_deltas, dtype=torch.float32).unsqueeze(0).to(device)
162
-
163
  input_batch = {"ts": input_tensor.to(device), "time_delta_input": time_delta_tensor}
164
-
165
  with torch.no_grad():
166
  with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16):
167
  prediction = model(input_batch)
168
- logger.info("Inference complete.")
169
  return prediction.cpu()
170
 
171
  def generate_visualization(last_input_map, prediction_tensor, target_map, channel_name):
172
  if last_input_map is None: return None, None, None
173
-
174
  c_idx = SDO_CHANNELS.index(channel_name)
175
- means, stds, epsilons, sl_scale_factors = APP_CACHE["scalers"][channel_name].get_params()
 
 
176
  pred_slice = inverse_transform_single_channel(
177
- prediction_tensor[0, c_idx].numpy(), mean=means, std=stds, epsilon=epsilons, sl_scale_factor=sl_scale_factors
178
  )
179
-
180
  vmax = np.quantile(np.nan_to_num(target_map[channel_name].data), 0.995)
181
  cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
182
  cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray'))
183
-
184
  def to_pil(data, flip=False):
185
  data_clipped = np.nan_to_num(data)
186
  data_clipped = np.clip(data_clipped, 0, vmax)
@@ -188,79 +191,115 @@ def generate_visualization(last_input_map, prediction_tensor, target_map, channe
188
  colored = (cmap(data_norm)[:, :, :3] * 255).astype(np.uint8)
189
  img = Image.fromarray(colored)
190
  return img.transpose(Image.Transpose.FLIP_TOP_BOTTOM) if flip else img
191
-
192
- return to_pil(last_input_map[channel_name].data), to_pil(pred_slice, flip=True), to_pil(target_map[channel_name].data)
193
 
194
  # --- 4. Gradio UI and Controllers ---
195
- def forecast_controller(dt_str, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
 
 
196
  try:
197
  if not dt_str: raise gr.Error("Please select a date and time.")
198
 
199
- progress(0, desc="Initializing...")
200
- setup_and_load_model(progress)
201
-
 
 
202
  target_dt = datetime.datetime.fromisoformat(dt_str)
203
- logger.info(f"Starting forecast for target time: {target_dt}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
- input_tensor, last_input_map, target_map = fetch_and_process_sdo_data(target_dt, progress)
 
206
  prediction_tensor = run_inference(input_tensor)
207
 
 
 
208
  img_in, img_pred, img_target = generate_visualization(last_input_map, prediction_tensor, target_map, "aia171")
209
 
210
- status = f"Forecast complete for {target_dt.isoformat()}. Ready to explore channels."
211
- logger.info(status)
212
-
213
- return (last_input_map, prediction_tensor, target_map,
214
- img_in, img_pred, img_target, status, gr.update(visible=True))
215
-
 
 
 
 
 
 
 
216
  except Exception as e:
217
- logger.error(f"An error occurred: {e}", exc_info=True)
218
- raise gr.Error(f"Failed to generate forecast. Error: {e}")
 
 
 
 
 
 
 
 
219
 
220
- def update_visualization_controller(last_input_map, prediction_tensor, target_map, channel_name):
221
- if last_input_map is None: return None, None, None
222
- return generate_visualization(last_input_map, prediction_tensor, target_map, channel_name)
223
 
224
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
225
  state_last_input = gr.State()
226
  state_prediction = gr.State()
227
  state_target = gr.State()
228
 
229
- gr.Markdown(
230
- """
231
- <div align='center'>
232
- # ☀️ Surya: Live Forecast Demo ☀️
233
- ### Generate a real forecast for any recent date using NASA's Heliophysics Model.
234
- **Instructions:**
235
- 1. Pick a date and time (at least 3 hours in the past).
236
- 2. Click 'Generate Forecast'. **This will be slow (5-15 minutes) as it downloads live data.**
237
- 3. Once complete, select different channels to explore the multi-spectrum forecast.
238
- </div>
239
- """
240
- )
241
 
242
  with gr.Row():
243
- datetime_input = gr.Textbox(label="Enter Forecast Start Time (YYYY-MM-DD HH:MM:SS)",
244
- value=(datetime.datetime.now() - datetime.timedelta(hours=3)).strftime("%Y-%m-%d %H:%M:%S"))
245
  run_button = gr.Button("🔮 Generate Forecast", variant="primary")
246
 
 
 
 
247
  with gr.Group(visible=False) as results_group:
248
- status_box = gr.Textbox(label="Status", interactive=False)
249
- channel_selector = gr.Dropdown(choices=SDO_CHANNELS, value="aia171", label="🛰️ Select SDO Channel")
250
  with gr.Row():
251
- input_display = gr.Image(label="Last Input to Model", height=512, width=512, interactive=False)
252
- prediction_display = gr.Image(label="Surya's Forecast", height=512, width=512, interactive=False)
253
- target_display = gr.Image(label="Ground Truth", height=512, width=512, interactive=False)
254
 
 
 
255
  run_button.click(
256
  fn=forecast_controller,
257
  inputs=[datetime_input],
258
- outputs=[state_last_input, state_prediction, state_target,
259
- input_display, prediction_display, target_display, status_box, results_group]
 
 
 
260
  )
261
 
262
  channel_selector.change(
263
- fn=update_visualization_controller,
264
  inputs=[state_last_input, state_prediction, state_target, channel_selector],
265
  outputs=[input_display, prediction_display, target_display]
266
  )
 
18
  import datetime
19
  import matplotlib.pyplot as plt
20
  import sunpy.visualization.colormaps as sunpy_cm
21
+ import traceback
22
 
23
  # --- Use the official Surya modules ---
24
  from surya.models.helio_spectformer import HelioSpectFormer
 
34
  # Global cache for model, config, etc.
35
  APP_CACHE = {}
36
 
 
37
  SDO_CHANNELS_MAP = {
38
  "aia94": (a.Wavelength(94 * u.angstrom), a.Sample(12 * u.s)),
39
  "aia131": (a.Wavelength(131 * u.angstrom), a.Sample(12 * u.s)),
 
52
  SDO_CHANNELS = list(SDO_CHANNELS_MAP.keys())
53
 
54
  # --- 1. Model Loading and Setup ---
55
+ def setup_and_load_model():
56
  if "model" in APP_CACHE:
57
+ yield "Model already loaded. Skipping setup."
58
  return
59
 
60
+ yield "Downloading model files (first run only)..."
61
  snapshot_download(repo_id="nasa-ibm-ai4science/Surya-1.0", local_dir="data/Surya-1.0",
62
  allow_patterns=["config.yaml", "scalers.yaml", "surya.366m.v1.pt"])
63
 
64
+ yield "Loading configuration and data scalers..."
65
  with open("data/Surya-1.0/config.yaml") as fp:
66
  config = yaml.safe_load(fp)
67
  APP_CACHE["config"] = config
68
  scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r"))
69
  APP_CACHE["scalers"] = build_scalers(info=scalers_info)
70
 
71
+ yield "Initializing model architecture..."
72
  model_config = config["model"]
73
  model = HelioSpectFormer(
74
  img_size=model_config["img_size"], patch_size=model_config["patch_size"],
 
84
  )
85
  device = "cuda" if torch.cuda.is_available() else "cpu"
86
  APP_CACHE["device"] = device
87
+
88
+ yield f"Loading model weights to {device}..."
89
  weights = torch.load(f"data/Surya-1.0/surya.366m.v1.pt", map_location=torch.device(device))
90
  model.load_state_dict(weights, strict=True)
91
  model.to(device)
92
  model.eval()
93
  APP_CACHE["model"] = model
94
+ yield "Model setup complete."
95
 
96
+ # --- 2. Live Data Fetching and Preprocessing (as a generator) ---
97
+ def fetch_and_process_sdo_data(target_dt):
98
  config = APP_CACHE["config"]
99
  img_size = config["model"]["img_size"][0]
100
 
 
107
  data_maps = {}
108
  total_downloads = len(all_times) * len(SDO_CHANNELS)
109
  downloads_done = 0
110
+ yield f"Starting download of {total_downloads} data files..."
111
  for t in all_times:
112
  data_maps[t] = {}
113
  for i, (channel, (physobs, sample)) in enumerate(SDO_CHANNELS_MAP.items()):
114
  downloads_done += 1
115
+ yield f"Downloading [{downloads_done}/{total_downloads}]: {channel} for {t.strftime('%Y-%m-%d %H:%M')}..."
116
 
117
  if channel in ["hmi_by", "hmi_bz"]:
118
  if data_maps[t].get("hmi_bx"): data_maps[t][channel] = data_maps[t]["hmi_bx"]
 
126
  files = Fido.fetch(query[0, 0], path="./data/sdo_cache")
127
  data_maps[t][channel] = sunpy.map.Map(files[0])
128
 
129
+ yield "✅ All files downloaded. Starting preprocessing..."
130
  output_wcs = WCS(naxis=2)
131
  output_wcs.wcs.crpix = [(img_size + 1) / 2, (img_size + 1) / 2]
132
  output_wcs.wcs.cdelt = np.array([-1.2, 1.2]) * u.arcsec
133
  output_wcs.wcs.crval = [0, 0] * u.arcsec
134
  output_wcs.wcs.ctype = ['HPLN-TAN', 'HPLT-TAN']
135
 
136
+ scaler = APP_CACHE["scalers"]
137
  processed_tensors = {}
138
  for t, channel_maps in data_maps.items():
139
  channel_tensors = []
 
145
  if exp_time is None or exp_time <= 0: exp_time = 1.0
146
  norm_data = reprojected_data / exp_time
147
 
148
+ scaled_data = scaler.transform(norm_data, c_idx=i)
 
149
  channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32)))
150
  processed_tensors[t] = torch.stack(channel_tensors)
151
 
152
+ yield "✅ Preprocessing complete."
153
  input_tensor_list = [processed_tensors[t] for t in input_times]
154
  input_tensor = torch.stack(input_tensor_list, dim=1).unsqueeze(0)
155
  target_map = data_maps[target_time]
156
  last_input_map = data_maps[input_times[-1]]
157
 
158
+ # The final yield of a generator is its return value
159
+ yield (input_tensor, last_input_map, target_map)
160
+
161
 
162
  # --- 3. Inference and Visualization ---
163
+ # (These are fast and don't need to be generators)
164
  def run_inference(input_tensor):
 
165
  model = APP_CACHE["model"]
166
  device = APP_CACHE["device"]
 
167
  time_deltas = APP_CACHE["config"]["data"]["time_delta_input_minutes"]
168
  time_delta_tensor = torch.tensor(time_deltas, dtype=torch.float32).unsqueeze(0).to(device)
 
169
  input_batch = {"ts": input_tensor.to(device), "time_delta_input": time_delta_tensor}
 
170
  with torch.no_grad():
171
  with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16):
172
  prediction = model(input_batch)
 
173
  return prediction.cpu()
174
 
175
  def generate_visualization(last_input_map, prediction_tensor, target_map, channel_name):
176
  if last_input_map is None: return None, None, None
 
177
  c_idx = SDO_CHANNELS.index(channel_name)
178
+ scaler = APP_CACHE["scalers"]
179
+ all_means, all_stds, all_epsilons, all_sl_scale_factors = scaler.get_params()
180
+ mean, std, epsilon, sl_scale_factor = all_means[c_idx], all_stds[c_idx], all_epsilons[c_idx], all_sl_scale_factors[c_idx]
181
  pred_slice = inverse_transform_single_channel(
182
+ prediction_tensor[0, c_idx].numpy(), mean=mean, std=std, epsilon=epsilon, sl_scale_factor=sl_scale_factor
183
  )
 
184
  vmax = np.quantile(np.nan_to_num(target_map[channel_name].data), 0.995)
185
  cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
186
  cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray'))
 
187
  def to_pil(data, flip=False):
188
  data_clipped = np.nan_to_num(data)
189
  data_clipped = np.clip(data_clipped, 0, vmax)
 
191
  colored = (cmap(data_norm)[:, :, :3] * 255).astype(np.uint8)
192
  img = Image.fromarray(colored)
193
  return img.transpose(Image.Transpose.FLIP_TOP_BOTTOM) if flip else img
194
+ return to_pil(last_input_map[channel_name].data, flip=True), to_pil(pred_slice, flip=True), to_pil(target_map[channel_name].data, flip=True)
 
195
 
196
  # --- 4. Gradio UI and Controllers ---
197
+ def forecast_controller(dt_str):
198
+ # This is now a generator function that yields updates to the UI
199
+
200
+ # Initial UI state: disable inputs, clear old results
201
+ yield {
202
+ log_box: gr.update(value="Starting forecast...", visible=True),
203
+ run_button: gr.update(interactive=False),
204
+ datetime_input: gr.update(interactive=False),
205
+ results_group: gr.update(visible=False)
206
+ }
207
+
208
  try:
209
  if not dt_str: raise gr.Error("Please select a date and time.")
210
 
211
+ # --- Stage 1: Setup Model ---
212
+ # The setup function is also a generator, so we loop through its yields
213
+ for status in setup_and_load_model():
214
+ yield { log_box: status }
215
+
216
  target_dt = datetime.datetime.fromisoformat(dt_str)
217
+
218
+ # --- Stage 2: Fetch and Process Data ---
219
+ # We loop through the yields from the data pipeline
220
+ data_pipeline = fetch_and_process_sdo_data(target_dt)
221
+ while True:
222
+ try:
223
+ # Get the next status update
224
+ status = next(data_pipeline)
225
+ # If it's a tuple, it's the final return value
226
+ if isinstance(status, tuple):
227
+ input_tensor, last_input_map, target_map = status
228
+ break
229
+ # Otherwise, it's a string update
230
+ yield { log_box: status }
231
+ except StopIteration:
232
+ raise gr.Error("Data processing pipeline finished unexpectedly.")
233
 
234
+ # --- Stage 3: Run Inference ---
235
+ yield { log_box: "Running AI model inference..." }
236
  prediction_tensor = run_inference(input_tensor)
237
 
238
+ # --- Stage 4: Generate Visualization ---
239
+ yield { log_box: "Generating final visualizations..." }
240
  img_in, img_pred, img_target = generate_visualization(last_input_map, prediction_tensor, target_map, "aia171")
241
 
242
+ yield {
243
+ log_box: f"✅ Forecast complete for {target_dt.isoformat()}.",
244
+ results_group: gr.update(visible=True),
245
+ # Pass final data to state objects
246
+ state_last_input: last_input_map,
247
+ state_prediction: prediction_tensor,
248
+ state_target: target_map,
249
+ # Display final images
250
+ input_display: img_in,
251
+ prediction_display: img_pred,
252
+ target_display: img_target,
253
+ }
254
+
255
  except Exception as e:
256
+ error_str = traceback.format_exc()
257
+ logger.error(f"An error occurred: {e}\n{error_str}")
258
+ yield { log_box: f"❌ ERROR: {e}\n\nTraceback:\n{error_str}" }
259
+
260
+ finally:
261
+ # Final UI state: re-enable inputs
262
+ yield {
263
+ run_button: gr.update(interactive=True),
264
+ datetime_input: gr.update(interactive=True)
265
+ }
266
 
 
 
 
267
 
268
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
269
  state_last_input = gr.State()
270
  state_prediction = gr.State()
271
  state_target = gr.State()
272
 
273
+ gr.Markdown(...) # UI definition is the same
 
 
 
 
 
 
 
 
 
 
 
274
 
275
  with gr.Row():
276
+ datetime_input = gr.Textbox(...)
 
277
  run_button = gr.Button("🔮 Generate Forecast", variant="primary")
278
 
279
+ # NEW: A dedicated box for logs and feedback
280
+ log_box = gr.Textbox(label="Log", interactive=False, visible=False, lines=5)
281
+
282
  with gr.Group(visible=False) as results_group:
283
+ channel_selector = gr.Dropdown(...)
 
284
  with gr.Row():
285
+ input_display = gr.Image(...)
286
+ prediction_display = gr.Image(...)
287
+ target_display = gr.Image(...)
288
 
289
+ # The .click() event is now pointed to our generator function
290
+ # It updates multiple components based on what the generator yields
291
  run_button.click(
292
  fn=forecast_controller,
293
  inputs=[datetime_input],
294
+ outputs=[
295
+ log_box, run_button, datetime_input, results_group,
296
+ state_last_input, state_prediction, state_target,
297
+ input_display, prediction_display, target_display
298
+ ]
299
  )
300
 
301
  channel_selector.change(
302
+ fn=generate_visualization, # This is a fast function, no generator needed
303
  inputs=[state_last_input, state_prediction, state_target, channel_selector],
304
  outputs=[input_display, prediction_display, target_display]
305
  )