broadfield-dev commited on
Commit
5d92716
·
verified ·
1 Parent(s): f43722b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +255 -111
app.py CHANGED
@@ -1,172 +1,316 @@
1
  import gradio as gr
2
  import torch
3
- from huggingface_hub import hf_hub_download
4
- from surya.model import Surya # This now works because of the file structure
 
 
5
  import numpy as np
6
  from PIL import Image
 
7
  import os
 
8
  import warnings
 
9
 
10
- # Suppress warnings for a cleaner demo experience
11
  warnings.filterwarnings("ignore")
 
 
12
 
13
- # --- 1. Define Constants and Data Channels ---
14
- # Based on the Surya project's data preprocessing
15
- AIA_CHANNELS = ["94", "131", "171", "193", "211", "304", "335", "1600"]
16
- HMI_CHANNELS = ["bx", "by", "bz", "by_abs", "bz_abs"]
17
- ALL_CHANNELS = [f"AIA {ch} Å" for ch in AIA_CHANNELS] + [f"HMI {ch}" for ch in HMI_CHANNELS]
18
 
19
- # --- 2. Caching and Loading the Model and Data ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  @gr.cache
21
- def load_model_and_data():
22
  """
23
- Downloads the pre-trained Surya model, the test data, and initializes the model.
24
- This function is cached so this happens only once.
 
25
  """
26
- print("Downloading model and test data... This may take a moment.")
27
- # Define local directories for caching
28
- model_dir = "./surya_model"
29
- data_dir = "./surya_data"
30
- os.makedirs(model_dir, exist_ok=True)
31
- os.makedirs(data_dir, exist_ok=True)
32
-
33
- # Download the model weights and test data from Hugging Face
34
- checkpoint_path = hf_hub_download(
35
  repo_id="nasa-ibm-ai4science/Surya-1.0",
36
- filename="surya.366m.v1.pt",
37
- local_dir=model_dir
38
  )
39
- test_data_path = hf_hub_download(
40
- repo_id="nasa-ibm-ai4science/Surya-1.0",
41
- filename="test_data.pt",
42
- local_dir=data_dir
43
- )
44
- print("Downloads complete.")
45
-
46
- # Initialize the model architecture
47
- model = Surya(
48
- img_size=4096,
49
- patch_size=16,
50
- in_chans=13,
51
- embed_dim=1280,
52
- spectral_blocks=2,
53
- attention_blocks=8,
54
  )
55
 
