broadfield-dev commited on
Commit
9f9c5a3
·
verified ·
1 Parent(s): 9013587

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -42
app.py CHANGED
@@ -1,5 +1,3 @@
1
- # Save this file as in the root of the cloned Surya repository
2
-
3
  import gradio as gr
4
  import torch
5
  from huggingface_hub import snapshot_download
@@ -20,18 +18,15 @@ import matplotlib.pyplot as plt
20
  import sunpy.visualization.colormaps as sunpy_cm
21
  import traceback
22
 
23
- # --- Use the official Surya modules ---
24
  from surya.models.helio_spectformer import HelioSpectFormer
25
  from surya.utils.data import build_scalers
26
  from surya.datasets.helio import inverse_transform_single_channel
27
 
28
- # --- Configuration ---
29
  warnings.filterwarnings("ignore", category=UserWarning, module='sunpy')
30
  warnings.filterwarnings("ignore", category=FutureWarning)
31
  logging.basicConfig(level=logging.INFO)
32
  logger = logging.getLogger(__name__)
33
 
34
- # Global cache for model, config, etc.
35
  APP_CACHE = {}
36
 
37
  SDO_CHANNELS_MAP = {
@@ -45,13 +40,12 @@ SDO_CHANNELS_MAP = {
45
  "aia1600": (a.Wavelength(1600 * u.angstrom), a.Sample(24 * u.s)),
46
  "hmi_m": (a.Physobs("intensity"), a.Sample(45 * u.s)),
47
  "hmi_bx": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)),
48
- "hmi_by": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)), # Placeholder
49
- "hmi_bz": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)), # Placeholder
50
  "hmi_v": (a.Physobs("los_velocity"), a.Sample(45 * u.s)),
51
  }
52
  SDO_CHANNELS = list(SDO_CHANNELS_MAP.keys())
53
 
54
- # --- 1. Model Loading and Setup ---
55
  def setup_and_load_model():
56
  if "model" in APP_CACHE:
57
  yield "Model already loaded. Skipping setup."
@@ -70,7 +64,18 @@ def setup_and_load_model():
70
 
71
  yield "Initializing model architecture..."
72
  model_config = config["model"]
73
- model = HelioSpectFormer(...) # Full model definition
 
 
 
 
 
 
 
 
 
 
 
74
  device = "cuda" if torch.cuda.is_available() else "cpu"
75
  APP_CACHE["device"] = device
76
 
@@ -82,13 +87,12 @@ def setup_and_load_model():
82
  APP_CACHE["model"] = model
83
  yield "✅ Model setup complete."
84
 
85
- # --- 2. Live Data Fetching and Preprocessing (as a generator) ---
86
  def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
87
  config = APP_CACHE["config"]
88
  img_size = config["model"]["img_size"]
89
 
90
  input_deltas = config["data"]["time_delta_input_minutes"]
91
- target_delta = forecast_horizon_minutes # Use user-provided horizon
92
  input_times = [target_dt + datetime.timedelta(minutes=m) for m in input_deltas]
93
  target_time = target_dt + datetime.timedelta(minutes=target_delta)
94
  all_times = sorted(list(set(input_times + [target_time])))
@@ -107,12 +111,11 @@ def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
107
  if data_maps[t].get("hmi_bx"): data_maps[t][channel] = data_maps[t]["hmi_bx"]
108
  continue
109
 
110
- # *** FIX: Use a.Time.nearest=True for robust fetching instead of a time window ***
111
  instrument = a.Instrument.hmi if "hmi" in channel else a.Instrument.aia
112
  query = Fido.search(a.Time(t), instrument, physobs, sample, a.Time.nearest==True)
113
 
114
  if not query: raise ValueError(f"No data found for {channel} near {t}")
115
- files = Fido.fetch(query, path="./data/sdo_cache") # Fetch the entire result
116
  data_maps[t][channel] = sunpy.map.Map(files[0])
117
 
118
  yield "✅ All files downloaded. Starting preprocessing..."
@@ -146,22 +149,42 @@ def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
146
 
147
  yield (input_tensor, last_input_map, target_map)
148
 
149
-
150
- # --- 3. Inference and Visualization ---
151
  def run_inference(input_tensor):
152
- # This function remains the same
153
- ...
 
 
 
 
 
 
 
154
 
155
  def generate_visualization(last_input_map, prediction_tensor, target_map, channel_name):
156
- # This function remains the same
157
- ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
- # --- 4. Gradio UI and Controllers ---
160
  def forecast_controller(date_str, hour, minute, forecast_horizon):
