broadfield-dev commited on
Commit
3e08cbf
·
verified ·
1 Parent(s): 1dc86bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -16,7 +16,8 @@ from io import BytesIO
16
  import re
17
 
18
  from surya.models.helio_spectformer import HelioSpectFormer
19
- from surya.utils.data import build_scalers
 
20
  from surya.datasets.helio import inverse_transform_single_channel
21
 
22
  warnings.filterwarnings("ignore", category=UserWarning, module='sunpy')
@@ -48,7 +49,10 @@ def setup_and_load_model():
48
  config = yaml.safe_load(fp)
49
  APP_CACHE["config"] = config
50
  scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r"))
51
- APP_CACHE["scalers"] = build_scalers(info=scalers_info)
 
 
 
52
 
53
  yield "Initializing model architecture..."
54
  model_config = config["model"]
@@ -79,7 +83,7 @@ def find_nearest_browse_image_url(channel, target_dt):
79
  url_code = CHANNEL_TO_URL_CODE[channel]
80
  base_url = "https://sdo.gsfc.nasa.gov/assets/img/browse"
81
 
82
- for i in range(2): # Try today, then yesterday
83
  dt_to_try = target_dt - datetime.timedelta(days=i)
84
  dir_url = dt_to_try.strftime(f"{base_url}/%Y/%m/%d/")
85
 
@@ -150,7 +154,7 @@ def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
150
  img_resized = img.resize((img_size, img_size), Image.Resampling.LANCZOS)
151
  norm_data = np.array(img_resized, dtype=np.float32)
152
 
153
- scaled_data = scaler.transform(norm_data.reshape(-1, 1), c_idx=i).reshape(norm_data.shape)
154
  channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32)))
155
  processed_tensors[t] = torch.stack(channel_tensors)
156
 
 
16
  import re
17
 
18
  from surya.models.helio_spectformer import HelioSpectFormer
19
+ # *** FIX: Import ScalerCollection and correct the import for inverse_transform_single_channel ***
20
+ from surya.utils.data import build_scalers, ScalerCollection
21
  from surya.datasets.helio import inverse_transform_single_channel
22
 
23
  warnings.filterwarnings("ignore", category=UserWarning, module='sunpy')
 
49
  config = yaml.safe_load(fp)
50
  APP_CACHE["config"] = config
51
  scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r"))
52
+
53
+ # *** FIX: Create the ScalerCollection object from the dictionary returned by build_scalers ***
54
+ scalers_dict = build_scalers(info=scalers_info)
55
+ APP_CACHE["scalers"] = ScalerCollection(scalers_dict, config["data"]["sdo_channels"])
56
 
57
  yield "Initializing model architecture..."
58
  model_config = config["model"]
 
83
  url_code = CHANNEL_TO_URL_CODE[channel]
84
  base_url = "https://sdo.gsfc.nasa.gov/assets/img/browse"
85
 
86
+ for i in range(2):
87
  dt_to_try = target_dt - datetime.timedelta(days=i)
88
  dir_url = dt_to_try.strftime(f"{base_url}/%Y/%m/%d/")
89
 
 
154
  img_resized = img.resize((img_size, img_size), Image.Resampling.LANCZOS)
155
  norm_data = np.array(img_resized, dtype=np.float32)
156
 
157
+ scaled_data = scaler.transform(norm_data, c_idx=i)
158
  channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32)))
159
  processed_tensors[t] = torch.stack(channel_tensors)
160