broadfield-dev commited on
Commit
d6afb4c
·
verified ·
1 Parent(s): acedaa8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -27
app.py CHANGED
@@ -2,7 +2,6 @@
2
 
3
  import gradio as gr
4
  import torch
5
- from torch.utils.data import DataLoader
6
  from huggingface_hub import snapshot_download
7
  import yaml
8
  import numpy as np
@@ -10,11 +9,9 @@ from PIL import Image
10
  import sunpy.map
11
  import sunpy.net.attrs as a
12
  from sunpy.net import Fido
13
- from sunpy.coordinates import Helioprojective
14
- from astropy.coordinates import SkyCoord
15
  from astropy.wcs import WCS
16
  import astropy.units as u
17
- from reproject import reproject_interp
18
  import os
19
  import warnings
20
  import logging
@@ -27,7 +24,8 @@ from surya.models.helio_spectformer import HelioSpectFormer
27
  from surya.utils.data import build_scalers, inverse_transform_single_channel
28
 
29
  # --- Configuration ---
30
- warnings.filterwarnings("ignore")
 
31
  logging.basicConfig(level=logging.INFO)
32
  logger = logging.getLogger(__name__)
33
 
@@ -94,14 +92,12 @@ def fetch_and_process_sdo_data(target_dt, progress):
94
  config = APP_CACHE["config"]
95
  img_size = config["model"]["img_size"][0]
96
 
97
- # Define time windows for input and target (ground truth)
98
  input_deltas = config["data"]["time_delta_input_minutes"]
99
  target_delta = config["data"]["time_delta_target_minutes"][0]
100
  input_times = [target_dt + datetime.timedelta(minutes=m) for m in input_deltas]
101
  target_time = target_dt + datetime.timedelta(minutes=target_delta)
102
  all_times = sorted(list(set(input_times + [target_time])))
103
 
104
- # Download data for all required timestamps
105
  data_maps = {}
106
  total_downloads = len(all_times) * len(SDO_CHANNELS_MAP)
107
  downloads_done = 0
@@ -110,53 +106,56 @@ def fetch_and_process_sdo_data(target_dt, progress):
110
  for i, (channel, (physobs, sample)) in enumerate(SDO_CHANNELS_MAP.items()):
111
  progress(downloads_done / total_downloads, desc=f"Downloading {channel} for {t.strftime('%H:%M')}...")
112
 
113
- # HMI vector fields are not standard products, use LoS as a placeholder for demo
114
  instrument = a.Instrument.hmi if "hmi" in channel else a.Instrument.aia
115
  if channel in ["hmi_by", "hmi_bz"]:
116
  if data_maps[t].get("hmi_bx"): data_maps[t][channel] = data_maps[t]["hmi_bx"]
117
  continue
118
 
119
  time_attr = a.Time(t - datetime.timedelta(minutes=10), t + datetime.timedelta(minutes=10))
120
- query = Fido.search(time_attr, a.Instrument.aia, physobs, sample) if "aia" in channel else Fido.search(time_attr, a.Instrument.hmi, physobs, sample)
 
 
 
 
 
 
 
121
 
122
  if not query: raise ValueError(f"No data found for {channel} at {t}")
123
  files = Fido.fetch(query[0, 0], path="./data/sdo_cache")
124
  data_maps[t][channel] = sunpy.map.Map(files[0])
125
  downloads_done += 1
126
 
127
- # Create target WCS for reprojection
128
  output_wcs = WCS(naxis=2)
129
  output_wcs.wcs.crpix = [(img_size + 1) / 2, (img_size + 1) / 2]
130
  output_wcs.wcs.cdelt = np.array([-1.2, 1.2]) * u.arcsec
131
  output_wcs.wcs.crval = [0, 0] * u.arcsec
132
  output_wcs.wcs.ctype = ['HPLN-TAN', 'HPLT-TAN']
133
 
134
- # Process data
135
  processed_tensors = {}
 
 
136
  for t, channel_maps in data_maps.items():
137
  channel_tensors = []
138
  for i, channel in enumerate(SDO_CHANNELS):
139
- progress(i / len(SDO_CHANNELS), desc=f"Processing {channel} for {t.strftime('%H:%M')}...")
140
  smap = channel_maps[channel]
141
 
142
- # Reproject to common grid
143
  reprojected_data, _ = reproject_interp(smap, output_wcs, shape_out=(img_size, img_size))
144
 
145
- # Normalize by exposure time and apply signed-log transform
146
  exp_time = smap.meta.get('exptime', 1.0)
147
- if exp_time <= 0: exp_time = 1.0
148
  norm_data = reprojected_data / exp_time
149
 
150
- # Apply the same scaling as the training pipeline
151
  scaler = APP_CACHE["scalers"][channel]
