Spaces:
Running
Running
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Installation
|
2 |
+
# Clone the repository
|
3 |
+
#git clone https://github.com/NASA-IMPACT/Surya.git
|
4 |
+
#cd Surya
|
5 |
+
|
6 |
+
# Install dependencies (using uv as recommended, or use pip with requirements.txt if available)
|
7 |
+
#curl -LsSf https://astral.sh/uv/install.sh | sh
|
8 |
+
#source ~/.bashrc
|
9 |
+
#uv sync
|
10 |
+
#source .venv/bin/activate
|
11 |
+
|
12 |
+
# Alternatively, if using pip:
|
13 |
+
# pip install -r requirements.txt # Assuming the repo has this file
|
14 |
+
|
15 |
+
# Usage Example: Load the model and perform zero-shot forecasting
|
16 |
+
import os
|
17 |
+
os.system("pip install git+https://github.com/NASA-IMPACT/Surya.git")
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from huggingface_hub import hf_hub_download
|
21 |
+
from surya.model import Surya # Adjust import based on actual module/class name in repo (likely surya.model or similar)
|
22 |
+
|
23 |
+
# Download pretrained weights from Hugging Face
|
24 |
+
checkpoint_path = hf_hub_download(
|
25 |
+
repo_id="nasa-ibm-ai4science/Surya-1.0",
|
26 |
+
filename="surya.366m.v1.pt" # Adjust filename based on actual weights file in the repo
|
27 |
+
)
|
28 |
+
|
29 |
+
# Initialize the model (parameters inferred from architecture description)
|
30 |
+
model = Surya(
|
31 |
+
img_size=4096, # Native resolution 4096x4096
|
32 |
+
patch_size=16, # Patch size 16x16, resulting in 65,536 tokens
|
33 |
+
in_chans=13, # 8 AIA channels + 5 HMI products
|
34 |
+
embed_dim=1280, # Internal dimension
|
35 |
+
spectral_blocks=2, # Two spectral gating blocks
|
36 |
+
attention_blocks=8, # Eight long-short attention layers
|
37 |
+
# Additional params like mlp_ratio=4, norm_layer=torch.nn.LayerNorm, etc., as needed
|
38 |
+
)
|
39 |
+
|
40 |
+
# Load weights
|
41 |
+
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
|
42 |
+
model.eval() # Set to evaluation mode for inference
|
43 |
+
|
44 |
+
# Prepare input data (example: batch of multi-instrument SDO data)
|
45 |
+
# Input shape: [batch_size, channels=13, time_steps, height=4096, width=4096]
|
46 |
+
# Preprocess data as per the paper: alignment, normalization (scaled signum-log transform)
|
47 |
+
input_tensor = torch.randn(1, 13, 5, 4096, 4096) # Dummy input: 1 batch, 13 channels, 5 time steps
|
48 |
+
|
49 |
+
# Perform inference (e.g., predict 60 minutes ahead)
|
50 |
+
with torch.no_grad():
|
51 |
+
prediction = model(input_tensor) # Output: future SDO imagery
|
52 |
+
|
53 |
+
# Post-process prediction (denormalize, visualize, etc.)
|
54 |
+
print(prediction.shape) # Expected: similar shape to input, shifted in time
|
55 |
+
|
56 |
+
# For fine-tuning with LoRA on downstream tasks (e.g., solar flare forecasting)
|
57 |
+
# Use libraries like peft (Parameter-Efficient Fine-Tuning)
|
58 |
+
from peft import LoraConfig, get_peft_model
|
59 |
+
|
60 |
+
lora_config = LoraConfig(
|
61 |
+
r=16, # Rank
|
62 |
+
lora_alpha=32,
|
63 |
+
target_modules=["attention"], # Target long-short attention modules
|
64 |
+
lora_dropout=0.05
|
65 |
+
)
|
66 |
+
model = get_peft_model(model, lora_config)
|
67 |
+
|
68 |
+
# Then train on your dataset
|
69 |
+
# optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
70 |
+
# ... training loop ...
|
71 |
+
|
72 |
+
# Refer to downstream_examples in the repo for specific tasks like finetune.py for solar flare forecasting
|
73 |
+
# Example: cd downstream_examples/solar_flare_forecasting; torchrun --nproc_per_node=1 finetune.py
|