broadfield-dev commited on
Commit
f43722b
·
verified ·
1 Parent(s): f966038

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -85
app.py CHANGED
@@ -1,26 +1,49 @@
1
  import gradio as gr
2
  import torch
3
  from huggingface_hub import hf_hub_download
4
- from surya.model import Surya
5
  import numpy as np
6
  from PIL import Image
 
7
  import warnings
8
 
9
  # Suppress warnings for a cleaner demo experience
10
  warnings.filterwarnings("ignore")
11
 
12
- # --- Model Loading ---
 
 
 
 
 
 
13
  @gr.cache
14
- def load_model():
15
  """
16
- Downloads the pre-trained Surya model weights and initializes the model.
17
- This function is cached so the model is only loaded once.
18
  """
 
 
 
 
 
 
 
 
19
  checkpoint_path = hf_hub_download(
20
  repo_id="nasa-ibm-ai4science/Surya-1.0",
21
- filename="surya.366m.v1.pt"
 
 
 
 
 
 
22
  )
 
23
 
 
24
  model = Surya(
25
  img_size=4096,
26
  patch_size=16,
@@ -30,106 +53,120 @@ def load_model():
30
  attention_blocks=8,
31
  )
32
 
 
 
33
  model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
34
  model.eval()
35
- return model
 
 
 
 
 
36
 
37
- model = load_model()
38
 
39
- # --- Core Prediction Logic ---
40
- def predict_solar_activity(time_steps, forecast_horizon):
41
  """
42
- Generates a forecast of solar activity using the Surya model.
43
- For this demo, we use a dummy input tensor to simulate the model's input.
44
- In a real-world scenario, this function would fetch and preprocess
45
- actual SDO data for the given time steps.
46
  """
47
- # Create a dummy input tensor representing a sequence of solar observations
48
- # Shape: [batch_size, channels, time_steps, height, width]
49
- dummy_input = torch.randn(1, 13, time_steps, 4096, 4096)
 
 
 
 
 
 
 
50
 
51
- # In a real application, you would replace the dummy input with actual,
52
- # preprocessed data from the Solar Dynamics Observatory (SDO).
53
- # Preprocessing would involve alignment and normalization as described
54
- # in the Surya paper.
55
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  with torch.no_grad():
57
- # The model's prediction would be based on the forecast_horizon
58
- # For this demo, we simulate a prediction by selecting a slice of the input
59
- prediction = model(dummy_input)
60
-
61
- # --- Visualization ---
62
- # For demonstration, we will visualize one of the output channels.
63
- # We will take the last predicted time step.
64
- predicted_image_tensor = prediction[0, 0, -1, :, :] # Visualizing the first channel
65
-
66
- # Normalize the tensor to a 0-255 range for image display
67
- normalized_tensor = (predicted_image_tensor - predicted_image_tensor.min()) / \
68
- (predicted_image_tensor.max() - predicted_image_tensor.min())
69
- image_array = (normalized_tensor * 255).byte().cpu().numpy()
70
- predicted_image = Image.fromarray(image_array)
71
-
72
- # For the flare prediction, we'll generate a dummy probability
73
- flare_probability = np.random.rand()
74
- if flare_probability > 0.5:
75
- flare_class = "M-class or X-class Flare"
76
- confidence = flare_probability
77
- else:
78
- flare_class = "No significant flare"
79
- confidence = 1 - flare_probability
80
-
81
- return predicted_image, {flare_class: confidence}
82
-
83
- # --- Gradio Interface Definition ---
84
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
85
  gr.Markdown(
86
  """
87
  <div align="center">
88
- # ☀️ Surya: Foundation Model for Heliophysics ☀️
89
- *A Gradio Demo for NASA's Solar Foundation Model*
 
90
  </div>
91
  """
92
  )
93
- gr.Markdown(
94
- "Surya is a 366M-parameter foundation model trained on full-resolution, multi-instrument SDO observations. "
95
- "This demo showcases its capability to forecast solar dynamics."
96
- )
97
 
98
  with gr.Row():
