broadfield-dev commited on
Commit
bf136f8
·
verified ·
1 Parent(s): d7970b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -35
app.py CHANGED
@@ -22,7 +22,6 @@ import sunpy.visualization.colormaps as sunpy_cm
22
  # --- Use the official Surya modules ---
23
  from surya.models.helio_spectformer import HelioSpectFormer
24
  from surya.utils.data import build_scalers
25
- # *** FIX: Corrected import location for the inverse transform function ***
26
  from surya.datasets.helio import inverse_transform_single_channel
27
 
28
  # --- Configuration ---
@@ -33,15 +32,17 @@ logger = logging.getLogger(__name__)
33
 
34
  # Global cache for model, config, etc.
35
  APP_CACHE = {}
 
 
36
  SDO_CHANNELS_MAP = {
37
- "aia94": (a.Wavelength(94, 94, "angstrom"), a.Sample(12 * u.s)),
38
- "aia131": (a.Wavelength(131, 131, "angstrom"), a.Sample(12 * u.s)),
39
- "aia171": (a.Wavelength(171, 171, "angstrom"), a.Sample(12 * u.s)),
40
- "aia193": (a.Wavelength(193, 193, "angstrom"), a.Sample(12 * u.s)),
41
- "aia211": (a.Wavelength(211, 211, "angstrom"), a.Sample(12 * u.s)),
42
- "aia304": (a.Wavelength(304, 304, "angstrom"), a.Sample(12 * u.s)),
43
- "aia335": (a.Wavelength(335, 335, "angstrom"), a.Sample(12 * u.s)),
44
- "aia1600": (a.Wavelength(1600, 1600, "angstrom"), a.Sample(24 * u.s)),
45
  "hmi_m": (a.Physobs("intensity"), a.Sample(45 * u.s)),
46
  "hmi_bx": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)),
47
  "hmi_by": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)), # Placeholder
@@ -101,11 +102,12 @@ def fetch_and_process_sdo_data(target_dt, progress):
101
  all_times = sorted(list(set(input_times + [target_time])))
102
 
103
  data_maps = {}
104
- total_downloads = len(all_times) * len(SDO_CHANNELS_MAP)
105
  downloads_done = 0
106
  for t in all_times:
107
  data_maps[t] = {}
108
  for i, (channel, (physobs, sample)) in enumerate(SDO_CHANNELS_MAP.items()):
 
109
  progress(downloads_done / total_downloads, desc=f"Downloading {channel} for {t.strftime('%H:%M')}...")
110
 
111
  if channel in ["hmi_by", "hmi_bz"]:
@@ -113,19 +115,13 @@ def fetch_and_process_sdo_data(target_dt, progress):
113
  continue
114
 
115
  time_attr = a.Time(t - datetime.timedelta(minutes=10), t + datetime.timedelta(minutes=10))
116
- search_query_list = [time_attr, physobs, sample]
117
- if "aia" in channel:
118
- search_query_list.append(a.Instrument.aia)
119
- else:
120
- search_query_list.append(a.Instrument.hmi)
121
-
122
- query = Fido.search(*search_query_list)
123
 
124
  if not query: raise ValueError(f"No data found for {channel} at {t}")
125
  files = Fido.fetch(query[0, 0], path="./data/sdo_cache")
126
  data_maps[t][channel] = sunpy.map.Map(files[0])
127
- downloads_done += 1
128
-
129
  output_wcs = WCS(naxis=2)
130
  output_wcs.wcs.crpix = [(img_size + 1) / 2, (img_size + 1) / 2]
131
  output_wcs.wcs.cdelt = np.array([-1.2, 1.2]) * u.arcsec
@@ -133,14 +129,10 @@ def fetch_and_process_sdo_data(target_dt, progress):
133
  output_wcs.wcs.ctype = ['HPLN-TAN', 'HPLT-TAN']
134
 
135
  processed_tensors = {}
136
- total_processing = len(all_times) * len(SDO_CHANNELS)
137
- processing_done = 0
138
  for t, channel_maps in data_maps.items():
139
  channel_tensors = []
140
  for i, channel in enumerate(SDO_CHANNELS):
141
- progress(processing_done / total_processing, desc=f"Processing {channel} for {t.strftime('%H:%M')}...")
142
  smap = channel_maps[channel]