161
  yield {
162
  log_box: gr.update(value="Starting forecast...", visible=True),
163
  run_button: gr.update(interactive=False),
164
- # Also disable the other controls
165
  date_input: gr.update(interactive=False),
166
  hour_slider: gr.update(interactive=False),
167
  minute_slider: gr.update(interactive=False),
@@ -175,11 +198,9 @@ def forecast_controller(date_str, hour, minute, forecast_horizon):
175
  for status in setup_and_load_model():
176
  yield { log_box: status }
177
 
178
- # Construct datetime from the new UI components
179
  target_dt = datetime.datetime.fromisoformat(f"{date_str}T{int(hour):02d}:{int(minute):02d}:00")
180
 
181
  data_pipeline = fetch_and_process_sdo_data(target_dt, forecast_horizon)
182
- # The rest of the generator logic remains the same...
183
  while True:
184
  try:
185
  status = next(data_pipeline)
@@ -199,14 +220,20 @@ def forecast_controller(date_str, hour, minute, forecast_horizon):
199
  yield {
200
  log_box: f"✅ Forecast complete for {target_dt.isoformat()} (+{forecast_horizon} mins).",
201
  results_group: gr.update(visible=True),
202
- # ... update states and images
 
 
 
 
 
203
  }
204
 
205
  except Exception as e:
206
- # ... error handling
 
 
207
 
208
  finally:
209
- # Re-enable all controls
210
  yield {
211
  run_button: gr.update(interactive=True),
212
  date_input: gr.update(interactive=True),
@@ -215,14 +242,24 @@ def forecast_controller(date_str, hour, minute, forecast_horizon):
215
  horizon_slider: gr.update(interactive=True),
216
  }
217
 
218
- # --- 5. Gradio UI Definition ---
219
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
220
- # State objects remain the same
221
- ...
 
222
 
223
- gr.Markdown(...) # Title remains the same
 
 
 
 
 
 
 
 
 
 
 
224
 
225
- # --- NEW: Controls Section ---
226
  with gr.Accordion("Step 1: Configure Forecast", open=True):
227
  with gr.Row():
228
  date_input = gr.Textbox(
@@ -238,20 +275,19 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
238
 
239
  run_button = gr.Button("🔮 Generate Forecast", variant="primary")
240
 
241
- # --- NEW: Moved log box to its own section ---
242
  with gr.Accordion("Step 2: View Log", open=False) as log_accordion:
243
- log_box = gr.Textbox(label="Log", interactive=False, visible=True, lines=5, max_lines=10)
244
 
245
- # --- Results section is now Step 3 ---
246
  with gr.Group(visible=False) as results_group:
247
  gr.Markdown("### Step 3: Explore Results")
248
- channel_selector = gr.Dropdown(...)
 
 
249
  with gr.Row():
250
- input_display = gr.Image(...)
251
- prediction_display = gr.Image(...)
252
- target_display = gr.Image(...)
253
 
254
- # --- Event Handlers ---
255
  run_button.click(
256
  fn=forecast_controller,
257
  inputs=[date_input, hour_slider, minute_slider, horizon_slider],
@@ -262,9 +298,12 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
262
  ]
263
  )
264
 
265
- channel_selector.change(...) # This remains the same
 
 
 
 
266
 
267
  if __name__ == "__main__":
268
- # Fill in the missing ... from previous versions for the full script
269
- # This is a condensed version showing only the key changes
270
  demo.launch(debug=True)
 
 
 
1
  import gradio as gr
2
  import torch
3
  from huggingface_hub import snapshot_download
 
18
  import sunpy.visualization.colormaps as sunpy_cm
19
  import traceback
20
 
 
21
  from surya.models.helio_spectformer import HelioSpectFormer
22
  from surya.utils.data import build_scalers
23
  from surya.datasets.helio import inverse_transform_single_channel
24
 
 
25
  warnings.filterwarnings("ignore", category=UserWarning, module='sunpy')
26
  warnings.filterwarnings("ignore", category=FutureWarning)
27
  logging.basicConfig(level=logging.INFO)
28
  logger = logging.getLogger(__name__)
29
 
 
30
  APP_CACHE = {}
31
 
32
  SDO_CHANNELS_MAP = {
 
40
  "aia1600": (a.Wavelength(1600 * u.angstrom), a.Sample(24 * u.s)),
41
  "hmi_m": (a.Physobs("intensity"), a.Sample(45 * u.s)),
42
  "hmi_bx": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)),
43
+ "hmi_by": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)),
44
+ "hmi_bz": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)),
45
  "hmi_v": (a.Physobs("los_velocity"), a.Sample(45 * u.s)),
