broadfield-dev commited on
Commit
4ac0d86
·
verified ·
1 Parent(s): a2ea901

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -82
app.py CHANGED
@@ -4,12 +4,7 @@ from huggingface_hub import snapshot_download
4
  import yaml
5
  import numpy as np
6
  from PIL import Image
7
- import sunpy.map
8
- import sunpy.net.attrs as a
9
- from sunpy.net import Fido
10
- from astropy.wcs import WCS
11
- import astropy.units as u
12
- from reproject import reproject_interp
13
  import os
14
  import warnings
15
  import logging
@@ -17,6 +12,7 @@ import datetime
17
  import matplotlib.pyplot as plt
18
  import sunpy.visualization.colormaps as sunpy_cm
19
  import traceback
 
20
 
21
  from surya.models.helio_spectformer import HelioSpectFormer
22
  from surya.utils.data import build_scalers
@@ -29,22 +25,13 @@ logger = logging.getLogger(__name__)
29
 
30
  APP_CACHE = {}
31
 
32
- SDO_CHANNELS_MAP = {
33
- "aia94": (a.Wavelength(94 * u.angstrom), a.Sample(12 * u.s)),
34
- "aia131": (a.Wavelength(131 * u.angstrom), a.Sample(12 * u.s)),
35
- "aia171": (a.Wavelength(171 * u.angstrom), a.Sample(12 * u.s)),
36
- "aia193": (a.Wavelength(193 * u.angstrom), a.Sample(12 * u.s)),
37
- "aia211": (a.Wavelength(211 * u.angstrom), a.Sample(12 * u.s)),
38
- "aia304": (a.Wavelength(304 * u.angstrom), a.Sample(12 * u.s)),
39
- "aia335": (a.Wavelength(335 * u.angstrom), a.Sample(12 * u.s)),
40
- "aia1600": (a.Wavelength(1600 * u.angstrom), a.Sample(24 * u.s)),
41
- "hmi_m": (a.Physobs("intensity"), a.Sample(45 * u.s)),
42
- "hmi_bx": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)),
43
- "hmi_by": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)),
44
- "hmi_bz": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)),
45
- "hmi_v": (a.Physobs("los_velocity"), a.Sample(45 * u.s)),
46
  }
47
- SDO_CHANNELS = list(SDO_CHANNELS_MAP.keys())
48
 
49
  def setup_and_load_model():
50
  if "model" in APP_CACHE:
@@ -87,67 +74,56 @@ def setup_and_load_model():
87
  APP_CACHE["model"] = model
88
  yield "✅ Model setup complete."
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
91
  config = APP_CACHE["config"]
92
  img_size = config["model"]["img_size"]
93
 
94
  input_deltas = config["data"]["time_delta_input_minutes"]
95
- target_delta = forecast_horizon_minutes
96
  input_times = [target_dt + datetime.timedelta(minutes=m) for m in input_deltas]
97
- target_time = target_dt + datetime.timedelta(minutes=target_delta)
98
  all_times = sorted(list(set(input_times + [target_time])))
99
 
100
- data_maps = {}
101
- last_successful_map = {}
102
- total_downloads = len(all_times) * len(SDO_CHANNELS)
103
- downloads_done = 0
104
- yield f"Starting download of {total_downloads} data files..."
105
  for t in all_times:
106
- data_maps[t] = {}
107
- for i, (channel, (physobs, sample)) in enumerate(SDO_CHANNELS_MAP.items()):
108
- downloads_done += 1
109
- yield f"Downloading [{downloads_done}/{total_downloads}]: {channel} for {t.strftime('%Y-%m-%d %H:%M')}..."
110
-
111
- if channel in ["hmi_by", "hmi_bz"]:
112
- if data_maps[t].get("hmi_bx"):
113
- smap = data_maps[t]["hmi_bx"]
114
- data_maps[t][channel] = smap
115
- last_successful_map[channel] = smap
116
- continue
117
 
