Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -22,7 +22,6 @@ import sunpy.visualization.colormaps as sunpy_cm
|
|
22 |
# --- Use the official Surya modules ---
|
23 |
from surya.models.helio_spectformer import HelioSpectFormer
|
24 |
from surya.utils.data import build_scalers
|
25 |
-
# *** FIX: Corrected import location for the inverse transform function ***
|
26 |
from surya.datasets.helio import inverse_transform_single_channel
|
27 |
|
28 |
# --- Configuration ---
|
@@ -33,15 +32,17 @@ logger = logging.getLogger(__name__)
|
|
33 |
|
34 |
# Global cache for model, config, etc.
|
35 |
APP_CACHE = {}
|
|
|
|
|
36 |
SDO_CHANNELS_MAP = {
|
37 |
-
"aia94": (a.Wavelength(94
|
38 |
-
"aia131": (a.Wavelength(131
|
39 |
-
"aia171": (a.Wavelength(171
|
40 |
-
"aia193": (a.Wavelength(193
|
41 |
-
"aia211": (a.Wavelength(211
|
42 |
-
"aia304": (a.Wavelength(304
|
43 |
-
"aia335": (a.Wavelength(335
|
44 |
-
"aia1600": (a.Wavelength(1600
|
45 |
"hmi_m": (a.Physobs("intensity"), a.Sample(45 * u.s)),
|
46 |
"hmi_bx": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)),
|
47 |
"hmi_by": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)), # Placeholder
|
@@ -101,11 +102,12 @@ def fetch_and_process_sdo_data(target_dt, progress):
|
|
101 |
all_times = sorted(list(set(input_times + [target_time])))
|
102 |
|
103 |
data_maps = {}
|
104 |
-
total_downloads = len(all_times) * len(
|
105 |
downloads_done = 0
|
106 |
for t in all_times:
|
107 |
data_maps[t] = {}
|
108 |
for i, (channel, (physobs, sample)) in enumerate(SDO_CHANNELS_MAP.items()):
|
|
|
109 |
progress(downloads_done / total_downloads, desc=f"Downloading {channel} for {t.strftime('%H:%M')}...")
|
110 |
|
111 |
if channel in ["hmi_by", "hmi_bz"]:
|
@@ -113,19 +115,13 @@ def fetch_and_process_sdo_data(target_dt, progress):
|
|
113 |
continue
|
114 |
|
115 |
time_attr = a.Time(t - datetime.timedelta(minutes=10), t + datetime.timedelta(minutes=10))
|
116 |
-
|
117 |
-
|
118 |
-
search_query_list.append(a.Instrument.aia)
|
119 |
-
else:
|
120 |
-
search_query_list.append(a.Instrument.hmi)
|
121 |
-
|
122 |
-
query = Fido.search(*search_query_list)
|
123 |
|
124 |
if not query: raise ValueError(f"No data found for {channel} at {t}")
|
125 |
files = Fido.fetch(query[0, 0], path="./data/sdo_cache")
|
126 |
data_maps[t][channel] = sunpy.map.Map(files[0])
|
127 |
-
|
128 |
-
|
129 |
output_wcs = WCS(naxis=2)
|
130 |
output_wcs.wcs.crpix = [(img_size + 1) / 2, (img_size + 1) / 2]
|
131 |
output_wcs.wcs.cdelt = np.array([-1.2, 1.2]) * u.arcsec
|
@@ -133,14 +129,10 @@ def fetch_and_process_sdo_data(target_dt, progress):
|
|
133 |
output_wcs.wcs.ctype = ['HPLN-TAN', 'HPLT-TAN']
|
134 |
|
135 |
processed_tensors = {}
|
136 |
-
total_processing = len(all_times) * len(SDO_CHANNELS)
|
137 |
-
processing_done = 0
|
138 |
for t, channel_maps in data_maps.items():
|
139 |
channel_tensors = []
|
140 |
for i, channel in enumerate(SDO_CHANNELS):
|
141 |
-
progress(processing_done / total_processing, desc=f"Processing {channel} for {t.strftime('%H:%M')}...")
|
142 |
smap = channel_maps[channel]
|
143 |
-
|
144 |
reprojected_data, _ = reproject_interp(smap, output_wcs, shape_out=(img_size, img_size))
|
145 |
|
146 |
exp_time = smap.meta.get('exptime', 1.0)
|
@@ -150,7 +142,6 @@ def fetch_and_process_sdo_data(target_dt, progress):
|
|
150 |
scaler = APP_CACHE["scalers"][channel]
|
151 |
scaled_data = scaler.transform(norm_data)
|
152 |
channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32)))
|
153 |
-
processing_done += 1
|
154 |
processed_tensors[t] = torch.stack(channel_tensors)
|
155 |
|
156 |
input_tensor_list = [processed_tensors[t] for t in input_times]
|
@@ -178,15 +169,12 @@ def run_inference(input_tensor):
|
|
178 |
return prediction.cpu()
|
179 |
|
180 |
def generate_visualization(last_input_map, prediction_tensor, target_map, channel_name):
|
181 |
-
if last_input_map is None:
|
182 |
-
return None, None, None
|
183 |
|
184 |
c_idx = SDO_CHANNELS.index(channel_name)
|
185 |
-
|
186 |
-
means, stds, epsilons, sl_scale_factors = APP_CACHE["scalers"][SDO_CHANNELS[c_idx]].get_params()
|
187 |
pred_slice = inverse_transform_single_channel(
|
188 |
-
prediction_tensor[0, c_idx].numpy(),
|
189 |
-
mean=means, std=stds, epsilon=epsilons, sl_scale_factor=sl_scale_factors
|
190 |
)
|
191 |
|
192 |
vmax = np.quantile(np.nan_to_num(target_map[channel_name].data), 0.995)
|
@@ -206,8 +194,7 @@ def generate_visualization(last_input_map, prediction_tensor, target_map, channe
|
|
206 |
# --- 4. Gradio UI and Controllers ---
|
207 |
def forecast_controller(dt_str, progress=gr.Progress(track_tqdm=True)):
|
208 |
try:
|
209 |
-
if not dt_str:
|
210 |
-
raise gr.Error("Please select a date and time.")
|
211 |
|
212 |
progress(0, desc="Initializing...")
|
213 |
setup_and_load_model(progress)
|
@@ -216,7 +203,6 @@ def forecast_controller(dt_str, progress=gr.Progress(track_tqdm=True)):
|
|
216 |
logger.info(f"Starting forecast for target time: {target_dt}")
|
217 |
|
218 |
input_tensor, last_input_map, target_map = fetch_and_process_sdo_data(target_dt, progress)
|
219 |
-
|
220 |
prediction_tensor = run_inference(input_tensor)
|
221 |
|
222 |
img_in, img_pred, img_target = generate_visualization(last_input_map, prediction_tensor, target_map, "aia171")
|
@@ -232,8 +218,7 @@ def forecast_controller(dt_str, progress=gr.Progress(track_tqdm=True)):
|
|
232 |
raise gr.Error(f"Failed to generate forecast. Error: {e}")
|
233 |
|
234 |
def update_visualization_controller(last_input_map, prediction_tensor, target_map, channel_name):
|
235 |
-
if last_input_map is None:
|
236 |
-
return None, None, None
|
237 |
return generate_visualization(last_input_map, prediction_tensor, target_map, channel_name)
|
238 |
|
239 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
|
22 |
# --- Use the official Surya modules ---
|
23 |
from surya.models.helio_spectformer import HelioSpectFormer
|
24 |
from surya.utils.data import build_scalers
|
|
|
25 |
from surya.datasets.helio import inverse_transform_single_channel
|
26 |
|
27 |
# --- Configuration ---
|
|
|
32 |
|
33 |
# Global cache for model, config, etc.
|
34 |
APP_CACHE = {}
|
35 |
+
|
36 |
+
# *** FIX: Corrected the a.Wavelength calls to use astropy units ***
|
37 |
SDO_CHANNELS_MAP = {
|
38 |
+
"aia94": (a.Wavelength(94 * u.angstrom), a.Sample(12 * u.s)),
|
39 |
+
"aia131": (a.Wavelength(131 * u.angstrom), a.Sample(12 * u.s)),
|
40 |
+
"aia171": (a.Wavelength(171 * u.angstrom), a.Sample(12 * u.s)),
|
41 |
+
"aia193": (a.Wavelength(193 * u.angstrom), a.Sample(12 * u.s)),
|
42 |
+
"aia211": (a.Wavelength(211 * u.angstrom), a.Sample(12 * u.s)),
|
43 |
+
"aia304": (a.Wavelength(304 * u.angstrom), a.Sample(12 * u.s)),
|
44 |
+
"aia335": (a.Wavelength(335 * u.angstrom), a.Sample(12 * u.s)),
|
45 |
+
"aia1600": (a.Wavelength(1600 * u.angstrom), a.Sample(24 * u.s)),
|
46 |
"hmi_m": (a.Physobs("intensity"), a.Sample(45 * u.s)),
|
47 |
"hmi_bx": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)),
|
48 |
"hmi_by": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)), # Placeholder
|
|
|
102 |
all_times = sorted(list(set(input_times + [target_time])))
|
103 |
|
104 |
data_maps = {}
|
105 |
+
total_downloads = len(all_times) * len(SDO_CHANNELS)
|
106 |
downloads_done = 0
|
107 |
for t in all_times:
|
108 |
data_maps[t] = {}
|
109 |
for i, (channel, (physobs, sample)) in enumerate(SDO_CHANNELS_MAP.items()):
|
110 |
+
downloads_done += 1
|
111 |
progress(downloads_done / total_downloads, desc=f"Downloading {channel} for {t.strftime('%H:%M')}...")
|
112 |
|
113 |
if channel in ["hmi_by", "hmi_bz"]:
|
|
|
115 |
continue
|
116 |
|
117 |
time_attr = a.Time(t - datetime.timedelta(minutes=10), t + datetime.timedelta(minutes=10))
|
118 |
+
instrument = a.Instrument.hmi if "hmi" in channel else a.Instrument.aia
|
119 |
+
query = Fido.search(time_attr, instrument, physobs, sample)
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
if not query: raise ValueError(f"No data found for {channel} at {t}")
|
122 |
files = Fido.fetch(query[0, 0], path="./data/sdo_cache")
|
123 |
data_maps[t][channel] = sunpy.map.Map(files[0])
|
124 |
+
|
|
|
125 |
output_wcs = WCS(naxis=2)
|
126 |
output_wcs.wcs.crpix = [(img_size + 1) / 2, (img_size + 1) / 2]
|
127 |
output_wcs.wcs.cdelt = np.array([-1.2, 1.2]) * u.arcsec
|
|
|
129 |
output_wcs.wcs.ctype = ['HPLN-TAN', 'HPLT-TAN']
|
130 |
|
131 |
processed_tensors = {}
|
|
|
|
|
132 |
for t, channel_maps in data_maps.items():
|
133 |
channel_tensors = []
|
134 |
for i, channel in enumerate(SDO_CHANNELS):
|
|
|
135 |
smap = channel_maps[channel]
|
|
|
136 |
reprojected_data, _ = reproject_interp(smap, output_wcs, shape_out=(img_size, img_size))
|
137 |
|
138 |
exp_time = smap.meta.get('exptime', 1.0)
|
|
|
142 |
scaler = APP_CACHE["scalers"][channel]
|
143 |
scaled_data = scaler.transform(norm_data)
|
144 |
channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32)))
|
|
|
145 |
processed_tensors[t] = torch.stack(channel_tensors)
|
146 |
|
147 |
input_tensor_list = [processed_tensors[t] for t in input_times]
|
|
|
169 |
return prediction.cpu()
|
170 |
|
171 |
def generate_visualization(last_input_map, prediction_tensor, target_map, channel_name):
|
172 |
+
if last_input_map is None: return None, None, None
|
|
|
173 |
|
174 |
c_idx = SDO_CHANNELS.index(channel_name)
|
175 |
+
means, stds, epsilons, sl_scale_factors = APP_CACHE["scalers"][channel_name].get_params()
|
|
|
176 |
pred_slice = inverse_transform_single_channel(
|
177 |
+
prediction_tensor[0, c_idx].numpy(), mean=means, std=stds, epsilon=epsilons, sl_scale_factor=sl_scale_factors
|
|
|
178 |
)
|
179 |
|
180 |
vmax = np.quantile(np.nan_to_num(target_map[channel_name].data), 0.995)
|
|
|
194 |
# --- 4. Gradio UI and Controllers ---
|
195 |
def forecast_controller(dt_str, progress=gr.Progress(track_tqdm=True)):
|
196 |
try:
|
197 |
+
if not dt_str: raise gr.Error("Please select a date and time.")
|
|
|
198 |
|
199 |
progress(0, desc="Initializing...")
|
200 |
setup_and_load_model(progress)
|
|
|
203 |
logger.info(f"Starting forecast for target time: {target_dt}")
|
204 |
|
205 |
input_tensor, last_input_map, target_map = fetch_and_process_sdo_data(target_dt, progress)
|
|
|
206 |
prediction_tensor = run_inference(input_tensor)
|
207 |
|
208 |
img_in, img_pred, img_target = generate_visualization(last_input_map, prediction_tensor, target_map, "aia171")
|
|
|
218 |
raise gr.Error(f"Failed to generate forecast. Error: {e}")
|
219 |
|
220 |
def update_visualization_controller(last_input_map, prediction_tensor, target_map, channel_name):
|
221 |
+
if last_input_map is None: return None, None, None
|
|
|
222 |
return generate_visualization(last_input_map, prediction_tensor, target_map, channel_name)
|
223 |
|
224 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|