Spaces:
Running
Running
# Installation | |
# Clone the repository | |
#git clone https://github.com/NASA-IMPACT/Surya.git | |
#cd Surya | |
# Install dependencies (using uv as recommended, or use pip with requirements.txt if available) | |
#curl -LsSf https://astral.sh/uv/install.sh | sh | |
#source ~/.bashrc | |
#uv sync | |
#source .venv/bin/activate | |
# Alternatively, if using pip: | |
# pip install -r requirements.txt # Assuming the repo has this file | |
# Usage Example: Load the model and perform zero-shot forecasting | |
import os | |
os.system("pip freeze") | |
os.system("python -v") | |
os.system("pip install git+https://github.com/NASA-IMPACT/Surya.git") | |
import torch | |
from huggingface_hub import hf_hub_download | |
from surya.model import Surya # Adjust import based on actual module/class name in repo (likely surya.model or similar) | |
# Download pretrained weights from Hugging Face | |
checkpoint_path = hf_hub_download( | |
repo_id="nasa-ibm-ai4science/Surya-1.0", | |
filename="surya.366m.v1.pt" # Adjust filename based on actual weights file in the repo | |
) | |
# Initialize the model (parameters inferred from architecture description) | |
model = Surya( | |
img_size=4096, # Native resolution 4096x4096 | |
patch_size=16, # Patch size 16x16, resulting in 65,536 tokens | |
in_chans=13, # 8 AIA channels + 5 HMI products | |
embed_dim=1280, # Internal dimension | |
spectral_blocks=2, # Two spectral gating blocks | |
attention_blocks=8, # Eight long-short attention layers | |
# Additional params like mlp_ratio=4, norm_layer=torch.nn.LayerNorm, etc., as needed | |
) | |
# Load weights | |
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) | |
model.eval() # Set to evaluation mode for inference | |
# Prepare input data (example: batch of multi-instrument SDO data) | |
# Input shape: [batch_size, channels=13, time_steps, height=4096, width=4096] | |
# Preprocess data as per the paper: alignment, normalization (scaled signum-log transform) | |
input_tensor = torch.randn(1, 13, 5, 4096, 4096) # Dummy input: 1 batch, 13 channels, 5 time steps | |
# Perform inference (e.g., predict 60 minutes ahead) | |
with torch.no_grad(): | |
prediction = model(input_tensor) # Output: future SDO imagery | |
# Post-process prediction (denormalize, visualize, etc.) | |
print(prediction.shape) # Expected: similar shape to input, shifted in time | |
# For fine-tuning with LoRA on downstream tasks (e.g., solar flare forecasting) | |
# Use libraries like peft (Parameter-Efficient Fine-Tuning) | |
from peft import LoraConfig, get_peft_model | |
lora_config = LoraConfig( | |
r=16, # Rank | |
lora_alpha=32, | |
target_modules=["attention"], # Target long-short attention modules | |
lora_dropout=0.05 | |
) | |
model = get_peft_model(model, lora_config) | |
# Then train on your dataset | |
# optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) | |
# ... training loop ... | |
# Refer to downstream_examples in the repo for specific tasks like finetune.py for solar flare forecasting | |
# Example: cd downstream_examples/solar_flare_forecasting; torchrun --nproc_per_node=1 finetune.py |