118
- time_attr = a.Time(t - datetime.timedelta(minutes=5), t + datetime.timedelta(minutes=5))
119
- instrument = a.Instrument.hmi if "hmi" in channel else a.Instrument.aia
120
- query = Fido.search(time_attr, instrument, physobs, sample)
121
-
122
- if query:
123
- files = Fido.fetch(query[0,0], path="./data/sdo_cache")
124
- smap = sunpy.map.Map(files[0])
125
- data_maps[t][channel] = smap
126
- last_successful_map[channel] = smap
127
- elif channel in last_successful_map:
128
- yield f"⚠️ WARNING: No data for {channel} near {t}. Reusing previous image."
129
- data_maps[t][channel] = last_successful_map[channel]
130
- else:
131
- raise ValueError(f"CRITICAL: No data found for initial image of {channel}. Cannot proceed.")
132
-
133
- yield "✅ All files downloaded. Starting preprocessing..."
134
- output_wcs = WCS(naxis=2)
135
- output_wcs.wcs.crpix = [(img_size + 1) / 2, (img_size + 1) / 2]
136
- output_wcs.wcs.cdelt = np.array([-1.2, 1.2]) * u.arcsec
137
- output_wcs.wcs.crval = [0, 0] * u.arcsec
138
- output_wcs.wcs.ctype = ['HPLN-TAN', 'HPLT-TAN']
139
-
140
  scaler = APP_CACHE["scalers"]
141
  processed_tensors = {}
142
- for t, channel_maps in data_maps.items():
143
  channel_tensors = []
144
  for i, channel in enumerate(SDO_CHANNELS):
145
- smap = channel_maps[channel]
146
- reprojected_data, _ = reproject_interp(smap, output_wcs, shape_out=(img_size, img_size))
 
147
 
148
- exp_time = smap.meta.get('exptime', 1.0)
149
- if exp_time is None or exp_time <= 0: exp_time = 1.0
150
- norm_data = reprojected_data / exp_time
151
 
152
  scaled_data = scaler.transform(norm_data.reshape(-1, 1), c_idx=i).reshape(norm_data.shape)
153
  channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32)))
@@ -156,10 +132,10 @@ def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
156
  yield "✅ Preprocessing complete."
157
  input_tensor_list = [processed_tensors[t] for t in input_times]
158
  input_tensor = torch.stack(input_tensor_list, dim=1).unsqueeze(0)
159
- target_map = data_maps[target_time]
160
- last_input_map = data_maps[input_times[-1]]
161
 
162
- yield (input_tensor, last_input_map, target_map)
163
 
164
  def run_inference(input_tensor):
165
  model = APP_CACHE["model"]
