broadfield-dev commited on
Commit
9013587
·
verified ·
1 Parent(s): 59f507f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -101
app.py CHANGED
@@ -70,18 +70,7 @@ def setup_and_load_model():
70
 
71
  yield "Initializing model architecture..."
72
  model_config = config["model"]
73
- model = HelioSpectFormer(
74
- img_size=model_config["img_size"], patch_size=model_config["patch_size"],
75
- in_chans=len(config["data"]["sdo_channels"]), embed_dim=model_config["embed_dim"],
76
- time_embedding={"type": "linear", "time_dim": len(config["data"]["time_delta_input_minutes"])},
77
- depth=model_config["depth"], n_spectral_blocks=model_config["n_spectral_blocks"],
78
- num_heads=model_config["num_heads"], mlp_ratio=model_config["mlp_ratio"],
79
- drop_rate=model_config["drop_rate"], dtype=torch.bfloat16,
80
- window_size=model_config["window_size"], dp_rank=model_config["dp_rank"],
81
- learned_flow=model_config["learned_flow"], use_latitude_in_learned_flow=model_config["learned_flow"],
82
- init_weights=False, checkpoint_layers=list(range(model_config["depth"])),
83
- rpe=model_config["rpe"], ensemble=model_config["ensemble"], finetune=model_config["finetune"],
84
- )
85
  device = "cuda" if torch.cuda.is_available() else "cpu"
86
  APP_CACHE["device"] = device
87
 
@@ -94,13 +83,12 @@ def setup_and_load_model():
94
  yield "✅ Model setup complete."
95
 
96
  # --- 2. Live Data Fetching and Preprocessing (as a generator) ---
97
- def fetch_and_process_sdo_data(target_dt):
98
  config = APP_CACHE["config"]
99
  img_size = config["model"]["img_size"]
100
 
101
  input_deltas = config["data"]["time_delta_input_minutes"]
102
- # *** FIX: Access target_delta as an integer, not a list. Removed [0]. ***
103
- target_delta = config["data"]["time_delta_target_minutes"]
104
  input_times = [target_dt + datetime.timedelta(minutes=m) for m in input_deltas]
105
  target_time = target_dt + datetime.timedelta(minutes=target_delta)
106
  all_times = sorted(list(set(input_times + [target_time])))
@@ -118,13 +106,13 @@ def fetch_and_process_sdo_data(target_dt):
118
  if channel in ["hmi_by", "hmi_bz"]:
119
  if data_maps[t].get("hmi_bx"): data_maps[t][channel] = data_maps[t]["hmi_bx"]
120
  continue
121
-
122
- time_attr = a.Time(t - datetime.timedelta(minutes=10), t + datetime.timedelta(minutes=10))
123
  instrument = a.Instrument.hmi if "hmi" in channel else a.Instrument.aia
124
- query = Fido.search(time_attr, instrument, physobs, sample)
125
 
126
- if not query: raise ValueError(f"No data found for {channel} at {t}")
127
- files = Fido.fetch(query[0, 0], path="./data/sdo_cache")
128
  data_maps[t][channel] = sunpy.map.Map(files[0])
129
 
130
  yield "✅ All files downloaded. Starting preprocessing..."
@@ -161,55 +149,37 @@ def fetch_and_process_sdo_data(target_dt):
161
 
162
  # --- 3. Inference and Visualization ---
163
  def run_inference(input_tensor):
164
- model = APP_CACHE["model"]
165
- device = APP_CACHE["device"]
166
- time_deltas = APP_CACHE["config"]["data"]["time_delta_input_minutes"]
167
- time_delta_tensor = torch.tensor(time_deltas, dtype=torch.float32).unsqueeze(0).to(device)
168
- input_batch = {"ts": input_tensor.to(device), "time_delta_input": time_delta_tensor}
169
- with torch.no_grad():
170
- with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16):
171
- prediction = model(input_batch)
172
- return prediction.cpu()
173
 
174
  def generate_visualization(last_input_map, prediction_tensor, target_map, channel_name):