56
- # Load the weights into the model
57
- print("Loading model weights...")
58
- model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
59
- model.eval()
60
- print("Model loaded successfully.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- # Load the test data
63
- test_data = torch.load(test_data_path)
64
- test_input = test_data["input"] # Input tensor for the model
65
- test_label = test_data["label"] # Ground truth for comparison
66
 
67
- return model, test_input, test_label
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- # --- 3. Helper function for Image Conversion ---
70
- def tensor_to_image(tensor_slice):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  """
72
- Normalizes a 2D tensor slice and converts it to a PIL Image for display.
 
73
  """
74
- # Detach tensor from graph, move to CPU, and convert to numpy
75
- img_np = tensor_slice.detach().cpu().numpy()
 
 
 
 
76
 
77
- # Normalize the tensor to a 0-255 range for image display
78
- min_val, max_val = np.min(img_np), np.max(img_np)
79
- if max_val > min_val:
80
- img_np = (img_np - min_val) / (max_val - min_val)
 
 
 
81
 
82
- img_array = (img_np * 255).astype(np.uint8)
83
- return Image.fromarray(img_array)
84
-
 
 
 
 
85
 
86
- # --- 4. Main Prediction and Visualization Function ---
87
- def run_forecast(channel_name, progress=gr.Progress()):
88
  """
89
- This function is triggered by the button click in the Gradio interface.
90
- It runs the model prediction and generates the images for display.
91
  """
92
- progress(0, desc="Loading model and data (first run may be slow)...")
93
- # Load the model and data (will be fast after the first run due to caching)
94
- model, test_input, test_label = load_model_and_data()
 
 
 
 
95
 
96
- progress(0.5, desc="Running inference on the model...")
97
- # Perform the forecast
98
- with torch.no_grad():
99
- prediction = model(test_input)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- progress(0.8, desc="Generating visualizations...")
102
- # Get the index of the selected channel
103
- channel_index = ALL_CHANNELS.index(channel_name)
 
104
 
105
- # Extract the last time step from the input sequence for display
106
- # Shape: [batch, channels, time, height, width] -> select channel, last time step
107
- input_slice = test_input[0, channel_index, -1, :, :]
108
- input_image = tensor_to_image(input_slice)
109
 
110
- # Extract the corresponding slice from the model's prediction
111
- # Shape: [batch, channels, time, height, width] -> select channel, first predicted step
112
- predicted_slice = prediction[0, channel_index, 0, :, :]
113
- predicted_image = tensor_to_image(predicted_slice)
 
 
 
 
 
 
 
 
 
 
114
 
115
- # Extract the corresponding slice from the ground truth label
116
- label_slice = test_label[0, channel_index, 0, :, :]
117
- label_image = tensor_to_image(label_slice)
118
 
119
- print(f"Forecast generated for channel: {channel_name}")
120
- return input_image, predicted_image, label_image
121
 
122
- # --- 5. Building the Gradio Interface ---
123
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
 
 
 
 
 
124
  gr.Markdown(
125
  """
126
  <div align="center">
127
- # ☀️ Surya: A Live Demonstration of NASA's Heliophysics Foundation Model ☀️
128
- This demo runs the actual Surya model to forecast solar activity. It uses the official test data for **2014-01-07**,
129
- allowing a direct comparison between the model's prediction and the real ground truth.
 
130
  </div>
131
  """
132
  )
133
-
134
  with gr.Row():
135
  channel_selector = gr.Dropdown(
136
- choices=ALL_CHANNELS,
137
- value=ALL_CHANNELS[2], # Default to "AIA 171 Å"
138
  label="🛰️ Select SDO Instrument Channel",
139
  info="Choose which solar observation channel to visualize."
140
  )
 
141
 
142
- run_button = gr.Button("🔮 Generate Forecast for 2014-01-07", variant="primary")
143
-
144
  with gr.Row():
145
  with gr.Column():
146
  gr.Markdown("### ⬅️ Final Input Image")
147
- gr.Markdown("The last image shown to the model before it makes a prediction.")
148
- input_display = gr.Image(label="Input Observation", height=400, width=400)
149
  with gr.Column():
150
  gr.Markdown("### 🔮 Model's Forecast")
151
- gr.Markdown("What the Surya model predicted the Sun would look like.")
152
- prediction_display = gr.Image(label="Surya Prediction", height=400, width=400)
153
  with gr.Column():
154
  gr.Markdown("### ✅ Ground Truth")
155
- gr.Markdown("What the Sun *actually* looked like at the forecast time.")
156
- label_display = gr.Image(label="Actual Observation", height=400, width=400)
157
-
158
- gr.Markdown(
159
- "--- \n"
160
- "**Note:** The first time you run a forecast, the app will download the 366M-parameter model (~1.4 GB) and test data. Subsequent runs will be much faster. "
161
- "The images are downscaled for display in this demo. "
162
- "For more information, visit the [Surya Hugging Face Repository](https://huggingface.co/nasa-ibm-ai4science/Surya-1.0)."
163
- )
164
 
 
165
  run_button.click(
166
- fn=run_forecast,
167
  inputs=[channel_selector],
168
- outputs=[input_display, prediction_display, label_display]
 
 
 
 
 
 
169
  )
170
 
171
  if __name__ == "__main__":
 
 
 
 
172
  demo.launch(debug=True)
 
1
  import gradio as gr
2
  import torch
3
+ import torch.nn.functional as F
4
+ from torch.utils.data import DataLoader
5
+ from huggingface_hub import snapshot_download
6
+ import yaml
7
  import numpy as np
8
  from PIL import Image
9
+ import sunpy.visualization.colormaps as sunpy_cm
10
  import os
11
+ import glob
12
  import warnings
13
+ import logging
14
 
15
+ # --- Suppress verbose logging and warnings for a cleaner UI ---
16
  warnings.filterwarnings("ignore")
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
 
20
+ # --- Dependencies from the Surya Repository ---
21
+ # NOTE: To make this script self-contained, the required classes and functions
22
+ # from the 'surya' library are included directly here.
23
+ # In a full installation, these would be imported.
 
24
 
25
+ from surya_dependencies import (
26
+ HelioSpectFormer,
27
+ HelioNetCDFDataset,
28
+ build_scalers,
29
+ custom_collate_fn,
30
+ inverse_transform_single_channel,
31
+ SDO_CHANNELS,
32
+ AIA_CHANNELS,
33
+ HMI_CHANNELS
34
+ )
35
+
36
+ # --- Global Cache for Model and Data ---
37
+ # We use a simple dictionary to act as a cache to avoid reloading.
38
+ APP_CACHE = {
39
+ "model": None,
40
+ "config": None,
41
+ "scalers": None,
42
+ "dataset": None,
43
+ "dataloader": None,
44
+ "device": "cuda" if torch.cuda.is_available() else "cpu",
45
+ }
46
+
47
+ # --- 1. Setup and Data Download ---
48
  @gr.cache
49
+ def setup_environment_and_download_data():
50
  """
51
+ Downloads all necessary files from Hugging Face: model, config, scalers, and validation data.
52
+ Also creates the necessary index file for the dataset loader.
53
+ This function is cached by Gradio to run only once.
54
  """
55
+ logger.info("Setting up environment. This will run only once.")
56
+ local_dir = "data/Surya-1.0"
57
+ # Download model, config, and scalers
58
+ snapshot_download(
 
 
 
 
 
59
  repo_id="nasa-ibm-ai4science/Surya-1.0",
60
+ local_dir=local_dir,
61
+ allow_patterns=["config.yaml", "scalers.yaml", "surya.366m.v1.pt"],
62
  )
63
+
64
+ # Download validation data
65
+ data_dir = "data/Surya-1.0_validation_data"
66
+ snapshot_download(
67
+ repo_id="nasa-ibm-ai4science/Surya-1.0_validation_data",
68
+ repo_type="dataset",
69
+ local_dir=data_dir,
70
+ allow_patterns="20140107_1[5-9]??.nc",
 
 
 
 
 
 
 
71
  )
72
 
73
+ # The test script requires an index file. We'll create it dynamically.
74
+ index_dir = "data/test_indices"
75
+ os.makedirs(index_dir, exist_ok=True)
76
+ index_file_path = os.path.join(index_dir, "test_surya_index.csv")
77
+
78
+ with open(index_file_path, "w") as f:
79
+ f.write("path\n")
80
+ # Find the downloaded NetCDF files and write their paths to the index
81
+ search_path = os.path.join(data_dir, "**", "*.nc")
82
+ for nc_file in sorted(glob.glob(search_path, recursive=True)):
83
+ f.write(f"{nc_file}\n")
84
+ logger.info(f"Created index file at {index_file_path}")
85
+ return index_file_path, local_dir
86
+
87
+ # --- 2. Model and Data Loading ---
88
+ def load_essentials(model_dir):
89
+ """Loads config, scalers, and the model into the APP_CACHE."""
90
+ if APP_CACHE["model"] is None:
91
+ logger.info("Loading config, scalers, and model for the first time...")
92
+ # Load config
93
+ with open(os.path.join(model_dir, "config.yaml")) as fp:
94
+ config = yaml.safe_load(fp)
95
+ APP_CACHE["config"] = config
96
 
97
+ # Build scalers for data normalization
98
+ scalers_info = yaml.safe_load(open(os.path.join(model_dir, "scalers.yaml"), "r"))
99
+ APP_CACHE["scalers"] = build_scalers(info=scalers_info)
 
100
 
101
+ # Initialize model from config
102
+ model = HelioSpectFormer(
103
+ img_size=config["model"]["img_size"],
104
+ patch_size=config["model"]["patch_size"],
105
+ in_chans=len(config["data"]["sdo_channels"]),
106
+ embed_dim=config["model"]["embed_dim"],
107
+ time_embedding={"type": "linear", "time_dim": len(config["data"]["time_delta_input_minutes"])},
108
+ depth=config["model"]["depth"],
109
+ n_spectral_blocks=config["model"]["n_spectral_blocks"],
110
+ num_heads=config["model"]["num_heads"],
111
+ mlp_ratio=config["model"]["mlp_ratio"],
112
+ drop_rate=config["model"]["drop_rate"],
113
+ dtype=torch.bfloat16,
114
+ window_size=config["model"]["window_size"],
115
+ dp_rank=config["model"]["dp_rank"],
116
+ learned_flow=config["model"]["learned_flow"],
117
+ use_latitude_in_learned_flow=config["model"]["learned_flow"],
118
+ init_weights=False,
119
+ checkpoint_layers=[i for i in range(config["model"]["depth"])],
120
+ rpe=config["model"]["rpe"],
121
+ ensemble=config["model"]["ensemble"],
122
+ finetune=config["model"]["finetune"],
123
+ )
124
+
125
+ # Load pre-trained weights
126
+ path_weights = os.path.join(model_dir, "surya.366m.v1.pt")
127
+ weights = torch.load(path_weights, map_location=torch.device(APP_CACHE["device"]))
128
+ model.load_state_dict(weights, strict=True)
129
+ model.to(APP_CACHE["device"])
130
+ model.eval()
131
+
132
+ n_params = sum(p.numel() for p in model.parameters()) / 1e6
133
+ logger.info(f"Surya FM: {n_params:.2f}M parameters loaded to {APP_CACHE['device']}.")
134
+ APP_CACHE["model"] = model
135
 
136
+ def get_dataloader(index_path):
137
+ """Initializes and returns a DataLoader for the validation data."""
138
+ if APP_CACHE["dataloader"] is None:
139
+ logger.info("Initializing dataset and dataloader...")
140
+ config = APP_CACHE["config"]
141
+ dataset = HelioNetCDFDataset(
142
+ index_path=index_path,
143
+ time_delta_input_minutes=config["data"]["time_delta_input_minutes"],
144
+ time_delta_target_minutes=config["data"]["time_delta_target_minutes"],
145
+ n_input_timestamps=len(config["data"]["time_delta_input_minutes"]),
146
+ rollout_steps=1,
147
+ channels=config["data"]["sdo_channels"],
148
+ scalers=APP_CACHE["scalers"],
149
+ phase="valid", # Important: ensure no random augmentations
150
+ )
151
+ dataloader = DataLoader(
152
+ dataset, shuffle=False, batch_size=1, num_workers=2,
153
+ pin_memory=True, drop_last=False, collate_fn=custom_collate_fn,
154
+ )
155
+ APP_CACHE["dataloader"] = dataloader
156
+ APP_CACHE["dataset"] = dataset # Also cache dataset for transformation info
157
+ return APP_CACHE["dataloader"]
158
+
159
+
160
+ # --- 3. Core Inference and Visualization Logic ---
161
+ def run_model_inference():
162
  """
163
+ Performs a single prediction step using the loaded model and dataloader.
164
+ Returns the raw input, prediction, and ground truth tensors.
165
  """
166
+ model = APP_CACHE["model"]
167
+ dataloader = APP_CACHE["dataloader"]
168
+ device = APP_CACHE["device"]
169
+
170
+ # Get the first (and only) batch of data from the validation set
171
+ batch_data, batch_metadata = next(iter(dataloader))
172
 
173
+ logger.info("Running inference on the validation batch...")
174
+ with torch.no_grad():
175
+ # Prepare input batch for the model
176
+ input_batch = {key: batch_data[key].to(device) for key in ["ts", "time_delta_input"]}
177
+ # Run model prediction
178
+ with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16):
179
+ prediction_tensor = model(input_batch)
180
 