143
-
144
  reprojected_data, _ = reproject_interp(smap, output_wcs, shape_out=(img_size, img_size))
145
 
146
  exp_time = smap.meta.get('exptime', 1.0)
@@ -150,7 +142,6 @@ def fetch_and_process_sdo_data(target_dt, progress):
150
  scaler = APP_CACHE["scalers"][channel]
151
  scaled_data = scaler.transform(norm_data)
152
  channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32)))
153
- processing_done += 1
154
  processed_tensors[t] = torch.stack(channel_tensors)
155
 
156
  input_tensor_list = [processed_tensors[t] for t in input_times]
@@ -178,15 +169,12 @@ def run_inference(input_tensor):
178
  return prediction.cpu()
179
 
180
  def generate_visualization(last_input_map, prediction_tensor, target_map, channel_name):
181
- if last_input_map is None:
182
- return None, None, None
183
 
184
  c_idx = SDO_CHANNELS.index(channel_name)
185
-
186
- means, stds, epsilons, sl_scale_factors = APP_CACHE["scalers"][SDO_CHANNELS[c_idx]].get_params()
187
  pred_slice = inverse_transform_single_channel(
188
- prediction_tensor[0, c_idx].numpy(),
189
- mean=means, std=stds, epsilon=epsilons, sl_scale_factor=sl_scale_factors
190
  )
191
 
192
  vmax = np.quantile(np.nan_to_num(target_map[channel_name].data), 0.995)
@@ -206,8 +194,7 @@ def generate_visualization(last_input_map, prediction_tensor, target_map, channe
206
  # --- 4. Gradio UI and Controllers ---
207
  def forecast_controller(dt_str, progress=gr.Progress(track_tqdm=True)):
208
  try:
209
- if not dt_str:
210
- raise gr.Error("Please select a date and time.")
211
 
212
  progress(0, desc="Initializing...")
213
  setup_and_load_model(progress)
@@ -216,7 +203,6 @@ def forecast_controller(dt_str, progress=gr.Progress(track_tqdm=True)):
216
  logger.info(f"Starting forecast for target time: {target_dt}")
217
 
218
  input_tensor, last_input_map, target_map = fetch_and_process_sdo_data(target_dt, progress)
219
-
220
  prediction_tensor = run_inference(input_tensor)
221
 
222
  img_in, img_pred, img_target = generate_visualization(last_input_map, prediction_tensor, target_map, "aia171")
@@ -232,8 +218,7 @@ def forecast_controller(dt_str, progress=gr.Progress(track_tqdm=True)):
232
  raise gr.Error(f"Failed to generate forecast. Error: {e}")
233
 
234
  def update_visualization_controller(last_input_map, prediction_tensor, target_map, channel_name):
235
- if last_input_map is None:
236
- return None, None, None
237
  return generate_visualization(last_input_map, prediction_tensor, target_map, channel_name)
238
 
239
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
 
22
  # --- Use the official Surya modules ---
23
  from surya.models.helio_spectformer import HelioSpectFormer
24
  from surya.utils.data import build_scalers
 
25
  from surya.datasets.helio import inverse_transform_single_channel
26
 
27
  # --- Configuration ---
 
32
 
33
  # Global cache for model, config, etc.
34
  APP_CACHE = {}
35
+
36
+ # *** FIX: Corrected the a.Wavelength calls to use astropy units ***
37
  SDO_CHANNELS_MAP = {
38
+ "aia94": (a.Wavelength(94 * u.angstrom), a.Sample(12 * u.s)),
39
+ "aia131": (a.Wavelength(131 * u.angstrom), a.Sample(12 * u.s)),
40
+ "aia171": (a.Wavelength(171 * u.angstrom), a.Sample(12 * u.s)),
41
+ "aia193": (a.Wavelength(193 * u.angstrom), a.Sample(12 * u.s)),
42
+ "aia211": (a.Wavelength(211 * u.angstrom), a.Sample(12 * u.s)),
43
+ "aia304": (a.Wavelength(304 * u.angstrom), a.Sample(12 * u.s)),
44
+ "aia335": (a.Wavelength(335 * u.angstrom), a.Sample(12 * u.s)),
45
+ "aia1600": (a.Wavelength(1600 * u.angstrom), a.Sample(24 * u.s)),
46
  "hmi_m": (a.Physobs("intensity"), a.Sample(45 * u.s)),
47
  "hmi_bx": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)),
