surya-demo / app.py
broadfield-dev's picture
Update app.py
4ac0d86 verified
raw
history blame
13.6 kB
import gradio as gr
import torch
from huggingface_hub import snapshot_download
import yaml
import numpy as np
from PIL import Image
import requests
import os
import warnings
import logging
import datetime
import matplotlib.pyplot as plt
import sunpy.visualization.colormaps as sunpy_cm
import traceback
from io import BytesIO
from surya.models.helio_spectformer import HelioSpectFormer
from surya.utils.data import build_scalers
from surya.datasets.helio import inverse_transform_single_channel
warnings.filterwarnings("ignore", category=UserWarning, module='sunpy')
warnings.filterwarnings("ignore", category=FutureWarning)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
APP_CACHE = {}
CHANNEL_TO_URL_CODE = {
"aia94": "0094", "aia131": "0131", "aia171": "0171", "aia193": "0193",
"aia211": "0211", "aia304": "0304", "aia335": "0335", "aia1600": "1600",
"hmi_m": "HMIBC", "hmi_bx": "HMIB", "hmi_by": "HMIB",
"hmi_bz": "HMIB", "hmi_v": "HMID"
}
SDO_CHANNELS = list(CHANNEL_TO_URL_CODE.keys())
def setup_and_load_model():
if "model" in APP_CACHE:
yield "Model already loaded. Skipping setup."
return
yield "Downloading model files (first run only)..."
snapshot_download(repo_id="nasa-ibm-ai4science/Surya-1.0", local_dir="data/Surya-1.0",
allow_patterns=["config.yaml", "scalers.yaml", "surya.366m.v1.pt"])
yield "Loading configuration and data scalers..."
with open("data/Surya-1.0/config.yaml") as fp:
config = yaml.safe_load(fp)
APP_CACHE["config"] = config
scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r"))
APP_CACHE["scalers"] = build_scalers(info=scalers_info)
yield "Initializing model architecture..."
model_config = config["model"]
model = HelioSpectFormer(
img_size=model_config["img_size"], patch_size=model_config["patch_size"],
in_chans=len(config["data"]["sdo_channels"]), embed_dim=model_config["embed_dim"],
time_embedding={"type": "linear", "time_dim": len(config["data"]["time_delta_input_minutes"])},
depth=model_config["depth"], n_spectral_blocks=model_config["n_spectral_blocks"],
num_heads=model_config["num_heads"], mlp_ratio=model_config["mlp_ratio"],
drop_rate=model_config["drop_rate"], dtype=torch.bfloat16,
window_size=model_config["window_size"], dp_rank=model_config["dp_rank"],
learned_flow=model_config["learned_flow"], use_latitude_in_learned_flow=model_config["learned_flow"],
init_weights=False, checkpoint_layers=list(range(model_config["depth"])),
rpe=model_config["rpe"], ensemble=model_config["ensemble"], finetune=model_config["finetune"],
)
device = "cuda" if torch.cuda.is_available() else "cpu"
APP_CACHE["device"] = device
yield f"Loading model weights to {device}..."
weights = torch.load(f"data/Surya-1.0/surya.366m.v1.pt", map_location=torch.device(device))
model.load_state_dict(weights, strict=True)
model.to(device)
model.eval()
APP_CACHE["model"] = model
yield "โœ… Model setup complete."
def fetch_browse_image(channel, target_dt, max_retries=15):
url_code = CHANNEL_TO_URL_CODE[channel]
base_url = "https://sdo.gsfc.nasa.gov/assets/img/browse"
for i in range(max_retries):
dt_to_try = target_dt - datetime.timedelta(minutes=i)
date_str = dt_to_try.strftime("%Y/%m/%d")
img_str = dt_to_try.strftime(f"%Y%m%d_%H%M%S_4096_{url_code}.jpg")
url = f"{base_url}/{date_str}/{img_str}"
response = requests.get(url)
if response.status_code == 200:
logger.info(f"Successfully found image for {channel} at {dt_to_try}")
return Image.open(BytesIO(response.content))
raise FileNotFoundError(f"Could not find any recent image for {channel} within {max_retries} minutes of {target_dt}.")
def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
config = APP_CACHE["config"]
img_size = config["model"]["img_size"]
input_deltas = config["data"]["time_delta_input_minutes"]
input_times = [target_dt + datetime.timedelta(minutes=m) for m in input_deltas]
target_time = target_dt + datetime.timedelta(minutes=forecast_horizon_minutes)
all_times = sorted(list(set(input_times + [target_time])))
images = {}
total_fetches = len(all_times) * len(SDO_CHANNELS)
fetches_done = 0
yield f"Starting search for {total_fetches} data files..."
for t in all_times:
images[t] = {}
for channel in SDO_CHANNELS:
fetches_done += 1
yield f"Searching [{fetches_done}/{total_fetches}]: {channel} near {t.strftime('%Y-%m-%d %H:%M')}..."
images[t][channel] = fetch_browse_image(channel, t)
yield "โœ… All images found. Starting preprocessing..."
scaler = APP_CACHE["scalers"]
processed_tensors = {}
for t, channel_images in images.items():
channel_tensors = []
for i, channel in enumerate(SDO_CHANNELS):
img = channel_images[channel]
if img.mode != 'L':
img = img.convert('L')
img_resized = img.resize((img_size, img_size), Image.Resampling.LANCZOS)
norm_data = np.array(img_resized, dtype=np.float32)
scaled_data = scaler.transform(norm_data.reshape(-1, 1), c_idx=i).reshape(norm_data.shape)
channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32)))
processed_tensors[t] = torch.stack(channel_tensors)
yield "โœ… Preprocessing complete."
input_tensor_list = [processed_tensors[t] for t in input_times]
input_tensor = torch.stack(input_tensor_list, dim=1).unsqueeze(0)
target_image_map = images[target_time]
last_input_image_map = images[input_times[-1]]
yield (input_tensor, last_input_image_map, target_image_map)
def run_inference(input_tensor):
model = APP_CACHE["model"]
device = APP_CACHE["device"]
time_deltas = APP_CACHE["config"]["data"]["time_delta_input_minutes"]
time_delta_tensor = torch.tensor(time_deltas, dtype=torch.float32).unsqueeze(0).to(device)
input_batch = {"ts": input_tensor.to(device), "time_delta_input": time_delta_tensor}
with torch.no_grad():
with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16):
prediction = model(input_batch)
return prediction.cpu()
def generate_visualization(last_input_map, prediction_tensor, target_map, channel_name):
if last_input_map is None: return None, None, None
c_idx = SDO_CHANNELS.index(channel_name)
scaler = APP_CACHE["scalers"]
all_means, all_stds, all_epsilons, all_sl_scale_factors = scaler.get_params()
mean, std, epsilon, sl_scale_factor = all_means[c_idx], all_stds[c_idx], all_epsilons[c_idx], all_sl_scale_factors[c_idx]
pred_slice = inverse_transform_single_channel(
prediction_tensor[0, c_idx].numpy(), mean=mean, std=std, epsilon=epsilon, sl_scale_factor=sl_scale_factor
)
target_img_data = np.array(target_map[channel_name])
vmax = np.quantile(np.nan_to_num(target_img_data), 0.995)
cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray'))
def to_pil(data):
data_clipped = np.nan_to_num(data)
data_clipped = np.clip(data_clipped, 0, vmax)
data_norm = data_clipped / vmax if vmax > 0 else data_clipped
colored = (cmap(data_norm)[:, :, :3] * 255).astype(np.uint8)
return Image.fromarray(colored)
return last_input_map[channel_name], to_pil(pred_slice), target_map[channel_name]
def forecast_controller(date_str, hour, minute, forecast_horizon):
yield {
log_box: gr.update(value="Starting forecast...", visible=True),
run_button: gr.update(interactive=False),
date_input: gr.update(interactive=False),
hour_slider: gr.update(interactive=False),
minute_slider: gr.update(interactive=False),
horizon_slider: gr.update(interactive=False),
results_group: gr.update(visible=False)
}
try:
if not date_str: raise gr.Error("Please select a date.")
for status in setup_and_load_model():
yield { log_box: status }
target_dt = datetime.datetime.fromisoformat(f"{date_str}T{int(hour):02d}:{int(minute):02d}:00")
data_pipeline = fetch_and_process_sdo_data(target_dt, forecast_horizon)
while True:
try:
status = next(data_pipeline)
if isinstance(status, tuple):
input_tensor, last_input_map, target_map = status
break
yield { log_box: status }
except StopIteration:
raise gr.Error("Data processing pipeline finished unexpectedly.")
yield { log_box: "Running AI model inference..." }
prediction_tensor = run_inference(input_tensor)
yield { log_box: "Generating final visualizations..." }
img_in, img_pred, img_target = generate_visualization(last_input_map, prediction_tensor, target_map, "aia171")
yield {
log_box: f"โœ… Forecast complete for {target_dt.isoformat()} (+{forecast_horizon} mins).",
results_group: gr.update(visible=True),
state_last_input: last_input_map,
state_prediction: prediction_tensor,
state_target: target_map,
input_display: img_in,
prediction_display: img_pred,
target_display: img_target,
}
except Exception as e:
error_str = traceback.format_exc()
logger.error(f"An error occurred: {e}\n{error_str}")
yield { log_box: f"โŒ ERROR: {e}\n\nTraceback:\n{error_str}" }
finally:
yield {
run_button: gr.update(interactive=True),
date_input: gr.update(interactive=True),
hour_slider: gr.update(interactive=True),
minute_slider: gr.update(interactive=True),
horizon_slider: gr.update(interactive=True),
}
with gr.Blocks(theme=gr.themes.Soft()) as demo:
state_last_input = gr.State()
state_prediction = gr.State()
state_target = gr.State()
gr.Markdown(
"""
<div align='center'>
# โ˜€๏ธ Surya: Live Forecast Demo โ˜€๏ธ
### A Foundation Model for Solar Dynamics
This demo runs NASA's **Surya**, a foundation model trained to understand the physics of the Sun.
It looks at the Sun in 13 different channels (wavelengths of light) simultaneously to learn the complex relationships between phenomena like coronal loops, magnetic fields, and solar flares. By seeing these interconnected views, it can generate a holistic forecast of what the entire solar disk will look like in the near future.
<br>
<p style="color:red;font-weight:bold;">NOTE: This demo uses lower-quality browse images for reliability. The model was trained on high-fidelity scientific data, so forecast accuracy may vary.</p>
</div>
"""
)
with gr.Accordion("Step 1: Configure Forecast", open=True):
with gr.Row():
date_input = gr.Textbox(
label="Date (YYYY-MM-DD)",
value=(datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=3)).strftime("%Y-%m-%d")
)
hour_slider = gr.Slider(label="Hour (UTC)", minimum=0, maximum=23, step=1, value=(datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=3)).hour)
minute_slider = gr.Slider(label="Minute (UTC)", minimum=0, maximum=59, step=1, value=(datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=3)).minute)
horizon_slider = gr.Slider(
label="Forecast Horizon (minutes ahead)",
minimum=12, maximum=120, step=12, value=12
)
run_button = gr.Button("๐Ÿ”ฎ Generate Forecast", variant="primary")
with gr.Accordion("Step 2: View Log", open=False):
log_box = gr.Textbox(label="Log", interactive=False, visible=False, lines=5, max_lines=10)
with gr.Group(visible=False) as results_group:
gr.Markdown("### Step 3: Explore Results")
channel_selector = gr.Dropdown(
choices=SDO_CHANNELS, value="aia171", label="๐Ÿ›ฐ๏ธ Select SDO Channel to Visualize"
)
with gr.Row():
input_display = gr.Image(label="Last Input to Model", height=512, width=512, interactive=False)
prediction_display = gr.Image(label="Surya's Forecast", height=512, width=512, interactive=False)
target_display = gr.Image(label="Ground Truth", height=512, width=512, interactive=False)
run_button.click(
fn=forecast_controller,
inputs=[date_input, hour_slider, minute_slider, horizon_slider],
outputs=[
log_box, run_button, date_input, hour_slider, minute_slider, horizon_slider, results_group,
state_last_input, state_prediction, state_target,
input_display, prediction_display, target_display
]
)
channel_selector.change(
fn=generate_visualization,
inputs=[state_last_input, state_prediction, state_target, channel_selector],
outputs=[input_display, prediction_display, target_display]
)
if __name__ == "__main__":
demo.launch(debug=True)