175
- if last_input_map is None: return None, None, None
176
- c_idx = SDO_CHANNELS.index(channel_name)
177
- scaler = APP_CACHE["scalers"]
178
- all_means, all_stds, all_epsilons, all_sl_scale_factors = scaler.get_params()
179
- mean, std, epsilon, sl_scale_factor = all_means[c_idx], all_stds[c_idx], all_epsilons[c_idx], all_sl_scale_factors[c_idx]
180
- pred_slice = inverse_transform_single_channel(
181
- prediction_tensor[0, c_idx].numpy(), mean=mean, std=std, epsilon=epsilon, sl_scale_factor=sl_scale_factor
182
- )
183
- vmax = np.quantile(np.nan_to_num(target_map[channel_name].data), 0.995)
184
- cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
185
- cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray'))
186
- def to_pil(data, flip=False):
187
- data_clipped = np.nan_to_num(data)
188
- data_clipped = np.clip(data_clipped, 0, vmax)
189
- data_norm = data_clipped / vmax if vmax > 0 else data_clipped
190
- colored = (cmap(data_norm)[:, :, :3] * 255).astype(np.uint8)
191
- img = Image.fromarray(colored)
192
- return img.transpose(Image.Transpose.FLIP_TOP_BOTTOM) if flip else img
193
- return to_pil(last_input_map[channel_name].data, flip=True), to_pil(pred_slice, flip=True), to_pil(target_map[channel_name].data, flip=True)
194
 
195
  # --- 4. Gradio UI and Controllers ---
196
- def forecast_controller(dt_str):
197
  yield {
198
  log_box: gr.update(value="Starting forecast...", visible=True),
199
  run_button: gr.update(interactive=False),
200
- datetime_input: gr.update(interactive=False),
 
 
 
 
201
  results_group: gr.update(visible=False)
202
  }
203
 
204
  try:
205
- if not dt_str: raise gr.Error("Please select a date and time.")
206
 
207
  for status in setup_and_load_model():
208
  yield { log_box: status }
209
-
210
- target_dt = datetime.datetime.fromisoformat(dt_str)
211
 
212
- data_pipeline = fetch_and_process_sdo_data(target_dt)
 
 
 
 
213
  while True:
214
  try:
215
  status = next(data_pipeline)
@@ -227,80 +197,74 @@ def forecast_controller(dt_str):
227
  img_in, img_pred, img_target = generate_visualization(last_input_map, prediction_tensor, target_map, "aia171")
228
 
229
  yield {
230
- log_box: f"✅ Forecast complete for {target_dt.isoformat()}.",
231
  results_group: gr.update(visible=True),
232
- state_last_input: last_input_map,
233
- state_prediction: prediction_tensor,
234
- state_target: target_map,
235
- input_display: img_in,
236
- prediction_display: img_pred,
237
- target_display: img_target,
238
  }
239
 
240
  except Exception as e:
241
- error_str = traceback.format_exc()
242
- logger.error(f"An error occurred: {e}\n{error_str}")
243
- yield { log_box: f"❌ ERROR: {e}\n\nTraceback:\n{error_str}" }
244
 
245
  finally:
 
246
  yield {
247
  run_button: gr.update(interactive=True),
248
- datetime_input: gr.update(interactive=True)
 
 
 
249
  }
250
 
251
  # --- 5. Gradio UI Definition ---
252
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
253
- state_last_input = gr.State()
254
- state_prediction = gr.State()
255
- state_target = gr.State()
256
 
257
- gr.Markdown(
258
- """
259
- <div align='center'>
260
- # ☀️ Surya: Live Forecast Demo ☀️
261
- ### Generate a real forecast for any recent date using NASA's Heliophysics Model.
262
- **Instructions:**
263
- 1. Pick a date and time (at least 3 hours in the past).
264
- 2. Click 'Generate Forecast'. **This will be slow (5-15 minutes) as it downloads live data.**
265
- 3. Once complete, select different channels to explore the multi-spectrum forecast.
266
- </div>
267
- """
268
- )
269
 
270
- with gr.Row():
271
- datetime_input = gr.Textbox(
272
- label="Enter Forecast Start Time (YYYY-MM-DD HH:MM:SS)",
273
- value=(datetime.datetime.now() - datetime.timedelta(hours=3)).strftime("%Y-%m-%d %H:%M:%S")
 
 
 
 
 
 
 
 
274
  )