48
  "hmi_by": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)), # Placeholder
 
102
  all_times = sorted(list(set(input_times + [target_time])))
103
 
104
  data_maps = {}
105
+ total_downloads = len(all_times) * len(SDO_CHANNELS)
106
  downloads_done = 0
107
  for t in all_times:
108
  data_maps[t] = {}
109
  for i, (channel, (physobs, sample)) in enumerate(SDO_CHANNELS_MAP.items()):
110
+ downloads_done += 1
111
  progress(downloads_done / total_downloads, desc=f"Downloading {channel} for {t.strftime('%H:%M')}...")
112
 
113
  if channel in ["hmi_by", "hmi_bz"]:
 
115
  continue
116
 
117
  time_attr = a.Time(t - datetime.timedelta(minutes=10), t + datetime.timedelta(minutes=10))
118
+ instrument = a.Instrument.hmi if "hmi" in channel else a.Instrument.aia
119
+ query = Fido.search(time_attr, instrument, physobs, sample)
 
 
 
 
 
120
 
121
  if not query: raise ValueError(f"No data found for {channel} at {t}")
122
  files = Fido.fetch(query[0, 0], path="./data/sdo_cache")
123
  data_maps[t][channel] = sunpy.map.Map(files[0])
124
+
 
125
  output_wcs = WCS(naxis=2)
126
  output_wcs.wcs.crpix = [(img_size + 1) / 2, (img_size + 1) / 2]
127
  output_wcs.wcs.cdelt = np.array([-1.2, 1.2]) * u.arcsec
 
129
  output_wcs.wcs.ctype = ['HPLN-TAN', 'HPLT-TAN']
130
 
131
  processed_tensors = {}
 
 
132
  for t, channel_maps in data_maps.items():
133
  channel_tensors = []
134
  for i, channel in enumerate(SDO_CHANNELS):
 
135
  smap = channel_maps[channel]
 
136
  reprojected_data, _ = reproject_interp(smap, output_wcs, shape_out=(img_size, img_size))
137
 
138
  exp_time = smap.meta.get('exptime', 1.0)
 
142
  scaler = APP_CACHE["scalers"][channel]
143
  scaled_data = scaler.transform(norm_data)
144
  channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32)))
 
145
  processed_tensors[t] = torch.stack(channel_tensors)
146
 
147
  input_tensor_list = [processed_tensors[t] for t in input_times]
 
169
  return prediction.cpu()
170
 
171
  def generate_visualization(last_input_map, prediction_tensor, target_map, channel_name):
172
+ if last_input_map is None: return None, None, None
 
173
 
174
  c_idx = SDO_CHANNELS.index(channel_name)
175
+ means, stds, epsilons, sl_scale_factors = APP_CACHE["scalers"][channel_name].get_params()
 
176
  pred_slice = inverse_transform_single_channel(
177
+ prediction_tensor[0, c_idx].numpy(), mean=means, std=stds, epsilon=epsilons, sl_scale_factor=sl_scale_factors
 
178
  )
179
 
180
  vmax = np.quantile(np.nan_to_num(target_map[channel_name].data), 0.995)
 
194
  # --- 4. Gradio UI and Controllers ---
195
  def forecast_controller(dt_str, progress=gr.Progress(track_tqdm=True)):
196
  try:
197
+ if not dt_str: raise gr.Error("Please select a date and time.")
 
198
 
199
  progress(0, desc="Initializing...")
200
  setup_and_load_model(progress)
 
203
  logger.info(f"Starting forecast for target time: {target_dt}")
204
 
205
  input_tensor, last_input_map, target_map = fetch_and_process_sdo_data(target_dt, progress)
 
206
  prediction_tensor = run_inference(input_tensor)
207
 
208
  img_in, img_pred, img_target = generate_visualization(last_input_map, prediction_tensor, target_map, "aia171")
 
218
  raise gr.Error(f"Failed to generate forecast. Error: {e}")
219
 
220
  def update_visualization_controller(last_input_map, prediction_tensor, target_map, channel_name):
221
+ if last_input_map is None: return None, None, None
 
222
  return generate_visualization(last_input_map, prediction_tensor, target_map, channel_name)
223
 
224
  with gr.Blocks(theme=gr.themes.Soft()) as demo: