Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -16,8 +16,7 @@ from io import BytesIO
|
|
16 |
import re
|
17 |
|
18 |
from surya.models.helio_spectformer import HelioSpectFormer
|
19 |
-
|
20 |
-
from surya.utils.data import build_scalers, ScalerCollection
|
21 |
from surya.datasets.helio import inverse_transform_single_channel
|
22 |
|
23 |
warnings.filterwarnings("ignore", category=UserWarning, module='sunpy')
|
@@ -49,10 +48,7 @@ def setup_and_load_model():
|
|
49 |
config = yaml.safe_load(fp)
|
50 |
APP_CACHE["config"] = config
|
51 |
scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r"))
|
52 |
-
|
53 |
-
# *** FIX: Create the ScalerCollection object from the dictionary returned by build_scalers ***
|
54 |
-
scalers_dict = build_scalers(info=scalers_info)
|
55 |
-
APP_CACHE["scalers"] = ScalerCollection(scalers_dict, config["data"]["sdo_channels"])
|
56 |
|
57 |
yield "Initializing model architecture..."
|
58 |
model_config = config["model"]
|
@@ -142,7 +138,7 @@ def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
|
|
142 |
images[t][channel] = Image.open(BytesIO(response.content))
|
143 |
|
144 |
yield "✅ All images found and downloaded. Starting preprocessing..."
|
145 |
-
|
146 |
processed_tensors = {}
|
147 |
for t, channel_images in images.items():
|
148 |
channel_tensors = []
|
@@ -154,7 +150,9 @@ def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
|
|
154 |
img_resized = img.resize((img_size, img_size), Image.Resampling.LANCZOS)
|
155 |
norm_data = np.array(img_resized, dtype=np.float32)
|
156 |
|
157 |
-
|
|
|
|
|
158 |
channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32)))
|
159 |
processed_tensors[t] = torch.stack(channel_tensors)
|
160 |
|
@@ -180,13 +178,17 @@ def run_inference(input_tensor):
|
|
180 |
def generate_visualization(last_input_map, prediction_tensor, target_map, channel_name):
|
181 |
if last_input_map is None: return None, None, None
|
182 |
c_idx = SDO_CHANNELS.index(channel_name)
|
183 |
-
scaler = APP_CACHE["scalers"]
|
184 |
-
all_means, all_stds, all_epsilons, all_sl_scale_factors = scaler.get_params()
|
185 |
-
mean, std, epsilon, sl_scale_factor = all_means[c_idx], all_stds[c_idx], all_epsilons[c_idx], all_sl_scale_factors[c_idx]
|
186 |
-
pred_slice = inverse_transform_single_channel(
|
187 |
-
prediction_tensor[0, c_idx].numpy(), mean=mean, std=std, epsilon=epsilon, sl_scale_factor=sl_scale_factor
|
188 |
-
)
|
189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
target_img_data = np.array(target_map[channel_name])
|
191 |
vmax = np.quantile(np.nan_to_num(target_img_data), 0.995)
|
192 |
cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
|
|
|
16 |
import re
|
17 |
|
18 |
from surya.models.helio_spectformer import HelioSpectFormer
|
19 |
+
from surya.utils.data import build_scalers
|
|
|
20 |
from surya.datasets.helio import inverse_transform_single_channel
|
21 |
|
22 |
warnings.filterwarnings("ignore", category=UserWarning, module='sunpy')
|
|
|
48 |
config = yaml.safe_load(fp)
|
49 |
APP_CACHE["config"] = config
|
50 |
scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r"))
|
51 |
+
APP_CACHE["scalers"] = build_scalers(info=scalers_info)
|
|
|
|
|
|
|
52 |
|
53 |
yield "Initializing model architecture..."
|
54 |
model_config = config["model"]
|
|
|
138 |
images[t][channel] = Image.open(BytesIO(response.content))
|
139 |
|
140 |
yield "✅ All images found and downloaded. Starting preprocessing..."
|
141 |
+
scalers_dict = APP_CACHE["scalers"]
|
142 |
processed_tensors = {}
|
143 |
for t, channel_images in images.items():
|
144 |
channel_tensors = []
|
|
|
150 |
img_resized = img.resize((img_size, img_size), Image.Resampling.LANCZOS)
|
151 |
norm_data = np.array(img_resized, dtype=np.float32)
|
152 |
|
153 |
+
# *** FIX: Retrieve the correct scaler object from the dictionary for the current channel ***
|
154 |
+
scaler = scalers_dict[channel]
|
155 |
+
scaled_data = scaler.transform(norm_data.reshape(-1, 1)).reshape(norm_data.shape)
|
156 |
channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32)))
|
157 |
processed_tensors[t] = torch.stack(channel_tensors)
|
158 |
|
|
|
178 |
def generate_visualization(last_input_map, prediction_tensor, target_map, channel_name):
|
179 |
if last_input_map is None: return None, None, None
|
180 |
c_idx = SDO_CHANNELS.index(channel_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
+
# *** FIX: Retrieve the correct scaler object for the current channel to get its parameters ***
|
183 |
+
scaler = APP_CACHE["scalers"][channel_name]
|
184 |
+
params = scaler.to_dict()
|
185 |
+
mean, std = params['mean'], params['std']
|
186 |
+
|
187 |
+
# Note: The inverse transform for the simplified JPEG pipeline might differ from the original
|
188 |
+
# We will use a standard inverse scaling, which is the most logical approach here.
|
189 |
+
pred_slice_scaled = prediction_tensor[0, c_idx].numpy()
|
190 |
+
pred_slice = (pred_slice_scaled * std) + mean
|
191 |
+
|
192 |
target_img_data = np.array(target_map[channel_name])
|
193 |
vmax = np.quantile(np.nan_to_num(target_img_data), 0.995)
|
194 |
cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
|