99
- with gr.Column(scale=1):
100
- gr.Markdown("### ⚙️ Prediction Parameters")
101
- time_steps_slider = gr.Slider(
102
- minimum=1, maximum=10, value=5, step=1,
103
- label="Number of Input Time Steps (12-min cadence)",
104
- info="Represents the sequence of past solar observations to feed the model."
105
- )
106
- forecast_horizon_slider = gr.Slider(
107
- minimum=1, maximum=24, value=1,
108
- label="Forecast Horizon (hours)",
109
- info="How far into the future to predict."
110
- )
111
- predict_button = gr.Button("🔮 Generate Forecast", variant="primary")
112
-
113
- with gr.Column(scale=2):
114
- gr.Markdown("### 🛰️ Predicted Solar Image")
115
- output_image = gr.Image(label="Forecasted SDO Image (AIA 171 Å)", height=512, width=512)
116
- gr.Markdown("### 💥 Solar Flare Prediction")
117
- output_flare = gr.Label(label="Flare Probability")
118
-
119
- predict_button.click(
120
- fn=predict_solar_activity,
121
- inputs=[time_steps_slider, forecast_horizon_slider],
122
- outputs=[output_image, output_flare]
123
- )
124
 
125
- gr.Markdown("---")
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  gr.Markdown(
127
- "**Note:** This demo uses a placeholder for real-time data fetching and displays a simulated prediction. "
128
- "The core of this application is the loaded Surya model from NASA and IBM."
 
 
129
  )
130
- gr.Markdown(
131
- "For more information, visit the [Surya model card on Hugging Face](https://huggingface.co/nasa-ibm-ai4science/Surya-1.0)."
 
 
 
132
  )
133
 
134
  if __name__ == "__main__":
135
- demo.launch()
 
1
  import gradio as gr
2
  import torch
3
  from huggingface_hub import hf_hub_download
4
+ from surya.model import Surya # This now works because of the file structure
5
  import numpy as np
6
  from PIL import Image
7
+ import os
8
  import warnings
9
 
10
  # Suppress warnings for a cleaner demo experience
11
  warnings.filterwarnings("ignore")
12
 
13
+ # --- 1. Define Constants and Data Channels ---
14
+ # Based on the Surya project's data preprocessing
15
+ AIA_CHANNELS = ["94", "131", "171", "193", "211", "304", "335", "1600"]
16
+ HMI_CHANNELS = ["bx", "by", "bz", "by_abs", "bz_abs"]
17
+ ALL_CHANNELS = [f"AIA {ch} Å" for ch in AIA_CHANNELS] + [f"HMI {ch}" for ch in HMI_CHANNELS]
18
+
19
+ # --- 2. Caching and Loading the Model and Data ---
20
  @gr.cache
21
+ def load_model_and_data():
22
  """
23
+ Downloads the pre-trained Surya model, the test data, and initializes the model.
24
+ This function is cached so this happens only once.
25
  """
26
+ print("Downloading model and test data... This may take a moment.")
27
+ # Define local directories for caching
28
+ model_dir = "./surya_model"
29
+ data_dir = "./surya_data"
30
+ os.makedirs(model_dir, exist_ok=True)
31
+ os.makedirs(data_dir, exist_ok=True)
32
+
33
+ # Download the model weights and test data from Hugging Face
34
  checkpoint_path = hf_hub_download(
35
  repo_id="nasa-ibm-ai4science/Surya-1.0",
36
+ filename="surya.366m.v1.pt",
37
+ local_dir=model_dir
38
+ )
39
+ test_data_path = hf_hub_download(
40
+ repo_id="nasa-ibm-ai4science/Surya-1.0",
41
+ filename="test_data.pt",
42
+ local_dir=data_dir
43
  )
44
+ print("Downloads complete.")
45
 
46
+ # Initialize the model architecture
47
  model = Surya(
48
  img_size=4096,
49
  patch_size=16,
 
53
  attention_blocks=8,
54
  )
55
 
56
+ # Load the weights into the model
57
+ print("Loading model weights...")
58
  model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
59
  model.eval()
60
+ print("Model loaded successfully.")
61
+
62
+ # Load the test data
63
+ test_data = torch.load(test_data_path)
64
+ test_input = test_data["input"] # Input tensor for the model
65
+ test_label = test_data["label"] # Ground truth for comparison
66
 
67
+ return model, test_input, test_label
68
 
69
+ # --- 3. Helper function for Image Conversion ---
70
+ def tensor_to_image(tensor_slice):
71
  """
72
+ Normalizes a 2D tensor slice and converts it to a PIL Image for display.
 
 
 
73
  """
74
+ # Detach tensor from graph, move to CPU, and convert to numpy
75
+ img_np = tensor_slice.detach().cpu().numpy()
76
+
77
+ # Normalize the tensor to a 0-255 range for image display
78
+ min_val, max_val = np.min(img_np), np.max(img_np)
79
+ if max_val > min_val:
80
+ img_np = (img_np - min_val) / (max_val - min_val)
81
+
82
+ img_array = (img_np * 255).astype(np.uint8)
83
+ return Image.fromarray(img_array)
84
 
 
 
 
 
