broadfield-dev commited on
Commit
e1181ea
·
verified ·
1 Parent(s): d4d895f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -30
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # Save this file as app.py in the root of the cloned Surya repository
2
 
3
  import gradio as gr
4
  import torch
@@ -155,12 +155,10 @@ def fetch_and_process_sdo_data(target_dt):
155
  target_map = data_maps[target_time]
156
  last_input_map = data_maps[input_times[-1]]
157
 
158
- # The final yield of a generator is its return value
159
  yield (input_tensor, last_input_map, target_map)
160
 
161
 
162
  # --- 3. Inference and Visualization ---
163
- # (These are fast and don't need to be generators)
164
  def run_inference(input_tensor):
165
  model = APP_CACHE["model"]
166
  device = APP_CACHE["device"]
@@ -195,9 +193,6 @@ def generate_visualization(last_input_map, prediction_tensor, target_map, channe
195
 
196
  # --- 4. Gradio UI and Controllers ---
197
  def forecast_controller(dt_str):
198
- # This is now a generator function that yields updates to the UI
199
-
200
- # Initial UI state: disable inputs, clear old results
201
  yield {
202
  log_box: gr.update(value="Starting forecast...", visible=True),
203
  run_button: gr.update(interactive=False),
@@ -208,45 +203,34 @@ def forecast_controller(dt_str):
208
  try:
209
  if not dt_str: raise gr.Error("Please select a date and time.")
210
 
211
- # --- Stage 1: Setup Model ---
212
- # The setup function is also a generator, so we loop through its yields
213
  for status in setup_and_load_model():
214
  yield { log_box: status }
215
 
216
  target_dt = datetime.datetime.fromisoformat(dt_str)
217
 
218
- # --- Stage 2: Fetch and Process Data ---
219
- # We loop through the yields from the data pipeline
220
  data_pipeline = fetch_and_process_sdo_data(target_dt)
221
  while True:
222
  try:
223
- # Get the next status update
224
  status = next(data_pipeline)
225
- # If it's a tuple, it's the final return value
226
  if isinstance(status, tuple):
227
  input_tensor, last_input_map, target_map = status
228
  break
229
- # Otherwise, it's a string update
230
  yield { log_box: status }
231
  except StopIteration:
232
  raise gr.Error("Data processing pipeline finished unexpectedly.")
233
 
234
- # --- Stage 3: Run Inference ---
235
  yield { log_box: "Running AI model inference..." }
236
  prediction_tensor = run_inference(input_tensor)
237
 
238
- # --- Stage 4: Generate Visualization ---
239
  yield { log_box: "Generating final visualizations..." }
240
  img_in, img_pred, img_target = generate_visualization(last_input_map, prediction_tensor, target_map, "aia171")
241
 
242
  yield {
243
  log_box: f"✅ Forecast complete for {target_dt.isoformat()}.",
244
  results_group: gr.update(visible=True),
245
- # Pass final data to state objects
246
  state_last_input: last_input_map,
247
  state_prediction: prediction_tensor,
248
  state_target: target_map,
249
- # Display final images
250
  input_display: img_in,
251
  prediction_display: img_pred,
252
  target_display: img_target,
@@ -258,36 +242,49 @@ def forecast_controller(dt_str):
258
  yield { log_box: f"❌ ERROR: {e}\n\nTraceback:\n{error_str}" }
259
 
260
  finally:
261
- # Final UI state: re-enable inputs
262
  yield {
263
  run_button: gr.update(interactive=True),
264
  datetime_input: gr.update(interactive=True)
265
  }
266
 
267
-
268
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
269
  state_last_input = gr.State()
270
  state_prediction = gr.State()
271
  state_target = gr.State()
272
 
273
- gr.Markdown(...) # UI definition is the same
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
  with gr.Row():
276
- datetime_input = gr.Textbox(...)
 
 
 
277
  run_button = gr.Button("🔮 Generate Forecast", variant="primary")
278
 
279
- # NEW: A dedicated box for logs and feedback
280
- log_box = gr.Textbox(label="Log", interactive=False, visible=False, lines=5)
281
 
282
  with gr.Group(visible=False) as results_group:
283
- channel_selector = gr.Dropdown(...)
 
 
284
  with gr.Row():
285
- input_display = gr.Image(...)
286
- prediction_display = gr.Image(...)
287
- target_display = gr.Image(...)
288
 
289
- # The .click() event is now pointed to our generator function
290
- # It updates multiple components based on what the generator yields
291
  run_button.click(
292
  fn=forecast_controller,
293
  inputs=[datetime_input],
@@ -299,7 +296,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
299
  )
300
 
301
  channel_selector.change(
302
- fn=generate_visualization, # This is a fast function, no generator needed
303
  inputs=[state_last_input, state_prediction, state_target, channel_selector],
304
  outputs=[input_display, prediction_display, target_display]
305
  )
 
1
+ # Save this file as in the root of the cloned Surya repository
2
 
3
  import gradio as gr
4
  import torch
 
155
  target_map = data_maps[target_time]
156
  last_input_map = data_maps[input_times[-1]]
157
 
 
158
  yield (input_tensor, last_input_map, target_map)
159
 
160
 
161
  # --- 3. Inference and Visualization ---
 
162
  def run_inference(input_tensor):
163
  model = APP_CACHE["model"]
164
  device = APP_CACHE["device"]
 
193
 
194
  # --- 4. Gradio UI and Controllers ---
195
  def forecast_controller(dt_str):
 
 
 
196
  yield {
197
  log_box: gr.update(value="Starting forecast...", visible=True),
198
  run_button: gr.update(interactive=False),
 
203
  try:
204
  if not dt_str: raise gr.Error("Please select a date and time.")
205
 
 
 
206
  for status in setup_and_load_model():
207
  yield { log_box: status }
208
 
209
  target_dt = datetime.datetime.fromisoformat(dt_str)
210
 
 
 
211
  data_pipeline = fetch_and_process_sdo_data(target_dt)
212
  while True:
213
  try:
 
214
  status = next(data_pipeline)
 
215
  if isinstance(status, tuple):
216
  input_tensor, last_input_map, target_map = status
217
  break
 
218
  yield { log_box: status }
219
  except StopIteration:
220
  raise gr.Error("Data processing pipeline finished unexpectedly.")
221
 
 
222
  yield { log_box: "Running AI model inference..." }
223
  prediction_tensor = run_inference(input_tensor)
224
 
 
225
  yield { log_box: "Generating final visualizations..." }
226
  img_in, img_pred, img_target = generate_visualization(last_input_map, prediction_tensor, target_map, "aia171")
227
 
228
  yield {
229
  log_box: f"✅ Forecast complete for {target_dt.isoformat()}.",
230
  results_group: gr.update(visible=True),
 
231
  state_last_input: last_input_map,
232
  state_prediction: prediction_tensor,
233
  state_target: target_map,
 
234
  input_display: img_in,
235
  prediction_display: img_pred,
236
  target_display: img_target,
 
242
  yield { log_box: f"❌ ERROR: {e}\n\nTraceback:\n{error_str}" }
243
 
244
  finally:
 
245
  yield {
246
  run_button: gr.update(interactive=True),
247
  datetime_input: gr.update(interactive=True)
248
  }
249
 
250
+ # --- 5. Gradio UI Definition ---
251
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
252
  state_last_input = gr.State()
253
  state_prediction = gr.State()
254
  state_target = gr.State()
255
 
256
+ # *** FIX: Replaced all '...' with complete UI component definitions ***
257
+ gr.Markdown(
258
+ """
259
+ <div align='center'>
260
+ # ☀️ Surya: Live Forecast Demo ☀️
261
+ ### Generate a real forecast for any recent date using NASA's Heliophysics Model.
262
+ **Instructions:**
263
+ 1. Pick a date and time (at least 3 hours in the past).
264
+ 2. Click 'Generate Forecast'. **This will be slow (5-15 minutes) as it downloads live data.**
265
+ 3. Once complete, select different channels to explore the multi-spectrum forecast.
266
+ </div>
267
+ """
268
+ )
269
 
270
  with gr.Row():
271
+ datetime_input = gr.Textbox(
272
+ label="Enter Forecast Start Time (YYYY-MM-DD HH:MM:SS)",
273
+ value=(datetime.datetime.now() - datetime.timedelta(hours=3)).strftime("%Y-%m-%d %H:%M:%S")
274
+ )
275
  run_button = gr.Button("🔮 Generate Forecast", variant="primary")
276
 
277
+ log_box = gr.Textbox(label="Log", interactive=False, visible=False, lines=5, max_lines=10)
 
278
 
279
  with gr.Group(visible=False) as results_group:
280
+ channel_selector = gr.Dropdown(
281
+ choices=SDO_CHANNELS, value="aia171", label="🛰️ Select SDO Channel to Visualize"
282
+ )
283
  with gr.Row():
284
+ input_display = gr.Image(label="Last Input to Model", height=512, width=512, interactive=False)
285
+ prediction_display = gr.Image(label="Surya's Forecast", height=512, width=512, interactive=False)
286
+ target_display = gr.Image(label="Ground Truth", height=512, width=512, interactive=False)
287
 
 
 
288
  run_button.click(
289
  fn=forecast_controller,
290
  inputs=[datetime_input],
 
296
  )
297
 
298
  channel_selector.change(
299
+ fn=generate_visualization,
300
  inputs=[state_last_input, state_prediction, state_target, channel_selector],
301
  outputs=[input_display, prediction_display, target_display]
302
  )