275
- run_button = gr.Button("🔮 Generate Forecast", variant="primary")
276
-
277
- log_box = gr.Textbox(label="Log", interactive=False, visible=False, lines=5, max_lines=10)
 
 
 
278
 
 
279
  with gr.Group(visible=False) as results_group:
280
- channel_selector = gr.Dropdown(
281
- choices=SDO_CHANNELS, value="aia171", label="🛰️ Select SDO Channel to Visualize"
282
- )
283
  with gr.Row():
284
- input_display = gr.Image(label="Last Input to Model", height=512, width=512, interactive=False)
285
- prediction_display = gr.Image(label="Surya's Forecast", height=512, width=512, interactive=False)
286
- target_display = gr.Image(label="Ground Truth", height=512, width=512, interactive=False)
287
 
 
288
  run_button.click(
289
  fn=forecast_controller,
290
- inputs=[datetime_input],
291
  outputs=[
292
- log_box, run_button, datetime_input, results_group,
293
  state_last_input, state_prediction, state_target,
294
  input_display, prediction_display, target_display
295
  ]
296
  )
297
 
298
- channel_selector.change(
299
- fn=generate_visualization,
300
- inputs=[state_last_input, state_prediction, state_target, channel_selector],
301
- outputs=[input_display, prediction_display, target_display]
302
- )
303
 
304
  if __name__ == "__main__":
305
- os.makedirs("./data/sdo_cache", exist_ok=True)
 
306
  demo.launch(debug=True)
 
70
 
71
  yield "Initializing model architecture..."
72
  model_config = config["model"]
73
+ model = HelioSpectFormer(...) # Full model definition
 
 
 
 
 
 
 
 
 
 
 
74
  device = "cuda" if torch.cuda.is_available() else "cpu"
75
  APP_CACHE["device"] = device
76
 
 
83
  yield "✅ Model setup complete."
84
 
85
  # --- 2. Live Data Fetching and Preprocessing (as a generator) ---
86
+ def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
87
  config = APP_CACHE["config"]
88
  img_size = config["model"]["img_size"]
89
 
90
  input_deltas = config["data"]["time_delta_input_minutes"]
91
+ target_delta = forecast_horizon_minutes # Use user-provided horizon
 
92
  input_times = [target_dt + datetime.timedelta(minutes=m) for m in input_deltas]
93
  target_time = target_dt + datetime.timedelta(minutes=target_delta)
94
  all_times = sorted(list(set(input_times + [target_time])))
 
106
  if channel in ["hmi_by", "hmi_bz"]:
107
  if data_maps[t].get("hmi_bx"): data_maps[t][channel] = data_maps[t]["hmi_bx"]
108
  continue
109
+
110
+ # *** FIX: Use a.Time.nearest=True for robust fetching instead of a time window ***
111
  instrument = a.Instrument.hmi if "hmi" in channel else a.Instrument.aia
112
+ query = Fido.search(a.Time(t), instrument, physobs, sample, a.Time.nearest==True)
113
 
114
+ if not query: raise ValueError(f"No data found for {channel} near {t}")
115
+ files = Fido.fetch(query, path="./data/sdo_cache") # Fetch the entire result
116
  data_maps[t][channel] = sunpy.map.Map(files[0])
117
 
118
  yield "✅ All files downloaded. Starting preprocessing..."
 
149
 
150
  # --- 3. Inference and Visualization ---
151
  def run_inference(input_tensor):
152
+ # This function remains the same
153
+ ...
 
 
 
 
 
 
 
154
 
155
  def generate_visualization(last_input_map, prediction_tensor, target_map, channel_name):
156
+ # This function remains the same
157
+ ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  # --- 4. Gradio UI and Controllers ---
160
+ def forecast_controller(date_str, hour, minute, forecast_horizon):
161
  yield {
162
  log_box: gr.update(value="Starting forecast...", visible=True),
163
  run_button: gr.update(interactive=False),
164
+ # Also disable the other controls
165
+ date_input: gr.update(interactive=False),
166
+ hour_slider: gr.update(interactive=False),
167
+ minute_slider: gr.update(interactive=False),
168
+ horizon_slider: gr.update(interactive=False),
169
  results_group: gr.update(visible=False)
170
  }
