File size: 3,010 Bytes
5f00446
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7826aa
bd0dd03
5f00446
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# Installation
# Clone the repository
#git clone https://github.com/NASA-IMPACT/Surya.git
#cd Surya

# Install dependencies (using uv as recommended, or use pip with requirements.txt if available)
#curl -LsSf https://astral.sh/uv/install.sh | sh
#source ~/.bashrc
#uv sync
#source .venv/bin/activate

# Alternatively, if using pip:
# pip install -r requirements.txt  # Assuming the repo has this file

# Usage Example: Load the model and perform zero-shot forecasting
import os
os.system("pip freeze")
os.system("python -v")
os.system("pip install git+https://github.com/NASA-IMPACT/Surya.git")

import torch
from huggingface_hub import hf_hub_download
from surya.model import Surya  # Adjust import based on actual module/class name in repo (likely surya.model or similar)

# Download pretrained weights from Hugging Face
checkpoint_path = hf_hub_download(
    repo_id="nasa-ibm-ai4science/Surya-1.0",
    filename="surya.366m.v1.pt"  # Adjust filename based on actual weights file in the repo
)

# Initialize the model (parameters inferred from architecture description)
model = Surya(
    img_size=4096,          # Native resolution 4096x4096
    patch_size=16,          # Patch size 16x16, resulting in 65,536 tokens
    in_chans=13,            # 8 AIA channels + 5 HMI products
    embed_dim=1280,         # Internal dimension
    spectral_blocks=2,      # Two spectral gating blocks
    attention_blocks=8,     # Eight long-short attention layers
    # Additional params like mlp_ratio=4, norm_layer=torch.nn.LayerNorm, etc., as needed
)

# Load weights
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
model.eval()  # Set to evaluation mode for inference

# Prepare input data (example: batch of multi-instrument SDO data)
# Input shape: [batch_size, channels=13, time_steps, height=4096, width=4096]
# Preprocess data as per the paper: alignment, normalization (scaled signum-log transform)
input_tensor = torch.randn(1, 13, 5, 4096, 4096)  # Dummy input: 1 batch, 13 channels, 5 time steps

# Perform inference (e.g., predict 60 minutes ahead)
with torch.no_grad():
    prediction = model(input_tensor)  # Output: future SDO imagery

# Post-process prediction (denormalize, visualize, etc.)
print(prediction.shape)  # Expected: similar shape to input, shifted in time

# For fine-tuning with LoRA on downstream tasks (e.g., solar flare forecasting)
# Use libraries like peft (Parameter-Efficient Fine-Tuning)
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=16,  # Rank
    lora_alpha=32,
    target_modules=["attention"],  # Target long-short attention modules
    lora_dropout=0.05
)
model = get_peft_model(model, lora_config)

# Then train on your dataset
# optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# ... training loop ... 

# Refer to downstream_examples in the repo for specific tasks like finetune.py for solar flare forecasting
# Example: cd downstream_examples/solar_flare_forecasting; torchrun --nproc_per_node=1 finetune.py