46
  }
47
  SDO_CHANNELS = list(SDO_CHANNELS_MAP.keys())
48
 
 
49
  def setup_and_load_model():
50
  if "model" in APP_CACHE:
51
  yield "Model already loaded. Skipping setup."
 
64
 
65
  yield "Initializing model architecture..."
66
  model_config = config["model"]
67
+ model = HelioSpectFormer(
68
+ img_size=model_config["img_size"], patch_size=model_config["patch_size"],
69
+ in_chans=len(config["data"]["sdo_channels"]), embed_dim=model_config["embed_dim"],
70
+ time_embedding={"type": "linear", "time_dim": len(config["data"]["time_delta_input_minutes"])},
71
+ depth=model_config["depth"], n_spectral_blocks=model_config["n_spectral_blocks"],
72
+ num_heads=model_config["num_heads"], mlp_ratio=model_config["mlp_ratio"],
73
+ drop_rate=model_config["drop_rate"], dtype=torch.bfloat16,
74
+ window_size=model_config["window_size"], dp_rank=model_config["dp_rank"],
75
+ learned_flow=model_config["learned_flow"], use_latitude_in_learned_flow=model_config["learned_flow"],
76
+ init_weights=False, checkpoint_layers=list(range(model_config["depth"])),
77
+ rpe=model_config["rpe"], ensemble=model_config["ensemble"], finetune=model_config["finetune"],
78
+ )
79
  device = "cuda" if torch.cuda.is_available() else "cpu"
80
  APP_CACHE["device"] = device
81
 
 
87
  APP_CACHE["model"] = model
88
  yield "✅ Model setup complete."
89
 
 
90
  def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
91
  config = APP_CACHE["config"]
92
  img_size = config["model"]["img_size"]
93
 
94
  input_deltas = config["data"]["time_delta_input_minutes"]
95
+ target_delta = forecast_horizon_minutes
96
  input_times = [target_dt + datetime.timedelta(minutes=m) for m in input_deltas]
97
  target_time = target_dt + datetime.timedelta(minutes=target_delta)
98
  all_times = sorted(list(set(input_times + [target_time])))
 
111
  if data_maps[t].get("hmi_bx"): data_maps[t][channel] = data_maps[t]["hmi_bx"]
112
  continue
113
 
 
114
  instrument = a.Instrument.hmi if "hmi" in channel else a.Instrument.aia
115
  query = Fido.search(a.Time(t), instrument, physobs, sample, a.Time.nearest==True)
116
 
117
  if not query: raise ValueError(f"No data found for {channel} near {t}")
118
+ files = Fido.fetch(query, path="./data/sdo_cache")
119
  data_maps[t][channel] = sunpy.map.Map(files[0])
120
 
121
  yield "✅ All files downloaded. Starting preprocessing..."
 
149
 
150
  yield (input_tensor, last_input_map, target_map)
151
 
 
 
152
  def run_inference(input_tensor):
153
+ model = APP_CACHE["model"]
154
+ device = APP_CACHE["device"]
155
+ time_deltas = APP_CACHE["config"]["data"]["time_delta_input_minutes"]
156
+ time_delta_tensor = torch.tensor(time_deltas, dtype=torch.float32).unsqueeze(0).to(device)
157
+ input_batch = {"ts": input_tensor.to(device), "time_delta_input": time_delta_tensor}
158
+ with torch.no_grad():
159
+ with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16):
160
+ prediction = model(input_batch)
161
+ return prediction.cpu()
162
 
163
  def generate_visualization(last_input_map, prediction_tensor, target_map, channel_name):
164
+ if last_input_map is None: return None, None, None
165
+ c_idx = SDO_CHANNELS.index(channel_name)
166
+ scaler = APP_CACHE["scalers"]
167
+ all_means, all_stds, all_epsilons, all_sl_scale_factors = scaler.get_params()
168
+ mean, std, epsilon, sl_scale_factor = all_means[c_idx], all_stds[c_idx], all_epsilons[c_idx], all_sl_scale_factors[c_idx]
169
+ pred_slice = inverse_transform_single_channel(
170
+ prediction_tensor[0, c_idx].numpy(), mean=mean, std=std, epsilon=epsilon, sl_scale_factor=sl_scale_factor
171
+ )
172
+ vmax = np.quantile(np.nan_to_num(target_map[channel_name].data), 0.995)
173
+ cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
174
+ cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray'))
175
+ def to_pil(data, flip=False):
176
+ data_clipped = np.nan_to_num(data)
177
+ data_clipped = np.clip(data_clipped, 0, vmax)
178
+ data_norm = data_clipped / vmax if vmax > 0 else data_clipped
179
+ colored = (cmap(data_norm)[:, :, :3] * 255).astype(np.uint8)
180
+ img = Image.fromarray(colored)
181
+ return img.transpose(Image.Transpose.FLIP_TOP_BOTTOM) if flip else img
182
+ return to_pil(last_input_map[channel_name].data, flip=True), to_pil(pred_slice, flip=True), to_pil(target_map[channel_name].data, flip=True)
183
 
 
184
  def forecast_controller(date_str, hour, minute, forecast_horizon):
