Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -18,6 +18,7 @@ import logging
|
|
18 |
import datetime
|
19 |
import matplotlib.pyplot as plt
|
20 |
import sunpy.visualization.colormaps as sunpy_cm
|
|
|
21 |
|
22 |
# --- Use the official Surya modules ---
|
23 |
from surya.models.helio_spectformer import HelioSpectFormer
|
@@ -33,7 +34,6 @@ logger = logging.getLogger(__name__)
|
|
33 |
# Global cache for model, config, etc.
|
34 |
APP_CACHE = {}
|
35 |
|
36 |
-
# *** FIX: Corrected the a.Wavelength calls to use astropy units ***
|
37 |
SDO_CHANNELS_MAP = {
|
38 |
"aia94": (a.Wavelength(94 * u.angstrom), a.Sample(12 * u.s)),
|
39 |
"aia131": (a.Wavelength(131 * u.angstrom), a.Sample(12 * u.s)),
|
@@ -52,22 +52,23 @@ SDO_CHANNELS_MAP = {
|
|
52 |
SDO_CHANNELS = list(SDO_CHANNELS_MAP.keys())
|
53 |
|
54 |
# --- 1. Model Loading and Setup ---
|
55 |
-
def setup_and_load_model(
|
56 |
if "model" in APP_CACHE:
|
|
|
57 |
return
|
58 |
|
59 |
-
|
60 |
snapshot_download(repo_id="nasa-ibm-ai4science/Surya-1.0", local_dir="data/Surya-1.0",
|
61 |
allow_patterns=["config.yaml", "scalers.yaml", "surya.366m.v1.pt"])
|
62 |
|
63 |
-
|
64 |
with open("data/Surya-1.0/config.yaml") as fp:
|
65 |
config = yaml.safe_load(fp)
|
66 |
APP_CACHE["config"] = config
|
67 |
scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r"))
|
68 |
APP_CACHE["scalers"] = build_scalers(info=scalers_info)
|
69 |
|
70 |
-
|
71 |
model_config = config["model"]
|
72 |
model = HelioSpectFormer(
|
73 |
img_size=model_config["img_size"], patch_size=model_config["patch_size"],
|
@@ -83,15 +84,17 @@ def setup_and_load_model(progress=gr.Progress()):
|
|
83 |
)
|
84 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
85 |
APP_CACHE["device"] = device
|
|
|
|
|
86 |
weights = torch.load(f"data/Surya-1.0/surya.366m.v1.pt", map_location=torch.device(device))
|
87 |
model.load_state_dict(weights, strict=True)
|
88 |
model.to(device)
|
89 |
model.eval()
|
90 |
APP_CACHE["model"] = model
|
91 |
-
|
92 |
|
93 |
-
# --- 2. Live Data Fetching and Preprocessing ---
|
94 |
-
def fetch_and_process_sdo_data(target_dt
|
95 |
config = APP_CACHE["config"]
|
96 |
img_size = config["model"]["img_size"][0]
|
97 |
|
@@ -104,11 +107,12 @@ def fetch_and_process_sdo_data(target_dt, progress):
|
|
104 |
data_maps = {}
|
105 |
total_downloads = len(all_times) * len(SDO_CHANNELS)
|
106 |
downloads_done = 0
|
|
|
107 |
for t in all_times:
|
108 |
data_maps[t] = {}
|
109 |
for i, (channel, (physobs, sample)) in enumerate(SDO_CHANNELS_MAP.items()):
|
110 |
downloads_done += 1
|
111 |
-
|
112 |
|
113 |
if channel in ["hmi_by", "hmi_bz"]:
|
114 |
if data_maps[t].get("hmi_bx"): data_maps[t][channel] = data_maps[t]["hmi_bx"]
|
@@ -122,12 +126,14 @@ def fetch_and_process_sdo_data(target_dt, progress):
|
|
122 |
files = Fido.fetch(query[0, 0], path="./data/sdo_cache")
|
123 |
data_maps[t][channel] = sunpy.map.Map(files[0])
|
124 |
|
|
|
125 |
output_wcs = WCS(naxis=2)
|
126 |
output_wcs.wcs.crpix = [(img_size + 1) / 2, (img_size + 1) / 2]
|
127 |
output_wcs.wcs.cdelt = np.array([-1.2, 1.2]) * u.arcsec
|
128 |
output_wcs.wcs.crval = [0, 0] * u.arcsec
|
129 |
output_wcs.wcs.ctype = ['HPLN-TAN', 'HPLT-TAN']
|
130 |
|
|
|
131 |
processed_tensors = {}
|
132 |
for t, channel_maps in data_maps.items():
|
133 |
channel_tensors = []
|
@@ -139,48 +145,45 @@ def fetch_and_process_sdo_data(target_dt, progress):
|
|
139 |
if exp_time is None or exp_time <= 0: exp_time = 1.0
|
140 |
norm_data = reprojected_data / exp_time
|
141 |
|
142 |
-
scaler =
|
143 |
-
scaled_data = scaler.transform(norm_data)
|
144 |
channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32)))
|
145 |
processed_tensors[t] = torch.stack(channel_tensors)
|
146 |
|
|
|
147 |
input_tensor_list = [processed_tensors[t] for t in input_times]
|
148 |
input_tensor = torch.stack(input_tensor_list, dim=1).unsqueeze(0)
|
149 |
target_map = data_maps[target_time]
|
150 |
last_input_map = data_maps[input_times[-1]]
|
151 |
|
152 |
-
|
|
|
|
|
153 |
|
154 |
# --- 3. Inference and Visualization ---
|
|
|
155 |
def run_inference(input_tensor):
|
156 |
-
logger.info("Running model inference...")
|
157 |
model = APP_CACHE["model"]
|
158 |
device = APP_CACHE["device"]
|
159 |
-
|
160 |
time_deltas = APP_CACHE["config"]["data"]["time_delta_input_minutes"]
|
161 |
time_delta_tensor = torch.tensor(time_deltas, dtype=torch.float32).unsqueeze(0).to(device)
|
162 |
-
|
163 |
input_batch = {"ts": input_tensor.to(device), "time_delta_input": time_delta_tensor}
|
164 |
-
|
165 |
with torch.no_grad():
|
166 |
with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16):
|
167 |
prediction = model(input_batch)
|
168 |
-
logger.info("Inference complete.")
|
169 |
return prediction.cpu()
|
170 |
|
171 |
def generate_visualization(last_input_map, prediction_tensor, target_map, channel_name):
|
172 |
if last_input_map is None: return None, None, None
|
173 |
-
|
174 |
c_idx = SDO_CHANNELS.index(channel_name)
|
175 |
-
|
|
|
|
|
176 |
pred_slice = inverse_transform_single_channel(
|
177 |
-
prediction_tensor[0, c_idx].numpy(), mean=
|
178 |
)
|
179 |
-
|
180 |
vmax = np.quantile(np.nan_to_num(target_map[channel_name].data), 0.995)
|
181 |
cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
|
182 |
cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray'))
|
183 |
-
|
184 |
def to_pil(data, flip=False):
|
185 |
data_clipped = np.nan_to_num(data)
|
186 |
data_clipped = np.clip(data_clipped, 0, vmax)
|
@@ -188,79 +191,115 @@ def generate_visualization(last_input_map, prediction_tensor, target_map, channe
|
|
188 |
colored = (cmap(data_norm)[:, :, :3] * 255).astype(np.uint8)
|
189 |
img = Image.fromarray(colored)
|
190 |
return img.transpose(Image.Transpose.FLIP_TOP_BOTTOM) if flip else img
|
191 |
-
|
192 |
-
return to_pil(last_input_map[channel_name].data), to_pil(pred_slice, flip=True), to_pil(target_map[channel_name].data)
|
193 |
|
194 |
# --- 4. Gradio UI and Controllers ---
|
195 |
-
def forecast_controller(dt_str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
try:
|
197 |
if not dt_str: raise gr.Error("Please select a date and time.")
|
198 |
|
199 |
-
|
200 |
-
|
201 |
-
|
|
|
|
|
202 |
target_dt = datetime.datetime.fromisoformat(dt_str)
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
-
|
|
|
206 |
prediction_tensor = run_inference(input_tensor)
|
207 |
|
|
|
|
|
208 |
img_in, img_pred, img_target = generate_visualization(last_input_map, prediction_tensor, target_map, "aia171")
|
209 |
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
except Exception as e:
|
217 |
-
|
218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
|
220 |
-
def update_visualization_controller(last_input_map, prediction_tensor, target_map, channel_name):
|
221 |
-
if last_input_map is None: return None, None, None
|
222 |
-
return generate_visualization(last_input_map, prediction_tensor, target_map, channel_name)
|
223 |
|
224 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
225 |
state_last_input = gr.State()
|
226 |
state_prediction = gr.State()
|
227 |
state_target = gr.State()
|
228 |
|
229 |
-
gr.Markdown(
|
230 |
-
"""
|
231 |
-
<div align='center'>
|
232 |
-
# ☀️ Surya: Live Forecast Demo ☀️
|
233 |
-
### Generate a real forecast for any recent date using NASA's Heliophysics Model.
|
234 |
-
**Instructions:**
|
235 |
-
1. Pick a date and time (at least 3 hours in the past).
|
236 |
-
2. Click 'Generate Forecast'. **This will be slow (5-15 minutes) as it downloads live data.**
|
237 |
-
3. Once complete, select different channels to explore the multi-spectrum forecast.
|
238 |
-
</div>
|
239 |
-
"""
|
240 |
-
)
|
241 |
|
242 |
with gr.Row():
|
243 |
-
datetime_input = gr.Textbox(
|
244 |
-
value=(datetime.datetime.now() - datetime.timedelta(hours=3)).strftime("%Y-%m-%d %H:%M:%S"))
|
245 |
run_button = gr.Button("🔮 Generate Forecast", variant="primary")
|
246 |
|
|
|
|
|
|
|
247 |
with gr.Group(visible=False) as results_group:
|
248 |
-
|
249 |
-
channel_selector = gr.Dropdown(choices=SDO_CHANNELS, value="aia171", label="🛰️ Select SDO Channel")
|
250 |
with gr.Row():
|
251 |
-
input_display = gr.Image(
|
252 |
-
prediction_display = gr.Image(
|
253 |
-
target_display = gr.Image(
|
254 |
|
|
|
|
|
255 |
run_button.click(
|
256 |
fn=forecast_controller,
|
257 |
inputs=[datetime_input],
|
258 |
-
outputs=[
|
259 |
-
|
|
|
|
|
|
|
260 |
)
|
261 |
|
262 |
channel_selector.change(
|
263 |
-
fn=
|
264 |
inputs=[state_last_input, state_prediction, state_target, channel_selector],
|
265 |
outputs=[input_display, prediction_display, target_display]
|
266 |
)
|
|
|
18 |
import datetime
|
19 |
import matplotlib.pyplot as plt
|
20 |
import sunpy.visualization.colormaps as sunpy_cm
|
21 |
+
import traceback
|
22 |
|
23 |
# --- Use the official Surya modules ---
|
24 |
from surya.models.helio_spectformer import HelioSpectFormer
|
|
|
34 |
# Global cache for model, config, etc.
|
35 |
APP_CACHE = {}
|
36 |
|
|
|
37 |
SDO_CHANNELS_MAP = {
|
38 |
"aia94": (a.Wavelength(94 * u.angstrom), a.Sample(12 * u.s)),
|
39 |
"aia131": (a.Wavelength(131 * u.angstrom), a.Sample(12 * u.s)),
|
|
|
52 |
SDO_CHANNELS = list(SDO_CHANNELS_MAP.keys())
|
53 |
|
54 |
# --- 1. Model Loading and Setup ---
|
55 |
+
def setup_and_load_model():
|
56 |
if "model" in APP_CACHE:
|
57 |
+
yield "Model already loaded. Skipping setup."
|
58 |
return
|
59 |
|
60 |
+
yield "Downloading model files (first run only)..."
|
61 |
snapshot_download(repo_id="nasa-ibm-ai4science/Surya-1.0", local_dir="data/Surya-1.0",
|
62 |
allow_patterns=["config.yaml", "scalers.yaml", "surya.366m.v1.pt"])
|
63 |
|
64 |
+
yield "Loading configuration and data scalers..."
|
65 |
with open("data/Surya-1.0/config.yaml") as fp:
|
66 |
config = yaml.safe_load(fp)
|
67 |
APP_CACHE["config"] = config
|
68 |
scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r"))
|
69 |
APP_CACHE["scalers"] = build_scalers(info=scalers_info)
|
70 |
|
71 |
+
yield "Initializing model architecture..."
|
72 |
model_config = config["model"]
|
73 |
model = HelioSpectFormer(
|
74 |
img_size=model_config["img_size"], patch_size=model_config["patch_size"],
|
|
|
84 |
)
|
85 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
86 |
APP_CACHE["device"] = device
|
87 |
+
|
88 |
+
yield f"Loading model weights to {device}..."
|
89 |
weights = torch.load(f"data/Surya-1.0/surya.366m.v1.pt", map_location=torch.device(device))
|
90 |
model.load_state_dict(weights, strict=True)
|
91 |
model.to(device)
|
92 |
model.eval()
|
93 |
APP_CACHE["model"] = model
|
94 |
+
yield "✅ Model setup complete."
|
95 |
|
96 |
+
# --- 2. Live Data Fetching and Preprocessing (as a generator) ---
|
97 |
+
def fetch_and_process_sdo_data(target_dt):
|
98 |
config = APP_CACHE["config"]
|
99 |
img_size = config["model"]["img_size"][0]
|
100 |
|
|
|
107 |
data_maps = {}
|
108 |
total_downloads = len(all_times) * len(SDO_CHANNELS)
|
109 |
downloads_done = 0
|
110 |
+
yield f"Starting download of {total_downloads} data files..."
|
111 |
for t in all_times:
|
112 |
data_maps[t] = {}
|
113 |
for i, (channel, (physobs, sample)) in enumerate(SDO_CHANNELS_MAP.items()):
|
114 |
downloads_done += 1
|
115 |
+
yield f"Downloading [{downloads_done}/{total_downloads}]: {channel} for {t.strftime('%Y-%m-%d %H:%M')}..."
|
116 |
|
117 |
if channel in ["hmi_by", "hmi_bz"]:
|
118 |
if data_maps[t].get("hmi_bx"): data_maps[t][channel] = data_maps[t]["hmi_bx"]
|
|
|
126 |
files = Fido.fetch(query[0, 0], path="./data/sdo_cache")
|
127 |
data_maps[t][channel] = sunpy.map.Map(files[0])
|
128 |
|
129 |
+
yield "✅ All files downloaded. Starting preprocessing..."
|
130 |
output_wcs = WCS(naxis=2)
|
131 |
output_wcs.wcs.crpix = [(img_size + 1) / 2, (img_size + 1) / 2]
|
132 |
output_wcs.wcs.cdelt = np.array([-1.2, 1.2]) * u.arcsec
|
133 |
output_wcs.wcs.crval = [0, 0] * u.arcsec
|
134 |
output_wcs.wcs.ctype = ['HPLN-TAN', 'HPLT-TAN']
|
135 |
|
136 |
+
scaler = APP_CACHE["scalers"]
|
137 |
processed_tensors = {}
|
138 |
for t, channel_maps in data_maps.items():
|
139 |
channel_tensors = []
|
|
|
145 |
if exp_time is None or exp_time <= 0: exp_time = 1.0
|
146 |
norm_data = reprojected_data / exp_time
|
147 |
|
148 |
+
scaled_data = scaler.transform(norm_data, c_idx=i)
|
|
|
149 |
channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32)))
|
150 |
processed_tensors[t] = torch.stack(channel_tensors)
|
151 |
|
152 |
+
yield "✅ Preprocessing complete."
|
153 |
input_tensor_list = [processed_tensors[t] for t in input_times]
|
154 |
input_tensor = torch.stack(input_tensor_list, dim=1).unsqueeze(0)
|
155 |
target_map = data_maps[target_time]
|
156 |
last_input_map = data_maps[input_times[-1]]
|
157 |
|
158 |
+
# The final yield of a generator is its return value
|
159 |
+
yield (input_tensor, last_input_map, target_map)
|
160 |
+
|
161 |
|
162 |
# --- 3. Inference and Visualization ---
|
163 |
+
# (These are fast and don't need to be generators)
|
164 |
def run_inference(input_tensor):
|
|
|
165 |
model = APP_CACHE["model"]
|
166 |
device = APP_CACHE["device"]
|
|
|
167 |
time_deltas = APP_CACHE["config"]["data"]["time_delta_input_minutes"]
|
168 |
time_delta_tensor = torch.tensor(time_deltas, dtype=torch.float32).unsqueeze(0).to(device)
|
|
|
169 |
input_batch = {"ts": input_tensor.to(device), "time_delta_input": time_delta_tensor}
|
|
|
170 |
with torch.no_grad():
|
171 |
with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16):
|
172 |
prediction = model(input_batch)
|
|
|
173 |
return prediction.cpu()
|
174 |
|
175 |
def generate_visualization(last_input_map, prediction_tensor, target_map, channel_name):
|
176 |
if last_input_map is None: return None, None, None
|
|
|
177 |
c_idx = SDO_CHANNELS.index(channel_name)
|
178 |
+
scaler = APP_CACHE["scalers"]
|
179 |
+
all_means, all_stds, all_epsilons, all_sl_scale_factors = scaler.get_params()
|
180 |
+
mean, std, epsilon, sl_scale_factor = all_means[c_idx], all_stds[c_idx], all_epsilons[c_idx], all_sl_scale_factors[c_idx]
|
181 |
pred_slice = inverse_transform_single_channel(
|
182 |
+
prediction_tensor[0, c_idx].numpy(), mean=mean, std=std, epsilon=epsilon, sl_scale_factor=sl_scale_factor
|
183 |
)
|
|
|
184 |
vmax = np.quantile(np.nan_to_num(target_map[channel_name].data), 0.995)
|
185 |
cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
|
186 |
cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray'))
|
|
|
187 |
def to_pil(data, flip=False):
|
188 |
data_clipped = np.nan_to_num(data)
|
189 |
data_clipped = np.clip(data_clipped, 0, vmax)
|
|
|
191 |
colored = (cmap(data_norm)[:, :, :3] * 255).astype(np.uint8)
|
192 |
img = Image.fromarray(colored)
|
193 |
return img.transpose(Image.Transpose.FLIP_TOP_BOTTOM) if flip else img
|
194 |
+
return to_pil(last_input_map[channel_name].data, flip=True), to_pil(pred_slice, flip=True), to_pil(target_map[channel_name].data, flip=True)
|
|
|
195 |
|
196 |
# --- 4. Gradio UI and Controllers ---
|
197 |
+
def forecast_controller(dt_str):
|
198 |
+
# This is now a generator function that yields updates to the UI
|
199 |
+
|
200 |
+
# Initial UI state: disable inputs, clear old results
|
201 |
+
yield {
|
202 |
+
log_box: gr.update(value="Starting forecast...", visible=True),
|
203 |
+
run_button: gr.update(interactive=False),
|
204 |
+
datetime_input: gr.update(interactive=False),
|
205 |
+
results_group: gr.update(visible=False)
|
206 |
+
}
|
207 |
+
|
208 |
try:
|
209 |
if not dt_str: raise gr.Error("Please select a date and time.")
|
210 |
|
211 |
+
# --- Stage 1: Setup Model ---
|
212 |
+
# The setup function is also a generator, so we loop through its yields
|
213 |
+
for status in setup_and_load_model():
|
214 |
+
yield { log_box: status }
|
215 |
+
|
216 |
target_dt = datetime.datetime.fromisoformat(dt_str)
|
217 |
+
|
218 |
+
# --- Stage 2: Fetch and Process Data ---
|
219 |
+
# We loop through the yields from the data pipeline
|
220 |
+
data_pipeline = fetch_and_process_sdo_data(target_dt)
|
221 |
+
while True:
|
222 |
+
try:
|
223 |
+
# Get the next status update
|
224 |
+
status = next(data_pipeline)
|
225 |
+
# If it's a tuple, it's the final return value
|
226 |
+
if isinstance(status, tuple):
|
227 |
+
input_tensor, last_input_map, target_map = status
|
228 |
+
break
|
229 |
+
# Otherwise, it's a string update
|
230 |
+
yield { log_box: status }
|
231 |
+
except StopIteration:
|
232 |
+
raise gr.Error("Data processing pipeline finished unexpectedly.")
|
233 |
|
234 |
+
# --- Stage 3: Run Inference ---
|
235 |
+
yield { log_box: "Running AI model inference..." }
|
236 |
prediction_tensor = run_inference(input_tensor)
|
237 |
|
238 |
+
# --- Stage 4: Generate Visualization ---
|
239 |
+
yield { log_box: "Generating final visualizations..." }
|
240 |
img_in, img_pred, img_target = generate_visualization(last_input_map, prediction_tensor, target_map, "aia171")
|
241 |
|
242 |
+
yield {
|
243 |
+
log_box: f"✅ Forecast complete for {target_dt.isoformat()}.",
|
244 |
+
results_group: gr.update(visible=True),
|
245 |
+
# Pass final data to state objects
|
246 |
+
state_last_input: last_input_map,
|
247 |
+
state_prediction: prediction_tensor,
|
248 |
+
state_target: target_map,
|
249 |
+
# Display final images
|
250 |
+
input_display: img_in,
|
251 |
+
prediction_display: img_pred,
|
252 |
+
target_display: img_target,
|
253 |
+
}
|
254 |
+
|
255 |
except Exception as e:
|
256 |
+
error_str = traceback.format_exc()
|
257 |
+
logger.error(f"An error occurred: {e}\n{error_str}")
|
258 |
+
yield { log_box: f"❌ ERROR: {e}\n\nTraceback:\n{error_str}" }
|
259 |
+
|
260 |
+
finally:
|
261 |
+
# Final UI state: re-enable inputs
|
262 |
+
yield {
|
263 |
+
run_button: gr.update(interactive=True),
|
264 |
+
datetime_input: gr.update(interactive=True)
|
265 |
+
}
|
266 |
|
|
|
|
|
|
|
267 |
|
268 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
269 |
state_last_input = gr.State()
|
270 |
state_prediction = gr.State()
|
271 |
state_target = gr.State()
|
272 |
|
273 |
+
gr.Markdown(...) # UI definition is the same
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
|
275 |
with gr.Row():
|
276 |
+
datetime_input = gr.Textbox(...)
|
|
|
277 |
run_button = gr.Button("🔮 Generate Forecast", variant="primary")
|
278 |
|
279 |
+
# NEW: A dedicated box for logs and feedback
|
280 |
+
log_box = gr.Textbox(label="Log", interactive=False, visible=False, lines=5)
|
281 |
+
|
282 |
with gr.Group(visible=False) as results_group:
|
283 |
+
channel_selector = gr.Dropdown(...)
|
|
|
284 |
with gr.Row():
|
285 |
+
input_display = gr.Image(...)
|
286 |
+
prediction_display = gr.Image(...)
|
287 |
+
target_display = gr.Image(...)
|
288 |
|
289 |
+
# The .click() event is now pointed to our generator function
|
290 |
+
# It updates multiple components based on what the generator yields
|
291 |
run_button.click(
|
292 |
fn=forecast_controller,
|
293 |
inputs=[datetime_input],
|
294 |
+
outputs=[
|
295 |
+
log_box, run_button, datetime_input, results_group,
|
296 |
+
state_last_input, state_prediction, state_target,
|
297 |
+
input_display, prediction_display, target_display
|
298 |
+
]
|
299 |
)
|
300 |
|
301 |
channel_selector.change(
|
302 |
+
fn=generate_visualization, # This is a fast function, no generator needed
|
303 |
inputs=[state_last_input, state_prediction, state_target, channel_selector],
|
304 |
outputs=[input_display, prediction_display, target_display]
|
305 |
)
|