broadfield-dev commited on
Commit
f966038
·
verified ·
1 Parent(s): 69e42f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -73
app.py CHANGED
@@ -1,75 +1,135 @@
1
- # Installation
2
- # Clone the repository
3
- #git clone https://github.com/NASA-IMPACT/Surya.git
4
- #cd Surya
5
-
6
- # Install dependencies (using uv as recommended, or use pip with requirements.txt if available)
7
- #curl -LsSf https://astral.sh/uv/install.sh | sh
8
- #source ~/.bashrc
9
- #uv sync
10
- #source .venv/bin/activate
11
-
12
- # Alternatively, if using pip:
13
- # pip install -r requirements.txt # Assuming the repo has this file
14
-
15
- # Usage Example: Load the model and perform zero-shot forecasting
16
- #import os
17
- #os.system("pip freeze")
18
- #os.system("python -v")
19
- #os.system("pip install git+https://github.com/NASA-IMPACT/Surya.git")
20
-
21
  import torch
22
  from huggingface_hub import hf_hub_download
23
- from surya.model import Surya # Adjust import based on actual module/class name in repo (likely surya.model or similar)
24
-
25
- # Download pretrained weights from Hugging Face
26
- checkpoint_path = hf_hub_download(
27
- repo_id="nasa-ibm-ai4science/Surya-1.0",
28
- filename="surya.366m.v1.pt" # Adjust filename based on actual weights file in the repo
29
- )
30
-
31
- # Initialize the model (parameters inferred from architecture description)
32
- model = Surya(
33
- img_size=4096, # Native resolution 4096x4096
34
- patch_size=16, # Patch size 16x16, resulting in 65,536 tokens
35
- in_chans=13, # 8 AIA channels + 5 HMI products
36
- embed_dim=1280, # Internal dimension
37
- spectral_blocks=2, # Two spectral gating blocks
38
- attention_blocks=8, # Eight long-short attention layers
39
- # Additional params like mlp_ratio=4, norm_layer=torch.nn.LayerNorm, etc., as needed
40
- )
41
-
42
- # Load weights
43
- model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
44
- model.eval() # Set to evaluation mode for inference
45
-
46
- # Prepare input data (example: batch of multi-instrument SDO data)
47
- # Input shape: [batch_size, channels=13, time_steps, height=4096, width=4096]
48
- # Preprocess data as per the paper: alignment, normalization (scaled signum-log transform)
49
- input_tensor = torch.randn(1, 13, 5, 4096, 4096) # Dummy input: 1 batch, 13 channels, 5 time steps
50
-
51
- # Perform inference (e.g., predict 60 minutes ahead)
52
- with torch.no_grad():
53
- prediction = model(input_tensor) # Output: future SDO imagery
54
-
55
- # Post-process prediction (denormalize, visualize, etc.)
56
- print(prediction.shape) # Expected: similar shape to input, shifted in time
57
-
58
- # For fine-tuning with LoRA on downstream tasks (e.g., solar flare forecasting)
59
- # Use libraries like peft (Parameter-Efficient Fine-Tuning)
60
- from peft import LoraConfig, get_peft_model
61
-
62
- lora_config = LoraConfig(
63
- r=16, # Rank
64
- lora_alpha=32,
65
- target_modules=["attention"], # Target long-short attention modules
66
- lora_dropout=0.05
67
- )
68
- model = get_peft_model(model, lora_config)
69
-
70
- # Then train on your dataset
71
- # optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
72
- # ... training loop ...
73
-
74
- # Refer to downstream_examples in the repo for specific tasks like finetune.py for solar flare forecasting
75
- # Example: cd downstream_examples/solar_flare_forecasting; torchrun --nproc_per_node=1 finetune.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
  from huggingface_hub import hf_hub_download
