Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1 |
-
# Save this file as in the root of the cloned Surya repository
|
2 |
-
|
3 |
import gradio as gr
|
4 |
import torch
|
5 |
from huggingface_hub import snapshot_download
|
@@ -20,18 +18,15 @@ import matplotlib.pyplot as plt
|
|
20 |
import sunpy.visualization.colormaps as sunpy_cm
|
21 |
import traceback
|
22 |
|
23 |
-
# --- Use the official Surya modules ---
|
24 |
from surya.models.helio_spectformer import HelioSpectFormer
|
25 |
from surya.utils.data import build_scalers
|
26 |
from surya.datasets.helio import inverse_transform_single_channel
|
27 |
|
28 |
-
# --- Configuration ---
|
29 |
warnings.filterwarnings("ignore", category=UserWarning, module='sunpy')
|
30 |
warnings.filterwarnings("ignore", category=FutureWarning)
|
31 |
logging.basicConfig(level=logging.INFO)
|
32 |
logger = logging.getLogger(__name__)
|
33 |
|
34 |
-
# Global cache for model, config, etc.
|
35 |
APP_CACHE = {}
|
36 |
|
37 |
SDO_CHANNELS_MAP = {
|
@@ -45,13 +40,12 @@ SDO_CHANNELS_MAP = {
|
|
45 |
"aia1600": (a.Wavelength(1600 * u.angstrom), a.Sample(24 * u.s)),
|
46 |
"hmi_m": (a.Physobs("intensity"), a.Sample(45 * u.s)),
|
47 |
"hmi_bx": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)),
|
48 |
-
"hmi_by": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)),
|
49 |
-
"hmi_bz": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)),
|
50 |
"hmi_v": (a.Physobs("los_velocity"), a.Sample(45 * u.s)),
|
51 |
}
|
52 |
SDO_CHANNELS = list(SDO_CHANNELS_MAP.keys())
|
53 |
|
54 |
-
# --- 1. Model Loading and Setup ---
|
55 |
def setup_and_load_model():
|
56 |
if "model" in APP_CACHE:
|
57 |
yield "Model already loaded. Skipping setup."
|
@@ -70,7 +64,18 @@ def setup_and_load_model():
|
|
70 |
|
71 |
yield "Initializing model architecture..."
|
72 |
model_config = config["model"]
|
73 |
-
model = HelioSpectFormer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
75 |
APP_CACHE["device"] = device
|
76 |
|
@@ -82,13 +87,12 @@ def setup_and_load_model():
|
|
82 |
APP_CACHE["model"] = model
|
83 |
yield "✅ Model setup complete."
|
84 |
|
85 |
-
# --- 2. Live Data Fetching and Preprocessing (as a generator) ---
|
86 |
def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
|
87 |
config = APP_CACHE["config"]
|
88 |
img_size = config["model"]["img_size"]
|
89 |
|
90 |
input_deltas = config["data"]["time_delta_input_minutes"]
|
91 |
-
target_delta = forecast_horizon_minutes
|
92 |
input_times = [target_dt + datetime.timedelta(minutes=m) for m in input_deltas]
|
93 |
target_time = target_dt + datetime.timedelta(minutes=target_delta)
|
94 |
all_times = sorted(list(set(input_times + [target_time])))
|
@@ -107,12 +111,11 @@ def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
|
|
107 |
if data_maps[t].get("hmi_bx"): data_maps[t][channel] = data_maps[t]["hmi_bx"]
|
108 |
continue
|
109 |
|
110 |
-
# *** FIX: Use a.Time.nearest=True for robust fetching instead of a time window ***
|
111 |
instrument = a.Instrument.hmi if "hmi" in channel else a.Instrument.aia
|
112 |
query = Fido.search(a.Time(t), instrument, physobs, sample, a.Time.nearest==True)
|
113 |
|
114 |
if not query: raise ValueError(f"No data found for {channel} near {t}")
|
115 |
-
files = Fido.fetch(query, path="./data/sdo_cache")
|
116 |
data_maps[t][channel] = sunpy.map.Map(files[0])
|
117 |
|
118 |
yield "✅ All files downloaded. Starting preprocessing..."
|
@@ -146,22 +149,42 @@ def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
|
|
146 |
|
147 |
yield (input_tensor, last_input_map, target_map)
|
148 |
|
149 |
-
|
150 |
-
# --- 3. Inference and Visualization ---
|
151 |
def run_inference(input_tensor):
|
152 |
-
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
def generate_visualization(last_input_map, prediction_tensor, target_map, channel_name):
|
156 |
-
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
-
# --- 4. Gradio UI and Controllers ---
|
160 |
def forecast_controller(date_str, hour, minute, forecast_horizon):
|
161 |
yield {
|
162 |
log_box: gr.update(value="Starting forecast...", visible=True),
|
163 |
run_button: gr.update(interactive=False),
|
164 |
-
# Also disable the other controls
|
165 |
date_input: gr.update(interactive=False),
|
166 |
hour_slider: gr.update(interactive=False),
|
167 |
minute_slider: gr.update(interactive=False),
|
@@ -175,11 +198,9 @@ def forecast_controller(date_str, hour, minute, forecast_horizon):
|
|
175 |
for status in setup_and_load_model():
|
176 |
yield { log_box: status }
|
177 |
|
178 |
-
# Construct datetime from the new UI components
|
179 |
target_dt = datetime.datetime.fromisoformat(f"{date_str}T{int(hour):02d}:{int(minute):02d}:00")
|
180 |
|
181 |
data_pipeline = fetch_and_process_sdo_data(target_dt, forecast_horizon)
|
182 |
-
# The rest of the generator logic remains the same...
|
183 |
while True:
|
184 |
try:
|
185 |
status = next(data_pipeline)
|
@@ -199,14 +220,20 @@ def forecast_controller(date_str, hour, minute, forecast_horizon):
|
|
199 |
yield {
|
200 |
log_box: f"✅ Forecast complete for {target_dt.isoformat()} (+{forecast_horizon} mins).",
|
201 |
results_group: gr.update(visible=True),
|
202 |
-
|
|
|
|
|
|
|
|
|
|
|
203 |
}
|
204 |
|
205 |
except Exception as e:
|
206 |
-
|
|
|
|
|
207 |
|
208 |
finally:
|
209 |
-
# Re-enable all controls
|
210 |
yield {
|
211 |
run_button: gr.update(interactive=True),
|
212 |
date_input: gr.update(interactive=True),
|
@@ -215,14 +242,24 @@ def forecast_controller(date_str, hour, minute, forecast_horizon):
|
|
215 |
horizon_slider: gr.update(interactive=True),
|
216 |
}
|
217 |
|
218 |
-
# --- 5. Gradio UI Definition ---
|
219 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
220 |
-
|
221 |
-
|
|
|
222 |
|
223 |
-
gr.Markdown(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
|
225 |
-
# --- NEW: Controls Section ---
|
226 |
with gr.Accordion("Step 1: Configure Forecast", open=True):
|
227 |
with gr.Row():
|
228 |
date_input = gr.Textbox(
|
@@ -238,20 +275,19 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
238 |
|
239 |
run_button = gr.Button("🔮 Generate Forecast", variant="primary")
|
240 |
|
241 |
-
# --- NEW: Moved log box to its own section ---
|
242 |
with gr.Accordion("Step 2: View Log", open=False) as log_accordion:
|
243 |
-
log_box = gr.Textbox(label="Log", interactive=False, visible=
|
244 |
|
245 |
-
# --- Results section is now Step 3 ---
|
246 |
with gr.Group(visible=False) as results_group:
|
247 |
gr.Markdown("### Step 3: Explore Results")
|
248 |
-
channel_selector = gr.Dropdown(
|
|
|
|
|
249 |
with gr.Row():
|
250 |
-
input_display = gr.Image(
|
251 |
-
prediction_display = gr.Image(
|
252 |
-
target_display = gr.Image(
|
253 |
|
254 |
-
# --- Event Handlers ---
|
255 |
run_button.click(
|
256 |
fn=forecast_controller,
|
257 |
inputs=[date_input, hour_slider, minute_slider, horizon_slider],
|
@@ -262,9 +298,12 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
262 |
]
|
263 |
)
|
264 |
|
265 |
-
channel_selector.change(
|
|
|
|
|
|
|
|
|
266 |
|
267 |
if __name__ == "__main__":
|
268 |
-
|
269 |
-
# This is a condensed version showing only the key changes
|
270 |
demo.launch(debug=True)
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
from huggingface_hub import snapshot_download
|
|
|
18 |
import sunpy.visualization.colormaps as sunpy_cm
|
19 |
import traceback
|
20 |
|
|
|
21 |
from surya.models.helio_spectformer import HelioSpectFormer
|
22 |
from surya.utils.data import build_scalers
|
23 |
from surya.datasets.helio import inverse_transform_single_channel
|
24 |
|
|
|
25 |
warnings.filterwarnings("ignore", category=UserWarning, module='sunpy')
|
26 |
warnings.filterwarnings("ignore", category=FutureWarning)
|
27 |
logging.basicConfig(level=logging.INFO)
|
28 |
logger = logging.getLogger(__name__)
|
29 |
|
|
|
30 |
APP_CACHE = {}
|
31 |
|
32 |
SDO_CHANNELS_MAP = {
|
|
|
40 |
"aia1600": (a.Wavelength(1600 * u.angstrom), a.Sample(24 * u.s)),
|
41 |
"hmi_m": (a.Physobs("intensity"), a.Sample(45 * u.s)),
|
42 |
"hmi_bx": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)),
|
43 |
+
"hmi_by": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)),
|
44 |
+
"hmi_bz": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)),
|
45 |
"hmi_v": (a.Physobs("los_velocity"), a.Sample(45 * u.s)),
|
46 |
}
|
47 |
SDO_CHANNELS = list(SDO_CHANNELS_MAP.keys())
|
48 |
|
|
|
49 |
def setup_and_load_model():
|
50 |
if "model" in APP_CACHE:
|
51 |
yield "Model already loaded. Skipping setup."
|
|
|
64 |
|
65 |
yield "Initializing model architecture..."
|
66 |
model_config = config["model"]
|
67 |
+
model = HelioSpectFormer(
|
68 |
+
img_size=model_config["img_size"], patch_size=model_config["patch_size"],
|
69 |
+
in_chans=len(config["data"]["sdo_channels"]), embed_dim=model_config["embed_dim"],
|
70 |
+
time_embedding={"type": "linear", "time_dim": len(config["data"]["time_delta_input_minutes"])},
|
71 |
+
depth=model_config["depth"], n_spectral_blocks=model_config["n_spectral_blocks"],
|
72 |
+
num_heads=model_config["num_heads"], mlp_ratio=model_config["mlp_ratio"],
|
73 |
+
drop_rate=model_config["drop_rate"], dtype=torch.bfloat16,
|
74 |
+
window_size=model_config["window_size"], dp_rank=model_config["dp_rank"],
|
75 |
+
learned_flow=model_config["learned_flow"], use_latitude_in_learned_flow=model_config["learned_flow"],
|
76 |
+
init_weights=False, checkpoint_layers=list(range(model_config["depth"])),
|
77 |
+
rpe=model_config["rpe"], ensemble=model_config["ensemble"], finetune=model_config["finetune"],
|
78 |
+
)
|
79 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
80 |
APP_CACHE["device"] = device
|
81 |
|
|
|
87 |
APP_CACHE["model"] = model
|
88 |
yield "✅ Model setup complete."
|
89 |
|
|
|
90 |
def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
|
91 |
config = APP_CACHE["config"]
|
92 |
img_size = config["model"]["img_size"]
|
93 |
|
94 |
input_deltas = config["data"]["time_delta_input_minutes"]
|
95 |
+
target_delta = forecast_horizon_minutes
|
96 |
input_times = [target_dt + datetime.timedelta(minutes=m) for m in input_deltas]
|
97 |
target_time = target_dt + datetime.timedelta(minutes=target_delta)
|
98 |
all_times = sorted(list(set(input_times + [target_time])))
|
|
|
111 |
if data_maps[t].get("hmi_bx"): data_maps[t][channel] = data_maps[t]["hmi_bx"]
|
112 |
continue
|
113 |
|
|
|
114 |
instrument = a.Instrument.hmi if "hmi" in channel else a.Instrument.aia
|
115 |
query = Fido.search(a.Time(t), instrument, physobs, sample, a.Time.nearest==True)
|
116 |
|
117 |
if not query: raise ValueError(f"No data found for {channel} near {t}")
|
118 |
+
files = Fido.fetch(query, path="./data/sdo_cache")
|
119 |
data_maps[t][channel] = sunpy.map.Map(files[0])
|
120 |
|
121 |
yield "✅ All files downloaded. Starting preprocessing..."
|
|
|
149 |
|
150 |
yield (input_tensor, last_input_map, target_map)
|
151 |
|
|
|
|
|
152 |
def run_inference(input_tensor):
|
153 |
+
model = APP_CACHE["model"]
|
154 |
+
device = APP_CACHE["device"]
|
155 |
+
time_deltas = APP_CACHE["config"]["data"]["time_delta_input_minutes"]
|
156 |
+
time_delta_tensor = torch.tensor(time_deltas, dtype=torch.float32).unsqueeze(0).to(device)
|
157 |
+
input_batch = {"ts": input_tensor.to(device), "time_delta_input": time_delta_tensor}
|
158 |
+
with torch.no_grad():
|
159 |
+
with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16):
|
160 |
+
prediction = model(input_batch)
|
161 |
+
return prediction.cpu()
|
162 |
|
163 |
def generate_visualization(last_input_map, prediction_tensor, target_map, channel_name):
|
164 |
+
if last_input_map is None: return None, None, None
|
165 |
+
c_idx = SDO_CHANNELS.index(channel_name)
|
166 |
+
scaler = APP_CACHE["scalers"]
|
167 |
+
all_means, all_stds, all_epsilons, all_sl_scale_factors = scaler.get_params()
|
168 |
+
mean, std, epsilon, sl_scale_factor = all_means[c_idx], all_stds[c_idx], all_epsilons[c_idx], all_sl_scale_factors[c_idx]
|
169 |
+
pred_slice = inverse_transform_single_channel(
|
170 |
+
prediction_tensor[0, c_idx].numpy(), mean=mean, std=std, epsilon=epsilon, sl_scale_factor=sl_scale_factor
|
171 |
+
)
|
172 |
+
vmax = np.quantile(np.nan_to_num(target_map[channel_name].data), 0.995)
|
173 |
+
cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
|
174 |
+
cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray'))
|
175 |
+
def to_pil(data, flip=False):
|
176 |
+
data_clipped = np.nan_to_num(data)
|
177 |
+
data_clipped = np.clip(data_clipped, 0, vmax)
|
178 |
+
data_norm = data_clipped / vmax if vmax > 0 else data_clipped
|
179 |
+
colored = (cmap(data_norm)[:, :, :3] * 255).astype(np.uint8)
|
180 |
+
img = Image.fromarray(colored)
|
181 |
+
return img.transpose(Image.Transpose.FLIP_TOP_BOTTOM) if flip else img
|
182 |
+
return to_pil(last_input_map[channel_name].data, flip=True), to_pil(pred_slice, flip=True), to_pil(target_map[channel_name].data, flip=True)
|
183 |
|
|
|
184 |
def forecast_controller(date_str, hour, minute, forecast_horizon):
|
185 |
yield {
|
186 |
log_box: gr.update(value="Starting forecast...", visible=True),
|
187 |
run_button: gr.update(interactive=False),
|
|
|
188 |
date_input: gr.update(interactive=False),
|
189 |
hour_slider: gr.update(interactive=False),
|
190 |
minute_slider: gr.update(interactive=False),
|
|
|
198 |
for status in setup_and_load_model():
|
199 |
yield { log_box: status }
|
200 |
|
|
|
201 |
target_dt = datetime.datetime.fromisoformat(f"{date_str}T{int(hour):02d}:{int(minute):02d}:00")
|
202 |
|
203 |
data_pipeline = fetch_and_process_sdo_data(target_dt, forecast_horizon)
|
|
|
204 |
while True:
|
205 |
try:
|
206 |
status = next(data_pipeline)
|
|
|
220 |
yield {
|
221 |
log_box: f"✅ Forecast complete for {target_dt.isoformat()} (+{forecast_horizon} mins).",
|
222 |
results_group: gr.update(visible=True),
|
223 |
+
state_last_input: last_input_map,
|
224 |
+
state_prediction: prediction_tensor,
|
225 |
+
state_target: target_map,
|
226 |
+
input_display: img_in,
|
227 |
+
prediction_display: img_pred,
|
228 |
+
target_display: img_target,
|
229 |
}
|
230 |
|
231 |
except Exception as e:
|
232 |
+
error_str = traceback.format_exc()
|
233 |
+
logger.error(f"An error occurred: {e}\n{error_str}")
|
234 |
+
yield { log_box: f"❌ ERROR: {e}\n\nTraceback:\n{error_str}" }
|
235 |
|
236 |
finally:
|
|
|
237 |
yield {
|
238 |
run_button: gr.update(interactive=True),
|
239 |
date_input: gr.update(interactive=True),
|
|
|
242 |
horizon_slider: gr.update(interactive=True),
|
243 |
}
|
244 |
|
|
|
245 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
246 |
+
state_last_input = gr.State()
|
247 |
+
state_prediction = gr.State()
|
248 |
+
state_target = gr.State()
|
249 |
|
250 |
+
gr.Markdown(
|
251 |
+
"""
|
252 |
+
<div align='center'>
|
253 |
+
# ☀️ Surya: Live Forecast Demo ☀️
|
254 |
+
### Generate a real forecast for any recent date using NASA's Heliophysics Model.
|
255 |
+
**Instructions:**
|
256 |
+
1. Pick a date and time (at least 3 hours in the past).
|
257 |
+
2. Click 'Generate Forecast'. **This will be slow (5-15 minutes) as it downloads live data.**
|
258 |
+
3. Once complete, select different channels to explore the multi-spectrum forecast.
|
259 |
+
</div>
|
260 |
+
"""
|
261 |
+
)
|
262 |
|
|
|
263 |
with gr.Accordion("Step 1: Configure Forecast", open=True):
|
264 |
with gr.Row():
|
265 |
date_input = gr.Textbox(
|
|
|
275 |
|
276 |
run_button = gr.Button("🔮 Generate Forecast", variant="primary")
|
277 |
|
|
|
278 |
with gr.Accordion("Step 2: View Log", open=False) as log_accordion:
|
279 |
+
log_box = gr.Textbox(label="Log", interactive=False, visible=False, lines=5, max_lines=10)
|
280 |
|
|
|
281 |
with gr.Group(visible=False) as results_group:
|
282 |
gr.Markdown("### Step 3: Explore Results")
|
283 |
+
channel_selector = gr.Dropdown(
|
284 |
+
choices=SDO_CHANNELS, value="aia171", label="🛰️ Select SDO Channel to Visualize"
|
285 |
+
)
|
286 |
with gr.Row():
|
287 |
+
input_display = gr.Image(label="Last Input to Model", height=512, width=512, interactive=False)
|
288 |
+
prediction_display = gr.Image(label="Surya's Forecast", height=512, width=512, interactive=False)
|
289 |
+
target_display = gr.Image(label="Ground Truth", height=512, width=512, interactive=False)
|
290 |
|
|
|
291 |
run_button.click(
|
292 |
fn=forecast_controller,
|
293 |
inputs=[date_input, hour_slider, minute_slider, horizon_slider],
|
|
|
298 |
]
|
299 |
)
|
300 |
|
301 |
+
channel_selector.change(
|
302 |
+
fn=generate_visualization,
|
303 |
+
inputs=[state_last_input, state_prediction, state_target, channel_selector],
|
304 |
+
outputs=[input_display, prediction_display, target_display]
|
305 |
+
)
|
306 |
|
307 |
if __name__ == "__main__":
|
308 |
+
os.makedirs("./data/sdo_cache", exist_ok=True)
|
|
|
309 |
demo.launch(debug=True)
|