Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -4,12 +4,7 @@ from huggingface_hub import snapshot_download
|
|
4 |
import yaml
|
5 |
import numpy as np
|
6 |
from PIL import Image
|
7 |
-
import
|
8 |
-
import sunpy.net.attrs as a
|
9 |
-
from sunpy.net import Fido
|
10 |
-
from astropy.wcs import WCS
|
11 |
-
import astropy.units as u
|
12 |
-
from reproject import reproject_interp
|
13 |
import os
|
14 |
import warnings
|
15 |
import logging
|
@@ -17,6 +12,7 @@ import datetime
|
|
17 |
import matplotlib.pyplot as plt
|
18 |
import sunpy.visualization.colormaps as sunpy_cm
|
19 |
import traceback
|
|
|
20 |
|
21 |
from surya.models.helio_spectformer import HelioSpectFormer
|
22 |
from surya.utils.data import build_scalers
|
@@ -29,22 +25,13 @@ logger = logging.getLogger(__name__)
|
|
29 |
|
30 |
APP_CACHE = {}
|
31 |
|
32 |
-
|
33 |
-
"aia94":
|
34 |
-
"
|
35 |
-
"
|
36 |
-
"
|
37 |
-
"aia211": (a.Wavelength(211 * u.angstrom), a.Sample(12 * u.s)),
|
38 |
-
"aia304": (a.Wavelength(304 * u.angstrom), a.Sample(12 * u.s)),
|
39 |
-
"aia335": (a.Wavelength(335 * u.angstrom), a.Sample(12 * u.s)),
|
40 |
-
"aia1600": (a.Wavelength(1600 * u.angstrom), a.Sample(24 * u.s)),
|
41 |
-
"hmi_m": (a.Physobs("intensity"), a.Sample(45 * u.s)),
|
42 |
-
"hmi_bx": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)),
|
43 |
-
"hmi_by": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)),
|
44 |
-
"hmi_bz": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)),
|
45 |
-
"hmi_v": (a.Physobs("los_velocity"), a.Sample(45 * u.s)),
|
46 |
}
|
47 |
-
SDO_CHANNELS = list(
|
48 |
|
49 |
def setup_and_load_model():
|
50 |
if "model" in APP_CACHE:
|
@@ -87,67 +74,56 @@ def setup_and_load_model():
|
|
87 |
APP_CACHE["model"] = model
|
88 |
yield "✅ Model setup complete."
|
89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
|
91 |
config = APP_CACHE["config"]
|
92 |
img_size = config["model"]["img_size"]
|
93 |
|
94 |
input_deltas = config["data"]["time_delta_input_minutes"]
|
95 |
-
target_delta = forecast_horizon_minutes
|
96 |
input_times = [target_dt + datetime.timedelta(minutes=m) for m in input_deltas]
|
97 |
-
target_time = target_dt + datetime.timedelta(minutes=
|
98 |
all_times = sorted(list(set(input_times + [target_time])))
|
99 |
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
for t in all_times:
|
106 |
-
|
107 |
-
for
|
108 |
-
|
109 |
-
yield f"
|
110 |
-
|
111 |
-
if channel in ["hmi_by", "hmi_bz"]:
|
112 |
-
if data_maps[t].get("hmi_bx"):
|
113 |
-
smap = data_maps[t]["hmi_bx"]
|
114 |
-
data_maps[t][channel] = smap
|
115 |
-
last_successful_map[channel] = smap
|
116 |
-
continue
|
117 |
|
118 |
-
|
119 |
-
instrument = a.Instrument.hmi if "hmi" in channel else a.Instrument.aia
|
120 |
-
query = Fido.search(time_attr, instrument, physobs, sample)
|
121 |
-
|
122 |
-
if query:
|
123 |
-
files = Fido.fetch(query[0,0], path="./data/sdo_cache")
|
124 |
-
smap = sunpy.map.Map(files[0])
|
125 |
-
data_maps[t][channel] = smap
|
126 |
-
last_successful_map[channel] = smap
|
127 |
-
elif channel in last_successful_map:
|
128 |
-
yield f"⚠️ WARNING: No data for {channel} near {t}. Reusing previous image."
|
129 |
-
data_maps[t][channel] = last_successful_map[channel]
|
130 |
-
else:
|
131 |
-
raise ValueError(f"CRITICAL: No data found for initial image of {channel}. Cannot proceed.")
|
132 |
-
|
133 |
-
yield "✅ All files downloaded. Starting preprocessing..."
|
134 |
-
output_wcs = WCS(naxis=2)
|
135 |
-
output_wcs.wcs.crpix = [(img_size + 1) / 2, (img_size + 1) / 2]
|
136 |
-
output_wcs.wcs.cdelt = np.array([-1.2, 1.2]) * u.arcsec
|
137 |
-
output_wcs.wcs.crval = [0, 0] * u.arcsec
|
138 |
-
output_wcs.wcs.ctype = ['HPLN-TAN', 'HPLT-TAN']
|
139 |
-
|
140 |
scaler = APP_CACHE["scalers"]
|
141 |
processed_tensors = {}
|
142 |
-
for t,
|
143 |
channel_tensors = []
|
144 |
for i, channel in enumerate(SDO_CHANNELS):
|
145 |
-
|
146 |
-
|
|
|
147 |
|
148 |
-
|
149 |
-
|
150 |
-
norm_data = reprojected_data / exp_time
|
151 |
|
152 |
scaled_data = scaler.transform(norm_data.reshape(-1, 1), c_idx=i).reshape(norm_data.shape)
|
153 |
channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32)))
|
@@ -156,10 +132,10 @@ def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
|
|
156 |
yield "✅ Preprocessing complete."
|
157 |
input_tensor_list = [processed_tensors[t] for t in input_times]
|
158 |
input_tensor = torch.stack(input_tensor_list, dim=1).unsqueeze(0)
|
159 |
-
|
160 |
-
|
161 |
|
162 |
-
yield (input_tensor,
|
163 |
|
164 |
def run_inference(input_tensor):
|
165 |
model = APP_CACHE["model"]
|
@@ -181,17 +157,20 @@ def generate_visualization(last_input_map, prediction_tensor, target_map, channe
|
|
181 |
pred_slice = inverse_transform_single_channel(
|
182 |
prediction_tensor[0, c_idx].numpy(), mean=mean, std=std, epsilon=epsilon, sl_scale_factor=sl_scale_factor
|
183 |
)
|
184 |
-
|
|
|
|
|
185 |
cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
|
186 |
cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray'))
|
187 |
-
|
|
|
188 |
data_clipped = np.nan_to_num(data)
|
189 |
data_clipped = np.clip(data_clipped, 0, vmax)
|
190 |
data_norm = data_clipped / vmax if vmax > 0 else data_clipped
|
191 |
colored = (cmap(data_norm)[:, :, :3] * 255).astype(np.uint8)
|
192 |
-
|
193 |
-
|
194 |
-
return
|
195 |
|
196 |
def forecast_controller(date_str, hour, minute, forecast_horizon):
|
197 |
yield {
|
@@ -265,7 +244,9 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
265 |
# ☀️ Surya: Live Forecast Demo ☀️
|
266 |
### A Foundation Model for Solar Dynamics
|
267 |
This demo runs NASA's **Surya**, a foundation model trained to understand the physics of the Sun.
|
268 |
-
It looks at the Sun in 13 different channels (
|
|
|
|
|
269 |
</div>
|
270 |
"""
|
271 |
)
|
@@ -273,11 +254,11 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
273 |
with gr.Accordion("Step 1: Configure Forecast", open=True):
|
274 |
with gr.Row():
|
275 |
date_input = gr.Textbox(
|
276 |
-
label="Date",
|
277 |
-
value=datetime.
|
278 |
)
|
279 |
-
hour_slider = gr.Slider(label="Hour (UTC)", minimum=0, maximum=23, step=1, value=datetime.datetime.
|
280 |
-
minute_slider = gr.Slider(label="Minute", minimum=0, maximum=59, step=1, value=datetime.datetime.
|
281 |
horizon_slider = gr.Slider(
|
282 |
label="Forecast Horizon (minutes ahead)",
|
283 |
minimum=12, maximum=120, step=12, value=12
|
@@ -285,7 +266,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
285 |
|
286 |
run_button = gr.Button("🔮 Generate Forecast", variant="primary")
|
287 |
|
288 |
-
with gr.Accordion("Step 2: View Log", open=False)
|
289 |
log_box = gr.Textbox(label="Log", interactive=False, visible=False, lines=5, max_lines=10)
|
290 |
|
291 |
with gr.Group(visible=False) as results_group:
|
@@ -315,5 +296,4 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
315 |
)
|
316 |
|
317 |
if __name__ == "__main__":
|
318 |
-
os.makedirs("./data/sdo_cache", exist_ok=True)
|
319 |
demo.launch(debug=True)
|
|
|
4 |
import yaml
|
5 |
import numpy as np
|
6 |
from PIL import Image
|
7 |
+
import requests
|
|
|
|
|
|
|
|
|
|
|
8 |
import os
|
9 |
import warnings
|
10 |
import logging
|
|
|
12 |
import matplotlib.pyplot as plt
|
13 |
import sunpy.visualization.colormaps as sunpy_cm
|
14 |
import traceback
|
15 |
+
from io import BytesIO
|
16 |
|
17 |
from surya.models.helio_spectformer import HelioSpectFormer
|
18 |
from surya.utils.data import build_scalers
|
|
|
25 |
|
26 |
APP_CACHE = {}
|
27 |
|
28 |
+
CHANNEL_TO_URL_CODE = {
|
29 |
+
"aia94": "0094", "aia131": "0131", "aia171": "0171", "aia193": "0193",
|
30 |
+
"aia211": "0211", "aia304": "0304", "aia335": "0335", "aia1600": "1600",
|
31 |
+
"hmi_m": "HMIBC", "hmi_bx": "HMIB", "hmi_by": "HMIB",
|
32 |
+
"hmi_bz": "HMIB", "hmi_v": "HMID"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
}
|
34 |
+
SDO_CHANNELS = list(CHANNEL_TO_URL_CODE.keys())
|
35 |
|
36 |
def setup_and_load_model():
|
37 |
if "model" in APP_CACHE:
|
|
|
74 |
APP_CACHE["model"] = model
|
75 |
yield "✅ Model setup complete."
|
76 |
|
77 |
+
def fetch_browse_image(channel, target_dt, max_retries=15):
|
78 |
+
url_code = CHANNEL_TO_URL_CODE[channel]
|
79 |
+
base_url = "https://sdo.gsfc.nasa.gov/assets/img/browse"
|
80 |
+
|
81 |
+
for i in range(max_retries):
|
82 |
+
dt_to_try = target_dt - datetime.timedelta(minutes=i)
|
83 |
+
date_str = dt_to_try.strftime("%Y/%m/%d")
|
84 |
+
img_str = dt_to_try.strftime(f"%Y%m%d_%H%M%S_4096_{url_code}.jpg")
|
85 |
+
url = f"{base_url}/{date_str}/{img_str}"
|
86 |
+
|
87 |
+
response = requests.get(url)
|
88 |
+
if response.status_code == 200:
|
89 |
+
logger.info(f"Successfully found image for {channel} at {dt_to_try}")
|
90 |
+
return Image.open(BytesIO(response.content))
|
91 |
+
|
92 |
+
raise FileNotFoundError(f"Could not find any recent image for {channel} within {max_retries} minutes of {target_dt}.")
|
93 |
+
|
94 |
def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
|
95 |
config = APP_CACHE["config"]
|
96 |
img_size = config["model"]["img_size"]
|
97 |
|
98 |
input_deltas = config["data"]["time_delta_input_minutes"]
|
|
|
99 |
input_times = [target_dt + datetime.timedelta(minutes=m) for m in input_deltas]
|
100 |
+
target_time = target_dt + datetime.timedelta(minutes=forecast_horizon_minutes)
|
101 |
all_times = sorted(list(set(input_times + [target_time])))
|
102 |
|
103 |
+
images = {}
|
104 |
+
total_fetches = len(all_times) * len(SDO_CHANNELS)
|
105 |
+
fetches_done = 0
|
106 |
+
yield f"Starting search for {total_fetches} data files..."
|
107 |
+
|
108 |
for t in all_times:
|
109 |
+
images[t] = {}
|
110 |
+
for channel in SDO_CHANNELS:
|
111 |
+
fetches_done += 1
|
112 |
+
yield f"Searching [{fetches_done}/{total_fetches}]: {channel} near {t.strftime('%Y-%m-%d %H:%M')}..."
|
113 |
+
images[t][channel] = fetch_browse_image(channel, t)
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
+
yield "✅ All images found. Starting preprocessing..."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
scaler = APP_CACHE["scalers"]
|
117 |
processed_tensors = {}
|
118 |
+
for t, channel_images in images.items():
|
119 |
channel_tensors = []
|
120 |
for i, channel in enumerate(SDO_CHANNELS):
|
121 |
+
img = channel_images[channel]
|
122 |
+
if img.mode != 'L':
|
123 |
+
img = img.convert('L')
|
124 |
|
125 |
+
img_resized = img.resize((img_size, img_size), Image.Resampling.LANCZOS)
|
126 |
+
norm_data = np.array(img_resized, dtype=np.float32)
|
|
|
127 |
|
128 |
scaled_data = scaler.transform(norm_data.reshape(-1, 1), c_idx=i).reshape(norm_data.shape)
|
129 |
channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32)))
|
|
|
132 |
yield "✅ Preprocessing complete."
|
133 |
input_tensor_list = [processed_tensors[t] for t in input_times]
|
134 |
input_tensor = torch.stack(input_tensor_list, dim=1).unsqueeze(0)
|
135 |
+
target_image_map = images[target_time]
|
136 |
+
last_input_image_map = images[input_times[-1]]
|
137 |
|
138 |
+
yield (input_tensor, last_input_image_map, target_image_map)
|
139 |
|
140 |
def run_inference(input_tensor):
|
141 |
model = APP_CACHE["model"]
|
|
|
157 |
pred_slice = inverse_transform_single_channel(
|
158 |
prediction_tensor[0, c_idx].numpy(), mean=mean, std=std, epsilon=epsilon, sl_scale_factor=sl_scale_factor
|
159 |
)
|
160 |
+
|
161 |
+
target_img_data = np.array(target_map[channel_name])
|
162 |
+
vmax = np.quantile(np.nan_to_num(target_img_data), 0.995)
|
163 |
cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
|
164 |
cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray'))
|
165 |
+
|
166 |
+
def to_pil(data):
|
167 |
data_clipped = np.nan_to_num(data)
|
168 |
data_clipped = np.clip(data_clipped, 0, vmax)
|
169 |
data_norm = data_clipped / vmax if vmax > 0 else data_clipped
|
170 |
colored = (cmap(data_norm)[:, :, :3] * 255).astype(np.uint8)
|
171 |
+
return Image.fromarray(colored)
|
172 |
+
|
173 |
+
return last_input_map[channel_name], to_pil(pred_slice), target_map[channel_name]
|
174 |
|
175 |
def forecast_controller(date_str, hour, minute, forecast_horizon):
|
176 |
yield {
|
|
|
244 |
# ☀️ Surya: Live Forecast Demo ☀️
|
245 |
### A Foundation Model for Solar Dynamics
|
246 |
This demo runs NASA's **Surya**, a foundation model trained to understand the physics of the Sun.
|
247 |
+
It looks at the Sun in 13 different channels (wavelengths of light) simultaneously to learn the complex relationships between phenomena like coronal loops, magnetic fields, and solar flares. By seeing these interconnected views, it can generate a holistic forecast of what the entire solar disk will look like in the near future.
|
248 |
+
<br>
|
249 |
+
<p style="color:red;font-weight:bold;">NOTE: This demo uses lower-quality browse images for reliability. The model was trained on high-fidelity scientific data, so forecast accuracy may vary.</p>
|
250 |
</div>
|
251 |
"""
|
252 |
)
|
|
|
254 |
with gr.Accordion("Step 1: Configure Forecast", open=True):
|
255 |
with gr.Row():
|
256 |
date_input = gr.Textbox(
|
257 |
+
label="Date (YYYY-MM-DD)",
|
258 |
+
value=(datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=3)).strftime("%Y-%m-%d")
|
259 |
)
|
260 |
+
hour_slider = gr.Slider(label="Hour (UTC)", minimum=0, maximum=23, step=1, value=(datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=3)).hour)
|
261 |
+
minute_slider = gr.Slider(label="Minute (UTC)", minimum=0, maximum=59, step=1, value=(datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=3)).minute)
|
262 |
horizon_slider = gr.Slider(
|
263 |
label="Forecast Horizon (minutes ahead)",
|
264 |
minimum=12, maximum=120, step=12, value=12
|
|
|
266 |
|
267 |
run_button = gr.Button("🔮 Generate Forecast", variant="primary")
|
268 |
|
269 |
+
with gr.Accordion("Step 2: View Log", open=False):
|
270 |
log_box = gr.Textbox(label="Log", interactive=False, visible=False, lines=5, max_lines=10)
|
271 |
|
272 |
with gr.Group(visible=False) as results_group:
|
|
|
296 |
)
|
297 |
|
298 |
if __name__ == "__main__":
|
|
|
299 |
demo.launch(debug=True)
|