@@ -181,17 +157,20 @@ def generate_visualization(last_input_map, prediction_tensor, target_map, channe
181
  pred_slice = inverse_transform_single_channel(
182
  prediction_tensor[0, c_idx].numpy(), mean=mean, std=std, epsilon=epsilon, sl_scale_factor=sl_scale_factor
183
  )
184
- vmax = np.quantile(np.nan_to_num(target_map[channel_name].data), 0.995)
 
 
185
  cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
186
  cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray'))
187
- def to_pil(data, flip=False):
 
188
  data_clipped = np.nan_to_num(data)
189
  data_clipped = np.clip(data_clipped, 0, vmax)
190
  data_norm = data_clipped / vmax if vmax > 0 else data_clipped
191
  colored = (cmap(data_norm)[:, :, :3] * 255).astype(np.uint8)
192
- img = Image.fromarray(colored)
193
- return img.transpose(Image.Transpose.FLIP_TOP_BOTTOM) if flip else img
194
- return to_pil(last_input_map[channel_name].data, flip=True), to_pil(pred_slice, flip=True), to_pil(target_map[channel_name].data, flip=True)
195
 
196
  def forecast_controller(date_str, hour, minute, forecast_horizon):
197
  yield {
@@ -265,7 +244,9 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
265
  # ☀️ Surya: Live Forecast Demo ☀️
266
  ### A Foundation Model for Solar Dynamics
267
  This demo runs NASA's **Surya**, a foundation model trained to understand the physics of the Sun.
268
- It looks at the Sun in 13 different channels (8 from the AIA instrument, 5 from HMI) simultaneously to learn the complex relationships between solar phenomena like coronal loops, magnetic fields, and solar flares. By seeing these interconnected views, it can generate a holistic forecast of what the entire solar disk will look like in the near future.
 
 
269
  </div>
270
  """
271
  )
@@ -273,11 +254,11 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
273
  with gr.Accordion("Step 1: Configure Forecast", open=True):
274
  with gr.Row():
275
  date_input = gr.Textbox(
276
- label="Date",
277
- value=datetime.date.today().strftime("%Y-%m-%d")
278
  )
279
- hour_slider = gr.Slider(label="Hour (UTC)", minimum=0, maximum=23, step=1, value=datetime.datetime.utcnow().hour - 3)
280
- minute_slider = gr.Slider(label="Minute", minimum=0, maximum=59, step=1, value=datetime.datetime.utcnow().minute)
281
  horizon_slider = gr.Slider(
282
  label="Forecast Horizon (minutes ahead)",
283
  minimum=12, maximum=120, step=12, value=12
@@ -285,7 +266,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
285
 
286
  run_button = gr.Button("🔮 Generate Forecast", variant="primary")
287
 
288
- with gr.Accordion("Step 2: View Log", open=False) as log_accordion:
289
  log_box = gr.Textbox(label="Log", interactive=False, visible=False, lines=5, max_lines=10)
290
 
291
  with gr.Group(visible=False) as results_group:
@@ -315,5 +296,4 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
315
  )
316
 
317
  if __name__ == "__main__":
318
- os.makedirs("./data/sdo_cache", exist_ok=True)
319
  demo.launch(debug=True)
 
4
  import yaml
5
  import numpy as np
6
  from PIL import Image
7
+ import requests
 
 
 
 
 
8
  import os
9
  import warnings
10
  import logging
 
12
  import matplotlib.pyplot as plt
13
  import sunpy.visualization.colormaps as sunpy_cm
14
  import traceback
15
+ from io import BytesIO
16
 
17
  from surya.models.helio_spectformer import HelioSpectFormer
18
  from surya.utils.data import build_scalers
 
25
 
26
  APP_CACHE = {}
27
 
28
+ CHANNEL_TO_URL_CODE = {
29
+ "aia94": "0094", "aia131": "0131", "aia171": "0171", "aia193": "0193",
30
+ "aia211": "0211", "aia304": "0304", "aia335": "0335", "aia1600": "1600",
31
+ "hmi_m": "HMIBC", "hmi_bx": "HMIB", "hmi_by": "HMIB",
32
+ "hmi_bz": "HMIB", "hmi_v": "HMID"
 
 
 
 
 
 
 
 
 
33
  }
34
+ SDO_CHANNELS = list(CHANNEL_TO_URL_CODE.keys())
35
 
36
  def setup_and_load_model():
37
  if "model" in APP_CACHE:
 
74
  APP_CACHE["model"] = model
75
  yield "✅ Model setup complete."
76
 
77
+ def fetch_browse_image(channel, target_dt, max_retries=15):
78
+ url_code = CHANNEL_TO_URL_CODE[channel]
79
+ base_url = "https://sdo.gsfc.nasa.gov/assets/img/browse"
80
+
81
+ for i in range(max_retries):
82
+ dt_to_try = target_dt - datetime.timedelta(minutes=i)
83
+ date_str = dt_to_try.strftime("%Y/%m/%d")
84
+ img_str = dt_to_try.strftime(f"%Y%m%d_%H%M%S_4096_{url_code}.jpg")
85
+ url = f"{base_url}/{date_str}/{img_str}"
86
+
87
+ response = requests.get(url)
88
+ if response.status_code == 200:
89
+ logger.info(f"Successfully found image for {channel} at {dt_to_try}")
90
+ return Image.open(BytesIO(response.content))
91
+
92
+ raise FileNotFoundError(f"Could not find any recent image for {channel} within {max_retries} minutes of {target_dt}.")
93
+
94
  def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
95
  config = APP_CACHE["config"]
96
  img_size = config["model"]["img_size"]
97
 
98
  input_deltas = config["data"]["time_delta_input_minutes"]
 
99
  input_times = [target_dt + datetime.timedelta(minutes=m) for m in input_deltas]
100
+ target_time = target_dt + datetime.timedelta(minutes=forecast_horizon_minutes)
101
  all_times = sorted(list(set(input_times + [target_time])))
102
 
103
+ images = {}
104
+ total_fetches = len(all_times) * len(SDO_CHANNELS)
105
+ fetches_done = 0
106
+ yield f"Starting search for {total_fetches} data files..."
107
+
108
  for t in all_times:
109
+ images[t] = {}
110
+ for channel in SDO_CHANNELS:
111
+ fetches_done += 1
112
+ yield f"Searching [{fetches_done}/{total_fetches}]: {channel} near {t.strftime('%Y-%m-%d %H:%M')}..."
113
+ images[t][channel] = fetch_browse_image(channel, t)
 
 
 
 
 
 
114
 
115
+ yield "✅ All images found. Starting preprocessing..."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  scaler = APP_CACHE["scalers"]
117
  processed_tensors = {}
118
+ for t, channel_images in images.items():
119
  channel_tensors = []
120
  for i, channel in enumerate(SDO_CHANNELS):
121
+ img = channel_images[channel]
122
+ if img.mode != 'L':
123
+ img = img.convert('L')
124
 
125
+ img_resized = img.resize((img_size, img_size), Image.Resampling.LANCZOS)
126
+ norm_data = np.array(img_resized, dtype=np.float32)
 
127
 
128
  scaled_data = scaler.transform(norm_data.reshape(-1, 1), c_idx=i).reshape(norm_data.shape)
129
  channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32)))
 
132
  yield "✅ Preprocessing complete."
133
  input_tensor_list = [processed_tensors[t] for t in input_times]
134
  input_tensor = torch.stack(input_tensor_list, dim=1).unsqueeze(0)
135
+ target_image_map = images[target_time]
136
+ last_input_image_map = images[input_times[-1]]
137
 
138
+ yield (input_tensor, last_input_image_map, target_image_map)
139
 
140
  def run_inference(input_tensor):
141
  model = APP_CACHE["model"]
 
157
  pred_slice = inverse_transform_single_channel(
158
  prediction_tensor[0, c_idx].numpy(), mean=mean, std=std, epsilon=epsilon, sl_scale_factor=sl_scale_factor
159
  )
160
+
161
+ target_img_data = np.array(target_map[channel_name])
162
+ vmax = np.quantile(np.nan_to_num(target_img_data), 0.995)
163
  cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
164
  cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray'))
165
+
166
+ def to_pil(data):
167
  data_clipped = np.nan_to_num(data)
168
  data_clipped = np.clip(data_clipped, 0, vmax)
169
  data_norm = data_clipped / vmax if vmax > 0 else data_clipped
170
  colored = (cmap(data_norm)[:, :, :3] * 255).astype(np.uint8)
171
+ return Image.fromarray(colored)
172
+
173
+ return last_input_map[channel_name], to_pil(pred_slice), target_map[channel_name]
174
 
175
  def forecast_controller(date_str, hour, minute, forecast_horizon):
176
  yield {
 
244
  # ☀️ Surya: Live Forecast Demo ☀️
245
  ### A Foundation Model for Solar Dynamics
246
  This demo runs NASA's **Surya**, a foundation model trained to understand the physics of the Sun.
247
+ It looks at the Sun in 13 different channels (wavelengths of light) simultaneously to learn the complex relationships between phenomena like coronal loops, magnetic fields, and solar flares. By seeing these interconnected views, it can generate a holistic forecast of what the entire solar disk will look like in the near future.
248
+ <br>
249
+ <p style="color:red;font-weight:bold;">NOTE: This demo uses lower-quality browse images for reliability. The model was trained on high-fidelity scientific data, so forecast accuracy may vary.</p>
250
  </div>
251
  """
252
  )
 
254
  with gr.Accordion("Step 1: Configure Forecast", open=True):
255
  with gr.Row():
256
  date_input = gr.Textbox(
257
+ label="Date (YYYY-MM-DD)",
258
+ value=(datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=3)).strftime("%Y-%m-%d")
259
  )
260
+ hour_slider = gr.Slider(label="Hour (UTC)", minimum=0, maximum=23, step=1, value=(datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=3)).hour)
261
+ minute_slider = gr.Slider(label="Minute (UTC)", minimum=0, maximum=59, step=1, value=(datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=3)).minute)
262
  horizon_slider = gr.Slider(
263
  label="Forecast Horizon (minutes ahead)",
264
  minimum=12, maximum=120, step=12, value=12
 
266
 
267
  run_button = gr.Button("🔮 Generate Forecast", variant="primary")
268
 
269
+ with gr.Accordion("Step 2: View Log", open=False):
270
  log_box = gr.Textbox(label="Log", interactive=False, visible=False, lines=5, max_lines=10)
271
 
272
  with gr.Group(visible=False) as results_group:
 
296
  )
297
 
298
  if __name__ == "__main__":
 
299
  demo.launch(debug=True)