152
  scaled_data = scaler.transform(norm_data)
153
  channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32)))
 
154
  processed_tensors[t] = torch.stack(channel_tensors)
155
 
156
- # Assemble final input and target tensors
157
  input_tensor_list = [processed_tensors[t] for t in input_times]
158
- input_tensor = torch.stack(input_tensor_list, dim=1).unsqueeze(0) # Add batch dim
159
- target_map = data_maps[target_time] # Return raw map for ground truth vis
160
  last_input_map = data_maps[input_times[-1]]
161
 
162
  return input_tensor, last_input_map, target_map
@@ -191,8 +190,7 @@ def generate_visualization(last_input_map, prediction_tensor, target_map, channe
191
  mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx]
192
  )
193
 
194
- # Get colormap and normalization
195
- vmax = np.quantile(target_map[channel_name].data, 0.995)
196
  cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
197
  cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray'))
198
 
@@ -206,7 +204,6 @@ def generate_visualization(last_input_map, prediction_tensor, target_map, channe
206
 
207
  return to_pil(last_input_map[channel_name].data), to_pil(pred_slice, flip=True), to_pil(target_map[channel_name].data)
208
 
209
-
210
  # --- 4. Gradio UI and Controllers ---
211
  def forecast_controller(dt_str, progress=gr.Progress(track_tqdm=True)):
212
  try:
@@ -223,13 +220,12 @@ def forecast_controller(dt_str, progress=gr.Progress(track_tqdm=True)):
223
 
224
  prediction_tensor = run_inference(input_tensor)
225
 
226
- # Default visualization for aia171
227
  img_in, img_pred, img_target = generate_visualization(last_input_map, prediction_tensor, target_map, "aia171")
228
 
229
  status = f"Forecast complete for {target_dt.isoformat()}. Ready to explore channels."
230
  logger.info(status)
231
 
232
- return (last_input_map, prediction_tensor, target_map, # state
233
  img_in, img_pred, img_target, status, gr.update(visible=True))
234
 
235
  except Exception as e:
@@ -243,7 +239,6 @@ def update_visualization_controller(last_input_map, prediction_tensor, target_ma
243
 
244
 
245
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
246
- # State objects to hold the data after a forecast is run
247
  state_last_input = gr.State()
248
  state_prediction = gr.State()
249
  state_target = gr.State()
@@ -263,7 +258,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
263
 
264
  with gr.Row():
265
  datetime_input = gr.Textbox(label="Enter Forecast Start Time (YYYY-MM-DD HH:MM:SS)",
266
- value=(datetime.datetime.now() - datetime.timedelta(hours=2)).strftime("%Y-%m-%d %H:%M:%S"))
267
  run_button = gr.Button("🔮 Generate Forecast", variant="primary")
268
 
269
  with gr.Group(visible=False) as results_group:
@@ -288,6 +283,5 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
288
  )
289
 
290
  if __name__ == "__main__":
291
- # Create cache directory if it doesn't exist
292
  os.makedirs("./data/sdo_cache", exist_ok=True)
293
  demo.launch(debug=True)
 
2
 
3
  import gradio as gr
4
  import torch
 
5
  from huggingface_hub import snapshot_download
6
  import yaml
7
  import numpy as np
 
9
  import sunpy.map
10
  import sunpy.net.attrs as a
11
  from sunpy.net import Fido
 
 
12
  from astropy.wcs import WCS
13
  import astropy.units as u
14
+ from reproject import reproject_interp # Correct import statement
15
  import os
16
  import warnings
17
  import logging
 
24
  from surya.utils.data import build_scalers, inverse_transform_single_channel
25
 
26
  # --- Configuration ---
27
+ warnings.filterwarnings("ignore", category=UserWarning, module='sunpy')
28
+ warnings.filterwarnings("ignore", category=FutureWarning)
29
  logging.basicConfig(level=logging.INFO)
30
  logger = logging.getLogger(__name__)
31
 
 
92
  config = APP_CACHE["config"]
93
  img_size = config["model"]["img_size"][0]
94
 
 
95
  input_deltas = config["data"]["time_delta_input_minutes"]
96
  target_delta = config["data"]["time_delta_target_minutes"][0]
97
  input_times = [target_dt + datetime.timedelta(minutes=m) for m in input_deltas]
98
  target_time = target_dt + datetime.timedelta(minutes=target_delta)
99
  all_times = sorted(list(set(input_times + [target_time])))
100
 
 
101
  data_maps = {}
102
  total_downloads = len(all_times) * len(SDO_CHANNELS_MAP)
103
  downloads_done = 0
 
106
  for i, (channel, (physobs, sample)) in enumerate(SDO_CHANNELS_MAP.items()):
107
  progress(downloads_done / total_downloads, desc=f"Downloading {channel} for {t.strftime('%H:%M')}...")
108
 
 
109
  instrument = a.Instrument.hmi if "hmi" in channel else a.Instrument.aia
110
  if channel in ["hmi_by", "hmi_bz"]:
111
  if data_maps[t].get("hmi_bx"): data_maps[t][channel] = data_maps[t]["hmi_bx"]
112
  continue
113
 
114
  time_attr = a.Time(t - datetime.timedelta(minutes=10), t + datetime.timedelta(minutes=10))
115
+ search_query = [time_attr, physobs, sample]
116
+ # AIA and HMI queries are slightly different
117
+ if "aia" in channel:
118
+ search_query.append(a.Instrument.aia)
119
+ else:
120
+ search_query.append(a.Instrument.hmi)
121
+
122
+ query = Fido.search(*search_query)
123
 
124
  if not query: raise ValueError(f"No data found for {channel} at {t}")
125
  files = Fido.fetch(query[0, 0], path="./data/sdo_cache")
126
  data_maps[t][channel] = sunpy.map.Map(files[0])
127
  downloads_done += 1
128
 
 
129
  output_wcs = WCS(naxis=2)
130
  output_wcs.wcs.crpix = [(img_size + 1) / 2, (img_size + 1) / 2]
131
  output_wcs.wcs.cdelt = np.array([-1.2, 1.2]) * u.arcsec
132
  output_wcs.wcs.crval = [0, 0] * u.arcsec
133
  output_wcs.wcs.ctype = ['HPLN-TAN', 'HPLT-TAN']
134
 
 
135
  processed_tensors = {}
136
+ total_processing = len(all_times) * len(SDO_CHANNELS)
137
+ processing_done = 0
138
  for t, channel_maps in data_maps.items():
139
  channel_tensors = []
140
  for i, channel in enumerate(SDO_CHANNELS):
141
+ progress(processing_done / total_processing, desc=f"Processing {channel} for {t.strftime('%H:%M')}...")
142
  smap = channel_maps[channel]
143
 
 
144
  reprojected_data, _ = reproject_interp(smap, output_wcs, shape_out=(img_size, img_size))
145
 
 
146
  exp_time = smap.meta.get('exptime', 1.0)
147
+ if exp_time is None or exp_time <= 0: exp_time = 1.0
148
  norm_data = reprojected_data / exp_time
149
 
 
150
  scaler = APP_CACHE["scalers"][channel]
151
  scaled_data = scaler.transform(norm_data)
152
  channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32)))