181
+ # Get the input and target tensors for comparison
182
+ input_tensor = input_batch["ts"].to(dtype=torch.float32).cpu()
183
+ target_tensor = batch_data["forecast"].cpu()
184
+ prediction_tensor = prediction_tensor.to(dtype=torch.float32).cpu()
185
+
186
+ logger.info("Inference complete.")
187
+ return input_tensor, prediction_tensor, target_tensor
188
 
189
+ def create_visualizations(channel_name, input_tensor, prediction_tensor, target_tensor):
 
190
  """
191
+ Takes raw tensors and a channel name, applies inverse transformation,
192
+ and converts them to displayable PIL Images.
193
  """
194
+ if input_tensor is None:
195
+ return None, None, None, "Please run the forecast first."
196
+
197
+ logger.info(f"Creating visualization for channel: {channel_name}")
198
+ c_idx = SDO_CHANNELS.index(channel_name)
199
+ dataset = APP_CACHE["dataset"]
200
+ means, stds, epsilons, sl_scale_factors = dataset.transformation_inputs()
201
 
202
+ # --- Denormalize data for visualization ---
203
+ # Final input image given to the model (last in sequence)
204
+ input_slice = inverse_transform_single_channel(
205
+ input_tensor[0, c_idx, -1, :, :].numpy(),
206
+ mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx]
207
+ )
208
+ # Model's prediction
209
+ pred_slice = inverse_transform_single_channel(
210
+ prediction_tensor[0, c_idx, :, :].numpy(),
211
+ mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx]
212
+ )
213
+ # Ground truth image
214
+ target_slice = inverse_transform_single_channel(
215
+ target_tensor[0, c_idx, 0, :, :].numpy(),
216
+ mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx]
217
+ )
218
+
219
+ # --- Convert to images ---
220
+ # Use a shared color scale for better comparison, clipped at 99.5th percentile
221
+ vmax = np.quantile(np.concatenate([input_slice, pred_slice, target_slice]), 0.995)
222
+
223
+ # Determine colormap from channel name
224
+ cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
225
+ cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray'))
226
 
