Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# Save this file as
|
2 |
|
3 |
import gradio as gr
|
4 |
import torch
|
@@ -155,12 +155,10 @@ def fetch_and_process_sdo_data(target_dt):
|
|
155 |
target_map = data_maps[target_time]
|
156 |
last_input_map = data_maps[input_times[-1]]
|
157 |
|
158 |
-
# The final yield of a generator is its return value
|
159 |
yield (input_tensor, last_input_map, target_map)
|
160 |
|
161 |
|
162 |
# --- 3. Inference and Visualization ---
|
163 |
-
# (These are fast and don't need to be generators)
|
164 |
def run_inference(input_tensor):
|
165 |
model = APP_CACHE["model"]
|
166 |
device = APP_CACHE["device"]
|
@@ -195,9 +193,6 @@ def generate_visualization(last_input_map, prediction_tensor, target_map, channe
|
|
195 |
|
196 |
# --- 4. Gradio UI and Controllers ---
|
197 |
def forecast_controller(dt_str):
|
198 |
-
# This is now a generator function that yields updates to the UI
|
199 |
-
|
200 |
-
# Initial UI state: disable inputs, clear old results
|
201 |
yield {
|
202 |
log_box: gr.update(value="Starting forecast...", visible=True),
|
203 |
run_button: gr.update(interactive=False),
|
@@ -208,45 +203,34 @@ def forecast_controller(dt_str):
|
|
208 |
try:
|
209 |
if not dt_str: raise gr.Error("Please select a date and time.")
|
210 |
|
211 |
-
# --- Stage 1: Setup Model ---
|
212 |
-
# The setup function is also a generator, so we loop through its yields
|
213 |
for status in setup_and_load_model():
|
214 |
yield { log_box: status }
|
215 |
|
216 |
target_dt = datetime.datetime.fromisoformat(dt_str)
|
217 |
|
218 |
-
# --- Stage 2: Fetch and Process Data ---
|
219 |
-
# We loop through the yields from the data pipeline
|
220 |
data_pipeline = fetch_and_process_sdo_data(target_dt)
|
221 |
while True:
|
222 |
try:
|
223 |
-
# Get the next status update
|
224 |
status = next(data_pipeline)
|
225 |
-
# If it's a tuple, it's the final return value
|
226 |
if isinstance(status, tuple):
|
227 |
input_tensor, last_input_map, target_map = status
|
228 |
break
|
229 |
-
# Otherwise, it's a string update
|
230 |
yield { log_box: status }
|
231 |
except StopIteration:
|
232 |
raise gr.Error("Data processing pipeline finished unexpectedly.")
|
233 |
|
234 |
-
# --- Stage 3: Run Inference ---
|
235 |
yield { log_box: "Running AI model inference..." }
|
236 |
prediction_tensor = run_inference(input_tensor)
|
237 |
|
238 |
-
# --- Stage 4: Generate Visualization ---
|
239 |
yield { log_box: "Generating final visualizations..." }
|
240 |
img_in, img_pred, img_target = generate_visualization(last_input_map, prediction_tensor, target_map, "aia171")
|
241 |
|
242 |
yield {
|
243 |
log_box: f"✅ Forecast complete for {target_dt.isoformat()}.",
|
244 |
results_group: gr.update(visible=True),
|
245 |
-
# Pass final data to state objects
|
246 |
state_last_input: last_input_map,
|
247 |
state_prediction: prediction_tensor,
|
248 |
state_target: target_map,
|
249 |
-
# Display final images
|
250 |
input_display: img_in,
|
251 |
prediction_display: img_pred,
|
252 |
target_display: img_target,
|
@@ -258,36 +242,49 @@ def forecast_controller(dt_str):
|
|
258 |
yield { log_box: f"❌ ERROR: {e}\n\nTraceback:\n{error_str}" }
|
259 |
|
260 |
finally:
|
261 |
-
# Final UI state: re-enable inputs
|
262 |
yield {
|
263 |
run_button: gr.update(interactive=True),
|
264 |
datetime_input: gr.update(interactive=True)
|
265 |
}
|
266 |
|
267 |
-
|
268 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
269 |
state_last_input = gr.State()
|
270 |
state_prediction = gr.State()
|
271 |
state_target = gr.State()
|
272 |
|
273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
|
275 |
with gr.Row():
|
276 |
-
datetime_input = gr.Textbox(
|
|
|
|
|
|
|
277 |
run_button = gr.Button("🔮 Generate Forecast", variant="primary")
|
278 |
|
279 |
-
|
280 |
-
log_box = gr.Textbox(label="Log", interactive=False, visible=False, lines=5)
|
281 |
|
282 |
with gr.Group(visible=False) as results_group:
|
283 |
-
channel_selector = gr.Dropdown(
|
|
|
|
|
284 |
with gr.Row():
|
285 |
-
input_display = gr.Image(
|
286 |
-
prediction_display = gr.Image(
|
287 |
-
target_display = gr.Image(
|
288 |
|
289 |
-
# The .click() event is now pointed to our generator function
|
290 |
-
# It updates multiple components based on what the generator yields
|
291 |
run_button.click(
|
292 |
fn=forecast_controller,
|
293 |
inputs=[datetime_input],
|
@@ -299,7 +296,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
299 |
)
|
300 |
|
301 |
channel_selector.change(
|
302 |
-
fn=generate_visualization,
|
303 |
inputs=[state_last_input, state_prediction, state_target, channel_selector],
|
304 |
outputs=[input_display, prediction_display, target_display]
|
305 |
)
|
|
|
1 |
+
# Save this file as in the root of the cloned Surya repository
|
2 |
|
3 |
import gradio as gr
|
4 |
import torch
|
|
|
155 |
target_map = data_maps[target_time]
|
156 |
last_input_map = data_maps[input_times[-1]]
|
157 |
|
|
|
158 |
yield (input_tensor, last_input_map, target_map)
|
159 |
|
160 |
|
161 |
# --- 3. Inference and Visualization ---
|
|
|
162 |
def run_inference(input_tensor):
|
163 |
model = APP_CACHE["model"]
|
164 |
device = APP_CACHE["device"]
|
|
|
193 |
|
194 |
# --- 4. Gradio UI and Controllers ---
|
195 |
def forecast_controller(dt_str):
|
|
|
|
|
|
|
196 |
yield {
|
197 |
log_box: gr.update(value="Starting forecast...", visible=True),
|
198 |
run_button: gr.update(interactive=False),
|
|
|
203 |
try:
|
204 |
if not dt_str: raise gr.Error("Please select a date and time.")
|
205 |
|
|
|
|
|
206 |
for status in setup_and_load_model():
|
207 |
yield { log_box: status }
|
208 |
|
209 |
target_dt = datetime.datetime.fromisoformat(dt_str)
|
210 |
|
|
|
|
|
211 |
data_pipeline = fetch_and_process_sdo_data(target_dt)
|
212 |
while True:
|
213 |
try:
|
|
|
214 |
status = next(data_pipeline)
|
|
|
215 |
if isinstance(status, tuple):
|
216 |
input_tensor, last_input_map, target_map = status
|
217 |
break
|
|
|
218 |
yield { log_box: status }
|
219 |
except StopIteration:
|
220 |
raise gr.Error("Data processing pipeline finished unexpectedly.")
|
221 |
|
|
|
222 |
yield { log_box: "Running AI model inference..." }
|
223 |
prediction_tensor = run_inference(input_tensor)
|
224 |
|
|
|
225 |
yield { log_box: "Generating final visualizations..." }
|
226 |
img_in, img_pred, img_target = generate_visualization(last_input_map, prediction_tensor, target_map, "aia171")
|
227 |
|
228 |
yield {
|
229 |
log_box: f"✅ Forecast complete for {target_dt.isoformat()}.",
|
230 |
results_group: gr.update(visible=True),
|
|
|
231 |
state_last_input: last_input_map,
|
232 |
state_prediction: prediction_tensor,
|
233 |
state_target: target_map,
|
|
|
234 |
input_display: img_in,
|
235 |
prediction_display: img_pred,
|
236 |
target_display: img_target,
|
|
|
242 |
yield { log_box: f"❌ ERROR: {e}\n\nTraceback:\n{error_str}" }
|
243 |
|
244 |
finally:
|
|
|
245 |
yield {
|
246 |
run_button: gr.update(interactive=True),
|
247 |
datetime_input: gr.update(interactive=True)
|
248 |
}
|
249 |
|
250 |
+
# --- 5. Gradio UI Definition ---
|
251 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
252 |
state_last_input = gr.State()
|
253 |
state_prediction = gr.State()
|
254 |
state_target = gr.State()
|
255 |
|
256 |
+
# *** FIX: Replaced all '...' with complete UI component definitions ***
|
257 |
+
gr.Markdown(
|
258 |
+
"""
|
259 |
+
<div align='center'>
|
260 |
+
# ☀️ Surya: Live Forecast Demo ☀️
|
261 |
+
### Generate a real forecast for any recent date using NASA's Heliophysics Model.
|
262 |
+
**Instructions:**
|
263 |
+
1. Pick a date and time (at least 3 hours in the past).
|
264 |
+
2. Click 'Generate Forecast'. **This will be slow (5-15 minutes) as it downloads live data.**
|
265 |
+
3. Once complete, select different channels to explore the multi-spectrum forecast.
|
266 |
+
</div>
|
267 |
+
"""
|
268 |
+
)
|
269 |
|
270 |
with gr.Row():
|
271 |
+
datetime_input = gr.Textbox(
|
272 |
+
label="Enter Forecast Start Time (YYYY-MM-DD HH:MM:SS)",
|
273 |
+
value=(datetime.datetime.now() - datetime.timedelta(hours=3)).strftime("%Y-%m-%d %H:%M:%S")
|
274 |
+
)
|
275 |
run_button = gr.Button("🔮 Generate Forecast", variant="primary")
|
276 |
|
277 |
+
log_box = gr.Textbox(label="Log", interactive=False, visible=False, lines=5, max_lines=10)
|
|
|
278 |
|
279 |
with gr.Group(visible=False) as results_group:
|
280 |
+
channel_selector = gr.Dropdown(
|
281 |
+
choices=SDO_CHANNELS, value="aia171", label="🛰️ Select SDO Channel to Visualize"
|
282 |
+
)
|
283 |
with gr.Row():
|
284 |
+
input_display = gr.Image(label="Last Input to Model", height=512, width=512, interactive=False)
|
285 |
+
prediction_display = gr.Image(label="Surya's Forecast", height=512, width=512, interactive=False)
|
286 |
+
target_display = gr.Image(label="Ground Truth", height=512, width=512, interactive=False)
|
287 |
|
|
|
|
|
288 |
run_button.click(
|
289 |
fn=forecast_controller,
|
290 |
inputs=[datetime_input],
|
|
|
296 |
)
|
297 |
|
298 |
channel_selector.change(
|
299 |
+
fn=generate_visualization,
|
300 |
inputs=[state_last_input, state_prediction, state_target, channel_selector],
|
301 |
outputs=[input_display, prediction_display, target_display]
|
302 |
)
|