153
+ processing_done += 1
154
  processed_tensors[t] = torch.stack(channel_tensors)
155
 
 
156
  input_tensor_list = [processed_tensors[t] for t in input_times]
157
+ input_tensor = torch.stack(input_tensor_list, dim=1).unsqueeze(0)
158
+ target_map = data_maps[target_time]
159
  last_input_map = data_maps[input_times[-1]]
160
 
161
  return input_tensor, last_input_map, target_map
 
190
  mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx]
191
  )
192
 
193
+ vmax = np.quantile(np.nan_to_num(target_map[channel_name].data), 0.995)
 
194
  cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
195
  cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray'))
196
 
 
204
 
205
  return to_pil(last_input_map[channel_name].data), to_pil(pred_slice, flip=True), to_pil(target_map[channel_name].data)
206
 
 
207
  # --- 4. Gradio UI and Controllers ---
208
  def forecast_controller(dt_str, progress=gr.Progress(track_tqdm=True)):
209
  try:
 
220
 
221
  prediction_tensor = run_inference(input_tensor)
222
 
 
223
  img_in, img_pred, img_target = generate_visualization(last_input_map, prediction_tensor, target_map, "aia171")
224
 
225
  status = f"Forecast complete for {target_dt.isoformat()}. Ready to explore channels."
226
  logger.info(status)
227
 
228
+ return (last_input_map, prediction_tensor, target_map,
229
  img_in, img_pred, img_target, status, gr.update(visible=True))
230
 
231
  except Exception as e:
 
239
 
240
 
241
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
 
242
  state_last_input = gr.State()
243
  state_prediction = gr.State()
244
  state_target = gr.State()
 
258
 
259
  with gr.Row():
260
  datetime_input = gr.Textbox(label="Enter Forecast Start Time (YYYY-MM-DD HH:MM:SS)",
261
+ value=(datetime.datetime.now() - datetime.timedelta(hours=3)).strftime("%Y-%m-%d %H:%M:%S"))
262
  run_button = gr.Button("🔮 Generate Forecast", variant="primary")
263
 
264
  with gr.Group(visible=False) as results_group:
 
283
  )
284
 
285
  if __name__ == "__main__":
 
286
  os.makedirs("./data/sdo_cache", exist_ok=True)
287
  demo.launch(debug=True)