broadfield-dev commited on
Commit
03610d5
·
verified ·
1 Parent(s): 3e08cbf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -14
app.py CHANGED
@@ -16,8 +16,7 @@ from io import BytesIO
16
  import re
17
 
18
  from surya.models.helio_spectformer import HelioSpectFormer
19
- # *** FIX: Import ScalerCollection and correct the import for inverse_transform_single_channel ***
20
- from surya.utils.data import build_scalers, ScalerCollection
21
  from surya.datasets.helio import inverse_transform_single_channel
22
 
23
  warnings.filterwarnings("ignore", category=UserWarning, module='sunpy')
@@ -49,10 +48,7 @@ def setup_and_load_model():
49
  config = yaml.safe_load(fp)
50
  APP_CACHE["config"] = config
51
  scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r"))
52
-
53
- # *** FIX: Create the ScalerCollection object from the dictionary returned by build_scalers ***
54
- scalers_dict = build_scalers(info=scalers_info)
55
- APP_CACHE["scalers"] = ScalerCollection(scalers_dict, config["data"]["sdo_channels"])
56
 
57
  yield "Initializing model architecture..."
58
  model_config = config["model"]
@@ -142,7 +138,7 @@ def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
142
  images[t][channel] = Image.open(BytesIO(response.content))
143
 
144
  yield "✅ All images found and downloaded. Starting preprocessing..."
145
- scaler = APP_CACHE["scalers"]
146
  processed_tensors = {}
147
  for t, channel_images in images.items():
148
  channel_tensors = []
@@ -154,7 +150,9 @@ def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
154
  img_resized = img.resize((img_size, img_size), Image.Resampling.LANCZOS)
155
  norm_data = np.array(img_resized, dtype=np.float32)
156
 
157
- scaled_data = scaler.transform(norm_data, c_idx=i)
 
 
158
  channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32)))
159
  processed_tensors[t] = torch.stack(channel_tensors)
160
 
@@ -180,13 +178,17 @@ def run_inference(input_tensor):
180
  def generate_visualization(last_input_map, prediction_tensor, target_map, channel_name):
181
  if last_input_map is None: return None, None, None
182
  c_idx = SDO_CHANNELS.index(channel_name)
183
- scaler = APP_CACHE["scalers"]
184
- all_means, all_stds, all_epsilons, all_sl_scale_factors = scaler.get_params()
185
- mean, std, epsilon, sl_scale_factor = all_means[c_idx], all_stds[c_idx], all_epsilons[c_idx], all_sl_scale_factors[c_idx]
186
- pred_slice = inverse_transform_single_channel(
187
- prediction_tensor[0, c_idx].numpy(), mean=mean, std=std, epsilon=epsilon, sl_scale_factor=sl_scale_factor
188
- )
189
 
 
 
 
 
 
 
 
 
 
 
190
  target_img_data = np.array(target_map[channel_name])
191
  vmax = np.quantile(np.nan_to_num(target_img_data), 0.995)
192
  cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
 
16
  import re
17
 
18
  from surya.models.helio_spectformer import HelioSpectFormer
19
+ from surya.utils.data import build_scalers
 
20
  from surya.datasets.helio import inverse_transform_single_channel
21
 
22
  warnings.filterwarnings("ignore", category=UserWarning, module='sunpy')
 
48
  config = yaml.safe_load(fp)
49
  APP_CACHE["config"] = config
50
  scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r"))
51
+ APP_CACHE["scalers"] = build_scalers(info=scalers_info)
 
 
 
52
 
53
  yield "Initializing model architecture..."
54
  model_config = config["model"]
 
138
  images[t][channel] = Image.open(BytesIO(response.content))
139
 
140
  yield "✅ All images found and downloaded. Starting preprocessing..."
141
+ scalers_dict = APP_CACHE["scalers"]
142
  processed_tensors = {}
143
  for t, channel_images in images.items():
144
  channel_tensors = []
 
150
  img_resized = img.resize((img_size, img_size), Image.Resampling.LANCZOS)
151
  norm_data = np.array(img_resized, dtype=np.float32)
152
 
153
+ # *** FIX: Retrieve the correct scaler object from the dictionary for the current channel ***
154
+ scaler = scalers_dict[channel]
155
+ scaled_data = scaler.transform(norm_data.reshape(-1, 1)).reshape(norm_data.shape)
156
  channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32)))
157
  processed_tensors[t] = torch.stack(channel_tensors)
158
 
 
178
  def generate_visualization(last_input_map, prediction_tensor, target_map, channel_name):
179
  if last_input_map is None: return None, None, None
180
  c_idx = SDO_CHANNELS.index(channel_name)
 
 
 
 
 
 
181
 
182
+ # *** FIX: Retrieve the correct scaler object for the current channel to get its parameters ***
183
+ scaler = APP_CACHE["scalers"][channel_name]
184
+ params = scaler.to_dict()
185
+ mean, std = params['mean'], params['std']
186
+
187
+ # Note: The inverse transform for the simplified JPEG pipeline might differ from the original
188
+ # We will use a standard inverse scaling, which is the most logical approach here.
189
+ pred_slice_scaled = prediction_tensor[0, c_idx].numpy()
190
+ pred_slice = (pred_slice_scaled * std) + mean
191
+
192
  target_img_data = np.array(target_map[channel_name])
193
  vmax = np.quantile(np.nan_to_num(target_img_data), 0.995)
194
  cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'