185
  yield {
186
  log_box: gr.update(value="Starting forecast...", visible=True),
187
  run_button: gr.update(interactive=False),
 
188
  date_input: gr.update(interactive=False),
189
  hour_slider: gr.update(interactive=False),
190
  minute_slider: gr.update(interactive=False),
 
198
  for status in setup_and_load_model():
199
  yield { log_box: status }
200
 
 
201
  target_dt = datetime.datetime.fromisoformat(f"{date_str}T{int(hour):02d}:{int(minute):02d}:00")
202
 
203
  data_pipeline = fetch_and_process_sdo_data(target_dt, forecast_horizon)
 
204
  while True:
205
  try:
206
  status = next(data_pipeline)
 
220
  yield {
221
  log_box: f"✅ Forecast complete for {target_dt.isoformat()} (+{forecast_horizon} mins).",
222
  results_group: gr.update(visible=True),
223
+ state_last_input: last_input_map,
224
+ state_prediction: prediction_tensor,
225
+ state_target: target_map,
226
+ input_display: img_in,
227
+ prediction_display: img_pred,
228
+ target_display: img_target,
229
  }
230
 
231
  except Exception as e:
232
+ error_str = traceback.format_exc()
233
+ logger.error(f"An error occurred: {e}\n{error_str}")
234
+ yield { log_box: f"❌ ERROR: {e}\n\nTraceback:\n{error_str}" }
235
 
236
  finally:
 
237
  yield {
238
  run_button: gr.update(interactive=True),
239
  date_input: gr.update(interactive=True),
 
242
  horizon_slider: gr.update(interactive=True),
243
  }
244
 
 
245
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
246
+ state_last_input = gr.State()
247
+ state_prediction = gr.State()
248
+ state_target = gr.State()
249
 
250
+ gr.Markdown(
251
+ """
252
+ <div align='center'>
253
+ # ☀️ Surya: Live Forecast Demo ☀️
254
+ ### Generate a real forecast for any recent date using NASA's Heliophysics Model.
255
+ **Instructions:**
256
+ 1. Pick a date and time (at least 3 hours in the past).
257
+ 2. Click 'Generate Forecast'. **This will be slow (5-15 minutes) as it downloads live data.**
258
+ 3. Once complete, select different channels to explore the multi-spectrum forecast.
259
+ </div>
260
+ """
261
+ )
262
 
 
263
  with gr.Accordion("Step 1: Configure Forecast", open=True):
264
  with gr.Row():
265
  date_input = gr.Textbox(
 
275
 
276
  run_button = gr.Button("🔮 Generate Forecast", variant="primary")
277
 
 
278
  with gr.Accordion("Step 2: View Log", open=False) as log_accordion:
279
+ log_box = gr.Textbox(label="Log", interactive=False, visible=False, lines=5, max_lines=10)
280
 
 
281
  with gr.Group(visible=False) as results_group:
282
  gr.Markdown("### Step 3: Explore Results")
283
+ channel_selector = gr.Dropdown(
284
+ choices=SDO_CHANNELS, value="aia171", label="🛰️ Select SDO Channel to Visualize"
285
+ )
286
  with gr.Row():
287
+ input_display = gr.Image(label="Last Input to Model", height=512, width=512, interactive=False)
288
+ prediction_display = gr.Image(label="Surya's Forecast", height=512, width=512, interactive=False)
289
+ target_display = gr.Image(label="Ground Truth", height=512, width=512, interactive=False)
290
 
 
291
  run_button.click(
292
  fn=forecast_controller,
293
  inputs=[date_input, hour_slider, minute_slider, horizon_slider],
 
298
  ]
299
  )
300
 
301
+ channel_selector.change(
302
+ fn=generate_visualization,
303
+ inputs=[state_last_input, state_prediction, state_target, channel_selector],
304
+ outputs=[input_display, prediction_display, target_display]
305
+ )
306
 
307
  if __name__ == "__main__":
308
+ os.makedirs("./data/sdo_cache", exist_ok=True)
 
309
  demo.launch(debug=True)