227
+ def to_pil(data, vmin=0, vmax=vmax, cmap=cmap):
228
+ data_clipped = np.clip(data, vmin, vmax)
229
+ data_norm = (data_clipped - vmin) / (vmax - vmin)
230
+ return Image.fromarray((cmap(data_norm)[:, :, :3] * 255).astype(np.uint8))
231
 
232
+ return to_pil(input_slice), to_pil(pred_slice), to_pil(target_slice), f"Displaying forecast for {channel_name}"
 
 
 
233
 
234
+ # --- 4. Gradio Controller Functions ---
235
+ def forecast_controller(channel_name, progress=gr.Progress()):
236
+ """
237
+ Main function triggered by the 'Generate' button. Orchestrates the entire pipeline.
238
+ """
239
+ progress(0, desc="Downloading model and data (first launch only)...")
240
+ index_path, model_dir = setup_environment_and_download_data()
241
+
242
+ progress(0.4, desc="Loading model and building data pipeline...")
243
+ load_essentials(model_dir)
244
+ get_dataloader(index_path)
245
+
246
+ progress(0.7, desc=f"Running inference on {APP_CACHE['device']}...")
247
+ input_t, pred_t, target_t = run_model_inference()
248
 
249
+ progress(0.9, desc="Creating visualizations...")
250
+ img_in, img_pred, img_target, status = create_visualizations(channel_name, input_t, pred_t, target_t)
 
