Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -16,7 +16,8 @@ from io import BytesIO
|
|
16 |
import re
|
17 |
|
18 |
from surya.models.helio_spectformer import HelioSpectFormer
|
19 |
-
|
|
|
20 |
from surya.datasets.helio import inverse_transform_single_channel
|
21 |
|
22 |
warnings.filterwarnings("ignore", category=UserWarning, module='sunpy')
|
@@ -48,7 +49,10 @@ def setup_and_load_model():
|
|
48 |
config = yaml.safe_load(fp)
|
49 |
APP_CACHE["config"] = config
|
50 |
scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r"))
|
51 |
-
|
|
|
|
|
|
|
52 |
|
53 |
yield "Initializing model architecture..."
|
54 |
model_config = config["model"]
|
@@ -79,7 +83,7 @@ def find_nearest_browse_image_url(channel, target_dt):
|
|
79 |
url_code = CHANNEL_TO_URL_CODE[channel]
|
80 |
base_url = "https://sdo.gsfc.nasa.gov/assets/img/browse"
|
81 |
|
82 |
-
for i in range(2):
|
83 |
dt_to_try = target_dt - datetime.timedelta(days=i)
|
84 |
dir_url = dt_to_try.strftime(f"{base_url}/%Y/%m/%d/")
|
85 |
|
@@ -150,7 +154,7 @@ def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
|
|
150 |
img_resized = img.resize((img_size, img_size), Image.Resampling.LANCZOS)
|
151 |
norm_data = np.array(img_resized, dtype=np.float32)
|
152 |
|
153 |
-
scaled_data = scaler.transform(norm_data
|
154 |
channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32)))
|
155 |
processed_tensors[t] = torch.stack(channel_tensors)
|
156 |
|
|
|
16 |
import re
|
17 |
|
18 |
from surya.models.helio_spectformer import HelioSpectFormer
|
19 |
+
# *** FIX: Import ScalerCollection and correct the import for inverse_transform_single_channel ***
|
20 |
+
from surya.utils.data import build_scalers, ScalerCollection
|
21 |
from surya.datasets.helio import inverse_transform_single_channel
|
22 |
|
23 |
warnings.filterwarnings("ignore", category=UserWarning, module='sunpy')
|
|
|
49 |
config = yaml.safe_load(fp)
|
50 |
APP_CACHE["config"] = config
|
51 |
scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r"))
|
52 |
+
|
53 |
+
# *** FIX: Create the ScalerCollection object from the dictionary returned by build_scalers ***
|
54 |
+
scalers_dict = build_scalers(info=scalers_info)
|
55 |
+
APP_CACHE["scalers"] = ScalerCollection(scalers_dict, config["data"]["sdo_channels"])
|
56 |
|
57 |
yield "Initializing model architecture..."
|
58 |
model_config = config["model"]
|
|
|
83 |
url_code = CHANNEL_TO_URL_CODE[channel]
|
84 |
base_url = "https://sdo.gsfc.nasa.gov/assets/img/browse"
|
85 |
|
86 |
+
for i in range(2):
|
87 |
dt_to_try = target_dt - datetime.timedelta(days=i)
|
88 |
dir_url = dt_to_try.strftime(f"{base_url}/%Y/%m/%d/")
|
89 |
|
|
|
154 |
img_resized = img.resize((img_size, img_size), Image.Resampling.LANCZOS)
|
155 |
norm_data = np.array(img_resized, dtype=np.float32)
|
156 |
|
157 |
+
scaled_data = scaler.transform(norm_data, c_idx=i)
|
158 |
channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32)))
|
159 |
processed_tensors[t] = torch.stack(channel_tensors)
|
160 |
|