surya-demo / app.py
broadfield-dev's picture
Update app.py
bd0dd03 verified
raw
history blame
3.01 kB
# Installation
# Clone the repository
#git clone https://github.com/NASA-IMPACT/Surya.git
#cd Surya
# Install dependencies (using uv as recommended, or use pip with requirements.txt if available)
#curl -LsSf https://astral.sh/uv/install.sh | sh
#source ~/.bashrc
#uv sync
#source .venv/bin/activate
# Alternatively, if using pip:
# pip install -r requirements.txt # Assuming the repo has this file
# Usage Example: Load the model and perform zero-shot forecasting
import os
os.system("pip freeze")
os.system("python -v")
os.system("pip install git+https://github.com/NASA-IMPACT/Surya.git")
import torch
from huggingface_hub import hf_hub_download
from surya.model import Surya # Adjust import based on actual module/class name in repo (likely surya.model or similar)
# Download pretrained weights from Hugging Face
checkpoint_path = hf_hub_download(
repo_id="nasa-ibm-ai4science/Surya-1.0",
filename="surya.366m.v1.pt" # Adjust filename based on actual weights file in the repo
)
# Initialize the model (parameters inferred from architecture description)
model = Surya(
img_size=4096, # Native resolution 4096x4096
patch_size=16, # Patch size 16x16, resulting in 65,536 tokens
in_chans=13, # 8 AIA channels + 5 HMI products
embed_dim=1280, # Internal dimension
spectral_blocks=2, # Two spectral gating blocks
attention_blocks=8, # Eight long-short attention layers
# Additional params like mlp_ratio=4, norm_layer=torch.nn.LayerNorm, etc., as needed
)
# Load weights
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
model.eval() # Set to evaluation mode for inference
# Prepare input data (example: batch of multi-instrument SDO data)
# Input shape: [batch_size, channels=13, time_steps, height=4096, width=4096]
# Preprocess data as per the paper: alignment, normalization (scaled signum-log transform)
input_tensor = torch.randn(1, 13, 5, 4096, 4096) # Dummy input: 1 batch, 13 channels, 5 time steps
# Perform inference (e.g., predict 60 minutes ahead)
with torch.no_grad():
prediction = model(input_tensor) # Output: future SDO imagery
# Post-process prediction (denormalize, visualize, etc.)
print(prediction.shape) # Expected: similar shape to input, shifted in time
# For fine-tuning with LoRA on downstream tasks (e.g., solar flare forecasting)
# Use libraries like peft (Parameter-Efficient Fine-Tuning)
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=16, # Rank
lora_alpha=32,
target_modules=["attention"], # Target long-short attention modules
lora_dropout=0.05
)
model = get_peft_model(model, lora_config)
# Then train on your dataset
# optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# ... training loop ...
# Refer to downstream_examples in the repo for specific tasks like finetune.py for solar flare forecasting
# Example: cd downstream_examples/solar_flare_forecasting; torchrun --nproc_per_node=1 finetune.py