251
 
252
+ return img_in, img_pred, img_target, status, input_t, pred_t, target_t
253
+
254
 
255
+ # --- 5. Gradio UI Layout ---
256
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
257
+ # Hidden state variables to store the raw tensors after inference
258
+ state_input = gr.State()
259
+ state_prediction = gr.State()
260
+ state_target = gr.State()
261
+
262
  gr.Markdown(
263
  """
264
  <div align="center">
265
+ # ☀️ Surya: Live Model Demo ☀️
266
+ ### An Interactive Interface for NASA's Heliophysics Foundation Model
267
+ This demo runs the **actual** Surya model on its official validation data for **2014-01-07**.
268
+ Click the button to generate a forecast, then use the dropdown to explore the results across different SDO instrument channels.
269
  </div>
270
  """
271
  )
272
+
273
  with gr.Row():
274
  channel_selector = gr.Dropdown(
275
+ choices=SDO_CHANNELS,
276
+ value="aia171",
277
  label="🛰️ Select SDO Instrument Channel",
278
  info="Choose which solar observation channel to visualize."
279
  )
280
+ run_button = gr.Button("🔮 Generate Forecast for 2014-01-07", variant="primary", scale=2)
281
 
282
+ status_box = gr.Textbox(label="Status", interactive=False, value="Ready. Press 'Generate Forecast' to start.")
283
+
284
  with gr.Row():
285
  with gr.Column():
286
  gr.Markdown("### ⬅️ Final Input Image")
287
+ gr.Markdown("The last observation shown to the model (T-1).")
288
+ input_display = gr.Image(label="Input", height=512, width=512, interactive=False)
289
  with gr.Column():
290
  gr.Markdown("### 🔮 Model's Forecast")
291
+ gr.Markdown("Surya's prediction for the next timestep (T+0).")
292
+ prediction_display = gr.Image(label="Prediction", height=512, width=512, interactive=False)
293
  with gr.Column():
294
  gr.Markdown("### ✅ Ground Truth")
295
+ gr.Markdown("What the Sun *actually* looked like at T+0.")
296
+ label_display = gr.Image(label="Ground Truth", height=512, width=512, interactive=False)
 
 
 
 
 
 
 
297
 
298
+ # --- Event Handlers ---
299
  run_button.click(
300
+ fn=forecast_controller,
301
  inputs=[channel_selector],
302
+ outputs=[input_display, prediction_display, label_display, status_box, state_input, state_prediction, state_target]
303
+ )
304
+
305
+ channel_selector.change(
306
+ fn=create_visualizations,
307
+ inputs=[channel_selector, state_input, state_prediction, state_target],
308
+ outputs=[input_display, prediction_display, label_display, status_box]
309
  )
310
 
311
  if __name__ == "__main__":
312
+ # The 'surya_dependencies.py' file must be in the same directory as this script.
313
+ # Create the placeholder file if it doesn't exist.
314
+ if not os.path.exists("surya_dependencies.py"):
315
+ raise FileNotFoundError("The required 'surya_dependencies.py' file is missing. Please download it from the provided source.")
316
  demo.launch(debug=True)