# Installation # Clone the repository #git clone https://github.com/NASA-IMPACT/Surya.git #cd Surya # Install dependencies (using uv as recommended, or use pip with requirements.txt if available) #curl -LsSf https://astral.sh/uv/install.sh | sh #source ~/.bashrc #uv sync #source .venv/bin/activate # Alternatively, if using pip: # pip install -r requirements.txt # Assuming the repo has this file # Usage Example: Load the model and perform zero-shot forecasting import os os.system("pip freeze") os.system("python -v") os.system("pip install git+https://github.com/NASA-IMPACT/Surya.git") import torch from huggingface_hub import hf_hub_download from surya.model import Surya # Adjust import based on actual module/class name in repo (likely surya.model or similar) # Download pretrained weights from Hugging Face checkpoint_path = hf_hub_download( repo_id="nasa-ibm-ai4science/Surya-1.0", filename="surya.366m.v1.pt" # Adjust filename based on actual weights file in the repo ) # Initialize the model (parameters inferred from architecture description) model = Surya( img_size=4096, # Native resolution 4096x4096 patch_size=16, # Patch size 16x16, resulting in 65,536 tokens in_chans=13, # 8 AIA channels + 5 HMI products embed_dim=1280, # Internal dimension spectral_blocks=2, # Two spectral gating blocks attention_blocks=8, # Eight long-short attention layers # Additional params like mlp_ratio=4, norm_layer=torch.nn.LayerNorm, etc., as needed ) # Load weights model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) model.eval() # Set to evaluation mode for inference # Prepare input data (example: batch of multi-instrument SDO data) # Input shape: [batch_size, channels=13, time_steps, height=4096, width=4096] # Preprocess data as per the paper: alignment, normalization (scaled signum-log transform) input_tensor = torch.randn(1, 13, 5, 4096, 4096) # Dummy input: 1 batch, 13 channels, 5 time steps # Perform inference (e.g., predict 60 minutes ahead) with torch.no_grad(): prediction = model(input_tensor) # Output: future SDO imagery # Post-process prediction (denormalize, visualize, etc.) print(prediction.shape) # Expected: similar shape to input, shifted in time # For fine-tuning with LoRA on downstream tasks (e.g., solar flare forecasting) # Use libraries like peft (Parameter-Efficient Fine-Tuning) from peft import LoraConfig, get_peft_model lora_config = LoraConfig( r=16, # Rank lora_alpha=32, target_modules=["attention"], # Target long-short attention modules lora_dropout=0.05 ) model = get_peft_model(model, lora_config) # Then train on your dataset # optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) # ... training loop ... # Refer to downstream_examples in the repo for specific tasks like finetune.py for solar flare forecasting # Example: cd downstream_examples/solar_flare_forecasting; torchrun --nproc_per_node=1 finetune.py