4
+ from surya.model import Surya
5
+ import numpy as np
6
+ from PIL import Image
7
+ import warnings
8
+
9
+ # Suppress warnings for a cleaner demo experience
10
+ warnings.filterwarnings("ignore")
11
+
12
+ # --- Model Loading ---
13
+ @gr.cache
14
+ def load_model():
15
+ """
16
+ Downloads the pre-trained Surya model weights and initializes the model.
17
+ This function is cached so the model is only loaded once.
18
+ """
19
+ checkpoint_path = hf_hub_download(
20
+ repo_id="nasa-ibm-ai4science/Surya-1.0",
21
+ filename="surya.366m.v1.pt"
22
+ )
23
+
24
+ model = Surya(
25
+ img_size=4096,
26
+ patch_size=16,
27
+ in_chans=13,
28
+ embed_dim=1280,
29
+ spectral_blocks=2,
30
+ attention_blocks=8,
31
+ )
32
+
33
+ model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
34
+ model.eval()
35
+ return model
36
+
37
+ model = load_model()
38
+
39
+ # --- Core Prediction Logic ---
40
+ def predict_solar_activity(time_steps, forecast_horizon):
41
+ """
42
+ Generates a forecast of solar activity using the Surya model.
43
+ For this demo, we use a dummy input tensor to simulate the model's input.
44
+ In a real-world scenario, this function would fetch and preprocess
45
+ actual SDO data for the given time steps.
46
+ """
47
+ # Create a dummy input tensor representing a sequence of solar observations
48
+ # Shape: [batch_size, channels, time_steps, height, width]
49
+ dummy_input = torch.randn(1, 13, time_steps, 4096, 4096)
50
+
51
+ # In a real application, you would replace the dummy input with actual,
52
+ # preprocessed data from the Solar Dynamics Observatory (SDO).
53
+ # Preprocessing would involve alignment and normalization as described
54
+ # in the Surya paper.
55
+
56
+ with torch.no_grad():
57
+ # The model's prediction would be based on the forecast_horizon
58
+ # For this demo, we simulate a prediction by selecting a slice of the input
59
+ prediction = model(dummy_input)
60
+
61
+ # --- Visualization ---
62
+ # For demonstration, we will visualize one of the output channels.
63
+ # We will take the last predicted time step.
64
+ predicted_image_tensor = prediction[0, 0, -1, :, :] # Visualizing the first channel
65
+
66
+ # Normalize the tensor to a 0-255 range for image display
67
+ normalized_tensor = (predicted_image_tensor - predicted_image_tensor.min()) / \
68
+ (predicted_image_tensor.max() - predicted_image_tensor.min())
69
+ image_array = (normalized_tensor * 255).byte().cpu().numpy()
70
+ predicted_image = Image.fromarray(image_array)
71
+
72
+ # For the flare prediction, we'll generate a dummy probability
73
+ flare_probability = np.random.rand()
74
+ if flare_probability > 0.5:
75
+ flare_class = "M-class or X-class Flare"
76
+ confidence = flare_probability
77
+ else:
78
+ flare_class = "No significant flare"
79
+ confidence = 1 - flare_probability
80
+
81
+ return predicted_image, {flare_class: confidence}
82
+
83
+ # --- Gradio Interface Definition ---
84
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
85
+ gr.Markdown(
86
+ """
87
+ <div align="center">
88
+ # ☀️ Surya: Foundation Model for Heliophysics ☀️
89
+ *A Gradio Demo for NASA's Solar Foundation Model*
90
+ </div>
91
+ """
92
+ )
93
+ gr.Markdown(
94
+ "Surya is a 366M-parameter foundation model trained on full-resolution, multi-instrument SDO observations. "
95
+ "This demo showcases its capability to forecast solar dynamics."
96
+ )
97
+
98
+ with gr.Row():
99
+ with gr.Column(scale=1):
100
+ gr.Markdown("### ⚙️ Prediction Parameters")
101
+ time_steps_slider = gr.Slider(
102
+ minimum=1, maximum=10, value=5, step=1,
103
+ label="Number of Input Time Steps (12-min cadence)",
104
+ info="Represents the sequence of past solar observations to feed the model."
105
+ )
106
+ forecast_horizon_slider = gr.Slider(
107
+ minimum=1, maximum=24, value=1,
108
+ label="Forecast Horizon (hours)",
109
+ info="How far into the future to predict."
110
+ )
111
+ predict_button = gr.Button("🔮 Generate Forecast", variant="primary")
112
+
113
+ with gr.Column(scale=2):
114
+ gr.Markdown("### 🛰️ Predicted Solar Image")
115
+ output_image = gr.Image(label="Forecasted SDO Image (AIA 171 Å)", height=512, width=512)
116
+ gr.Markdown("### 💥 Solar Flare Prediction")
117
+ output_flare = gr.Label(label="Flare Probability")
118
+
119
+ predict_button.click(
120
+ fn=predict_solar_activity,
121
+ inputs=[time_steps_slider, forecast_horizon_slider],
122
+ outputs=[output_image, output_flare]
123
+ )
124
+
125
+ gr.Markdown("---")
126
+ gr.Markdown(
127
+ "**Note:** This demo uses a placeholder for real-time data fetching and displays a simulated prediction. "
128
+ "The core of this application is the loaded Surya model from NASA and IBM."
129
+ )
130
+ gr.Markdown(
131
+ "For more information, visit the [Surya model card on Hugging Face](https://huggingface.co/nasa-ibm-ai4science/Surya-1.0)."
132
+ )
133
+
134
+ if __name__ == "__main__":
135
+ demo.launch()