171
 
172
  try:
173
+ if not date_str: raise gr.Error("Please select a date.")
174
 
175
  for status in setup_and_load_model():
176
  yield { log_box: status }
 
 
177
 
178
+ # Construct datetime from the new UI components
179
+ target_dt = datetime.datetime.fromisoformat(f"{date_str}T{int(hour):02d}:{int(minute):02d}:00")
180
+
181
+ data_pipeline = fetch_and_process_sdo_data(target_dt, forecast_horizon)
182
+ # The rest of the generator logic remains the same...
183
  while True:
184
  try:
185
  status = next(data_pipeline)
 
197
  img_in, img_pred, img_target = generate_visualization(last_input_map, prediction_tensor, target_map, "aia171")
198
 
199
  yield {
200
+ log_box: f"✅ Forecast complete for {target_dt.isoformat()} (+{forecast_horizon} mins).",
201
  results_group: gr.update(visible=True),
202
+ # ... update states and images
 
 
 
 
 
203
  }
204
 
205
  except Exception as e:
206
+ # ... error handling
 
 
207
 
208
  finally:
209
+ # Re-enable all controls
210
  yield {
211
  run_button: gr.update(interactive=True),
212
+ date_input: gr.update(interactive=True),
213
+ hour_slider: gr.update(interactive=True),
214
+ minute_slider: gr.update(interactive=True),
215
+ horizon_slider: gr.update(interactive=True),
216
  }
217
 
218
  # --- 5. Gradio UI Definition ---
219
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
220
+ # State objects remain the same
221
+ ...
 
222
 
223
+ gr.Markdown(...) # Title remains the same
 
 
 
 
 
 
 
 
 
 
 
224
 
225
+ # --- NEW: Controls Section ---
226
+ with gr.Accordion("Step 1: Configure Forecast", open=True):
227
+ with gr.Row():
228
+ date_input = gr.Textbox(
229
+ label="Date",
230
+ value=datetime.date.today().strftime("%Y-%m-%d")
231
+ )
232
+ hour_slider = gr.Slider(label="Hour (UTC)", minimum=0, maximum=23, step=1, value=datetime.datetime.utcnow().hour - 3)
233
+ minute_slider = gr.Slider(label="Minute", minimum=0, maximum=59, step=1, value=datetime.datetime.utcnow().minute)
234
+ horizon_slider = gr.Slider(
235
+ label="Forecast Horizon (minutes ahead)",
236
+ minimum=12, maximum=120, step=12, value=12
237
  )
238
+
239
+ run_button = gr.Button("🔮 Generate Forecast", variant="primary")
240
+
241
+ # --- NEW: Moved log box to its own section ---
242
+ with gr.Accordion("Step 2: View Log", open=False) as log_accordion:
243
+ log_box = gr.Textbox(label="Log", interactive=False, visible=True, lines=5, max_lines=10)
244
 
245
+ # --- Results section is now Step 3 ---
246
  with gr.Group(visible=False) as results_group:
247
+ gr.Markdown("### Step 3: Explore Results")
248
+ channel_selector = gr.Dropdown(...)
 
249
  with gr.Row():
250
+ input_display = gr.Image(...)
251
+ prediction_display = gr.Image(...)
252
+ target_display = gr.Image(...)
253
 
254
+ # --- Event Handlers ---
255
  run_button.click(
256
  fn=forecast_controller,
257
+ inputs=[date_input, hour_slider, minute_slider, horizon_slider],
258
  outputs=[
259
+ log_box, run_button, date_input, hour_slider, minute_slider, horizon_slider, results_group,
260
  state_last_input, state_prediction, state_target,
261
  input_display, prediction_display, target_display
262
  ]
263
  )
264
 
265
+ channel_selector.change(...) # This remains the same
 
 
 
 
266
 
267
  if __name__ == "__main__":
268
+ # Fill in the missing ... from previous versions for the full script
269
+ # This is a condensed version showing only the key changes
270
  demo.launch(debug=True)