import gradio as gr
import torch
from huggingface_hub import snapshot_download
import yaml
import numpy as np
from PIL import Image
import requests
import os
import warnings
import logging
import datetime
import matplotlib.pyplot as plt
import sunpy.visualization.colormaps as sunpy_cm
import traceback
from io import BytesIO
import re
from surya.models.helio_spectformer import HelioSpectFormer
from surya.utils.data import build_scalers
from surya.datasets.helio import inverse_transform_single_channel
warnings.filterwarnings("ignore", category=UserWarning, module='sunpy')
warnings.filterwarnings("ignore", category=FutureWarning)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
APP_CACHE = {}
CHANNEL_TO_URL_CODE = {
"aia94": "0094", "aia131": "0131", "aia171": "0171", "aia193": "0193",
"aia211": "0211", "aia304": "0304", "aia335": "0335", "aia1600": "1600",
"hmi_m": "HMIBC", "hmi_bx": "HMIB", "hmi_by": "HMIB",
"hmi_bz": "HMIB", "hmi_v": "HMID"
}
SDO_CHANNELS = list(CHANNEL_TO_URL_CODE.keys())
def setup_and_load_model():
if "model" in APP_CACHE:
yield "Model already loaded. Skipping setup."
return
yield "Downloading model files (first run only)..."
snapshot_download(repo_id="nasa-ibm-ai4science/Surya-1.0", local_dir="data/Surya-1.0",
allow_patterns=["config.yaml", "scalers.yaml", "surya.366m.v1.pt"])
yield "Loading configuration and data scalers..."
with open("data/Surya-1.0/config.yaml") as fp:
config = yaml.safe_load(fp)
APP_CACHE["config"] = config
scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r"))
APP_CACHE["scalers"] = build_scalers(info=scalers_info)
yield "Initializing model architecture..."
model_config = config["model"]
model = HelioSpectFormer(
img_size=model_config["img_size"], patch_size=model_config["patch_size"],
in_chans=len(config["data"]["sdo_channels"]), embed_dim=model_config["embed_dim"],
time_embedding={"type": "linear", "time_dim": len(config["data"]["time_delta_input_minutes"])},
depth=model_config["depth"], n_spectral_blocks=model_config["n_spectral_blocks"],
num_heads=model_config["num_heads"], mlp_ratio=model_config["mlp_ratio"],
drop_rate=model_config["drop_rate"], dtype=torch.bfloat16,
window_size=model_config["window_size"], dp_rank=model_config["dp_rank"],
learned_flow=model_config["learned_flow"], use_latitude_in_learned_flow=model_config["learned_flow"],
init_weights=False, checkpoint_layers=list(range(model_config["depth"])),
rpe=model_config["rpe"], ensemble=model_config["ensemble"], finetune=model_config["finetune"],
)
device = "cuda" if torch.cuda.is_available() else "cpu"
APP_CACHE["device"] = device
yield f"Loading model weights to {device}..."
weights = torch.load(f"data/Surya-1.0/surya.366m.v1.pt", map_location=torch.device(device))
model.load_state_dict(weights, strict=True)
model.to(device)
model.eval()
APP_CACHE["model"] = model
yield "✅ Model setup complete."
def find_nearest_browse_image_url(channel, target_dt):
url_code = CHANNEL_TO_URL_CODE[channel]
base_url = "https://sdo.gsfc.nasa.gov/assets/img/browse"
for i in range(2):
dt_to_try = target_dt - datetime.timedelta(days=i)
dir_url = dt_to_try.strftime(f"{base_url}/%Y/%m/%d/")
response = requests.get(dir_url)
if response.status_code != 200:
continue
filenames = re.findall(r'href="(\d{8}_\d{6}_4096_' + url_code + r'\.jpg)"', response.text)
if not filenames:
continue
best_filename = ""
min_diff = float('inf')
for fname in filenames:
try:
timestamp_str = fname.split('_')[1]
img_dt = datetime.datetime.strptime(f"{dt_to_try.strftime('%Y%m%d')}{timestamp_str}", "%Y%m%d%H%M%S")
diff = abs((target_dt - img_dt).total_seconds())
if diff < min_diff:
min_diff = diff
best_filename = fname
except (ValueError, IndexError):
continue
if best_filename:
return dir_url + best_filename
raise FileNotFoundError(f"Could not find any browse images for {channel} in the last 48 hours.")
def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
config = APP_CACHE["config"]
img_size = config["model"]["img_size"]
input_deltas = config["data"]["time_delta_input_minutes"]
input_times = [target_dt + datetime.timedelta(minutes=m) for m in input_deltas]
target_time = target_dt + datetime.timedelta(minutes=forecast_horizon_minutes)
all_times = sorted(list(set(input_times + [target_time])))
images = {}
total_fetches = len(all_times) * len(SDO_CHANNELS)
fetches_done = 0
yield f"Starting search for {total_fetches} data files..."
for t in all_times:
images[t] = {}
for channel in SDO_CHANNELS:
fetches_done += 1
yield f"Finding [{fetches_done}/{total_fetches}]: Closest image for {channel} near {t.strftime('%Y-%m-%d %H:%M')}..."
image_url = find_nearest_browse_image_url(channel, t)
yield f"Downloading: {os.path.basename(image_url)}..."
response = requests.get(image_url)
response.raise_for_status()
images[t][channel] = Image.open(BytesIO(response.content))
yield "✅ All images found and downloaded. Starting preprocessing..."
scalers_dict = APP_CACHE["scalers"]
processed_tensors = {}
for t, channel_images in images.items():
channel_tensors = []
for i, channel in enumerate(SDO_CHANNELS):
img = channel_images[channel]
if img.mode != 'L':
img = img.convert('L')
img_resized = img.resize((img_size, img_size), Image.Resampling.LANCZOS)
norm_data = np.array(img_resized, dtype=np.float32)
scaler = scalers_dict[channel]
scaled_data = scaler.transform(norm_data.reshape(-1, 1)).reshape(norm_data.shape)
channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32)))
processed_tensors[t] = torch.stack(channel_tensors)
yield "✅ Preprocessing complete."
input_tensor_list = [processed_tensors[t] for t in input_times]
input_tensor = torch.stack(input_tensor_list, dim=1).unsqueeze(0)
target_image_map = images[target_time]
last_input_image_map = images[input_times[-1]]
yield (input_tensor, last_input_image_map, target_image_map)
def run_inference(input_tensor):
model = APP_CACHE["model"]
device = APP_CACHE["device"]
time_deltas = APP_CACHE["config"]["data"]["time_delta_input_minutes"]
time_delta_tensor = torch.tensor(time_deltas, dtype=torch.float32).unsqueeze(0).to(device)
input_batch = {"ts": input_tensor.to(device), "time_delta_input": time_delta_tensor}
with torch.no_grad():
with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16):
prediction = model(input_batch)
return prediction.cpu().to(torch.float32)
def generate_visualization(last_input_map, prediction_tensor, target_map, channel_name):
if last_input_map is None: return None, None, None
c_idx = SDO_CHANNELS.index(channel_name)
# *** FIX: Access the specific scaler for the channel from the dictionary ***
scaler = APP_CACHE["scalers"][channel_name]
# *** FIX: Access the parameters as attributes, not from to_dict() ***
mean = scaler.mean
std = scaler.std
epsilon = scaler.epsilon
sl_scale_factor = scaler.sl_scale_factor
pred_slice = inverse_transform_single_channel(
prediction_tensor[0, c_idx].numpy(), mean=mean, std=std, epsilon=epsilon, sl_scale_factor=sl_scale_factor
)
target_img_data = np.array(target_map[channel_name])
vmax = np.quantile(np.nan_to_num(target_img_data), 0.995)
cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray'))
def to_pil(data):
data_clipped = np.nan_to_num(data)
data_clipped = np.clip(data_clipped, 0, vmax)
data_norm = data_clipped / vmax if vmax > 0 else data_clipped
colored = (cmap(data_norm)[:, :, :3] * 255).astype(np.uint8)
return Image.fromarray(colored)
return last_input_map[channel_name], to_pil(pred_slice), target_map[channel_name]
def forecast_controller(date_str, hour, minute, forecast_horizon):
yield {
log_box: gr.update(value="Starting forecast...", visible=True),
run_button: gr.update(interactive=False),
date_input: gr.update(interactive=False),
hour_slider: gr.update(interactive=False),
minute_slider: gr.update(interactive=False),
horizon_slider: gr.update(interactive=False),
results_group: gr.update(visible=False)
}
try:
if not date_str: raise gr.Error("Please select a date.")
for status in setup_and_load_model():
yield { log_box: status }
target_dt = datetime.datetime.fromisoformat(f"{date_str}T{int(hour):02d}:{int(minute):02d}:00")
data_pipeline = fetch_and_process_sdo_data(target_dt, forecast_horizon)
while True:
try:
status = next(data_pipeline)
if isinstance(status, tuple):
input_tensor, last_input_map, target_map = status
break
yield { log_box: status }
except StopIteration:
raise gr.Error("Data processing pipeline finished unexpectedly.")
yield { log_box: "Running AI model inference..." }
prediction_tensor = run_inference(input_tensor)
yield { log_box: "Generating final visualizations..." }
img_in, img_pred, img_target = generate_visualization(last_input_map, prediction_tensor, target_map, "aia171")
yield {
log_box: f"✅ Forecast complete for {target_dt.isoformat()} (+{forecast_horizon} mins).",
results_group: gr.update(visible=True),
state_last_input: last_input_map,
state_prediction: prediction_tensor,
state_target: target_map,
input_display: img_in,
prediction_display: img_pred,
target_display: img_target,
}
except Exception as e:
error_str = traceback.format_exc()
logger.error(f"An error occurred: {e}\n{error_str}")
yield { log_box: f"❌ ERROR: {e}\n\nTraceback:\n{error_str}" }
finally:
yield {
run_button: gr.update(interactive=True),
date_input: gr.update(interactive=True),
hour_slider: gr.update(interactive=True),
minute_slider: gr.update(interactive=True),
horizon_slider: gr.update(interactive=True),
}
with gr.Blocks(theme=gr.themes.Soft()) as demo:
state_last_input = gr.State()
state_prediction = gr.State()
state_target = gr.State()
gr.Markdown(
"""
# ☀️ Surya: Live Forecast Demo ☀️
### A Foundation Model for Solar Dynamics
This demo runs NASA's **Surya**, a foundation model trained to understand the physics of the Sun.
It looks at the Sun in 13 different channels (wavelengths of light) simultaneously to learn the complex relationships between phenomena like coronal loops, magnetic fields, and solar flares. By seeing these interconnected views, it can generate a holistic forecast of what the entire solar disk will look like in the near future.
NOTE: This demo uses lower-quality browse images for reliability. The model was trained on high-fidelity scientific data, so forecast accuracy may vary.
""" ) with gr.Accordion("Step 1: Configure Forecast", open=True): with gr.Row(): date_input = gr.Textbox( label="Date (YYYY-MM-DD)", value=(datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=3)).strftime("%Y-%m-%d") ) hour_slider = gr.Slider(label="Hour (UTC)", minimum=0, maximum=23, step=1, value=(datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=3)).hour) minute_slider = gr.Slider(label="Minute (UTC)", minimum=0, maximum=59, step=1, value=(datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=3)).minute) horizon_slider = gr.Slider( label="Forecast Horizon (minutes ahead)", minimum=12, maximum=120, step=12, value=12 ) run_button = gr.Button("🔮 Generate Forecast", variant="primary") with gr.Accordion("Step 2: View Log", open=False): log_box = gr.Textbox(label="Log", interactive=False, visible=False, lines=5, max_lines=10) with gr.Group(visible=False) as results_group: gr.Markdown("### Step 3: Explore Results") channel_selector = gr.Dropdown( choices=SDO_CHANNELS, value="aia171", label="🛰️ Select SDO Channel to Visualize" ) with gr.Row(): input_display = gr.Image(label="Last Input to Model", height=512, width=512, interactive=False) prediction_display = gr.Image(label="Surya's Forecast", height=512, width=512, interactive=False) target_display = gr.Image(label="Ground Truth", height=512, width=512, interactive=False) run_button.click( fn=forecast_controller, inputs=[date_input, hour_slider, minute_slider, horizon_slider], outputs=[ log_box, run_button, date_input, hour_slider, minute_slider, horizon_slider, results_group, state_last_input, state_prediction, state_target, input_display, prediction_display, target_display ] ) channel_selector.change( fn=generate_visualization, inputs=[state_last_input, state_prediction, state_target, channel_selector], outputs=[input_display, prediction_display, target_display] ) if __name__ == "__main__": demo.launch(debug=True)