broadfield-dev commited on
Commit
5f00446
·
verified ·
1 Parent(s): 827fdf1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Installation
2
+ # Clone the repository
3
+ #git clone https://github.com/NASA-IMPACT/Surya.git
4
+ #cd Surya
5
+
6
+ # Install dependencies (using uv as recommended, or use pip with requirements.txt if available)
7
+ #curl -LsSf https://astral.sh/uv/install.sh | sh
8
+ #source ~/.bashrc
9
+ #uv sync
10
+ #source .venv/bin/activate
11
+
12
+ # Alternatively, if using pip:
13
+ # pip install -r requirements.txt # Assuming the repo has this file
14
+
15
+ # Usage Example: Load the model and perform zero-shot forecasting
16
+ import os
17
+ os.system("pip install git+https://github.com/NASA-IMPACT/Surya.git")
18
+
19
+ import torch
20
+ from huggingface_hub import hf_hub_download
21
+ from surya.model import Surya # Adjust import based on actual module/class name in repo (likely surya.model or similar)
22
+
23
+ # Download pretrained weights from Hugging Face
24
+ checkpoint_path = hf_hub_download(
25
+ repo_id="nasa-ibm-ai4science/Surya-1.0",
26
+ filename="surya.366m.v1.pt" # Adjust filename based on actual weights file in the repo
27
+ )
28
+
29
+ # Initialize the model (parameters inferred from architecture description)
30
+ model = Surya(
31
+ img_size=4096, # Native resolution 4096x4096
32
+ patch_size=16, # Patch size 16x16, resulting in 65,536 tokens
33
+ in_chans=13, # 8 AIA channels + 5 HMI products
34
+ embed_dim=1280, # Internal dimension
35
+ spectral_blocks=2, # Two spectral gating blocks
36
+ attention_blocks=8, # Eight long-short attention layers
37
+ # Additional params like mlp_ratio=4, norm_layer=torch.nn.LayerNorm, etc., as needed
38
+ )
39
+
40
+ # Load weights
41
+ model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
42
+ model.eval() # Set to evaluation mode for inference
43
+
44
+ # Prepare input data (example: batch of multi-instrument SDO data)
45
+ # Input shape: [batch_size, channels=13, time_steps, height=4096, width=4096]
46
+ # Preprocess data as per the paper: alignment, normalization (scaled signum-log transform)
47
+ input_tensor = torch.randn(1, 13, 5, 4096, 4096) # Dummy input: 1 batch, 13 channels, 5 time steps
48
+
49
+ # Perform inference (e.g., predict 60 minutes ahead)
50
+ with torch.no_grad():
51
+ prediction = model(input_tensor) # Output: future SDO imagery
52
+
53
+ # Post-process prediction (denormalize, visualize, etc.)
54
+ print(prediction.shape) # Expected: similar shape to input, shifted in time
55
+
56
+ # For fine-tuning with LoRA on downstream tasks (e.g., solar flare forecasting)
57
+ # Use libraries like peft (Parameter-Efficient Fine-Tuning)
58
+ from peft import LoraConfig, get_peft_model
59
+
60
+ lora_config = LoraConfig(
61
+ r=16, # Rank
62
+ lora_alpha=32,
63
+ target_modules=["attention"], # Target long-short attention modules
64
+ lora_dropout=0.05
65
+ )
66
+ model = get_peft_model(model, lora_config)
67
+
68
+ # Then train on your dataset
69
+ # optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
70
+ # ... training loop ...
71
+
72
+ # Refer to downstream_examples in the repo for specific tasks like finetune.py for solar flare forecasting
73
+ # Example: cd downstream_examples/solar_flare_forecasting; torchrun --nproc_per_node=1 finetune.py