85
 
86
+ # --- 4. Main Prediction and Visualization Function ---
87
+ def run_forecast(channel_name, progress=gr.Progress()):
88
+ """
89
+ This function is triggered by the button click in the Gradio interface.
90
+ It runs the model prediction and generates the images for display.
91
+ """
92
+ progress(0, desc="Loading model and data (first run may be slow)...")
93
+ # Load the model and data (will be fast after the first run due to caching)
94
+ model, test_input, test_label = load_model_and_data()
95
+
96
+ progress(0.5, desc="Running inference on the model...")
97
+ # Perform the forecast
98
  with torch.no_grad():
99
+ prediction = model(test_input)
100
+
101
+ progress(0.8, desc="Generating visualizations...")
102
+ # Get the index of the selected channel
103
+ channel_index = ALL_CHANNELS.index(channel_name)
104
+
105
+ # Extract the last time step from the input sequence for display
106
+ # Shape: [batch, channels, time, height, width] -> select channel, last time step
107
+ input_slice = test_input[0, channel_index, -1, :, :]
108
+ input_image = tensor_to_image(input_slice)
109
+
110
+ # Extract the corresponding slice from the model's prediction
111
+ # Shape: [batch, channels, time, height, width] -> select channel, first predicted step
112
+ predicted_slice = prediction[0, channel_index, 0, :, :]
113
+ predicted_image = tensor_to_image(predicted_slice)
114
+
115
+ # Extract the corresponding slice from the ground truth label
116
+ label_slice = test_label[0, channel_index, 0, :, :]
117
+ label_image = tensor_to_image(label_slice)
118
+
119
+ print(f"Forecast generated for channel: {channel_name}")
120
+ return input_image, predicted_image, label_image
121
+
122
+ # --- 5. Building the Gradio Interface ---
 
 
 
123
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
124
  gr.Markdown(
125
  """
126
  <div align="center">
127
+ # ☀️ Surya: A Live Demonstration of NASA's Heliophysics Foundation Model ☀️
128
+ This demo runs the actual Surya model to forecast solar activity. It uses the official test data for **2014-01-07**,
129
+ allowing a direct comparison between the model's prediction and the real ground truth.
130
  </div>
131
  """
132
  )
 
 
 
 
133
 
134
  with gr.Row():
135
+ channel_selector = gr.Dropdown(
136
+ choices=ALL_CHANNELS,
137
+ value=ALL_CHANNELS[2], # Default to "AIA 171 Å"
138
+ label="🛰️ Select SDO Instrument Channel",
139
+ info="Choose which solar observation channel to visualize."
140
+ )
141
+
142
+ run_button = gr.Button("🔮 Generate Forecast for 2014-01-07", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
+ with gr.Row():
145
+ with gr.Column():
146
+ gr.Markdown("### ⬅️ Final Input Image")
147
+ gr.Markdown("The last image shown to the model before it makes a prediction.")
148
+ input_display = gr.Image(label="Input Observation", height=400, width=400)
149
+ with gr.Column():
150
+ gr.Markdown("### 🔮 Model's Forecast")
151
+ gr.Markdown("What the Surya model predicted the Sun would look like.")
152
+ prediction_display = gr.Image(label="Surya Prediction", height=400, width=400)
153
+ with gr.Column():
154
+ gr.Markdown("### ✅ Ground Truth")
155
+ gr.Markdown("What the Sun *actually* looked like at the forecast time.")
156
+ label_display = gr.Image(label="Actual Observation", height=400, width=400)
157
+
158
  gr.Markdown(
159
+ "--- \n"
160
+ "**Note:** The first time you run a forecast, the app will download the 366M-parameter model (~1.4 GB) and test data. Subsequent runs will be much faster. "
161
+ "The images are downscaled for display in this demo. "
162
+ "For more information, visit the [Surya Hugging Face Repository](https://huggingface.co/nasa-ibm-ai4science/Surya-1.0)."
163
  )
164
+
165
+ run_button.click(
166
+ fn=run_forecast,
167
+ inputs=[channel_selector],
168
+ outputs=[input_display, prediction_display, label_display]
169
  )
170
 
171
  if __name__ == "__main__":
172
+ demo.launch(debug=True)