Spaces:
Running
Running
# Save this file as app.py in the root of the cloned Surya repository | |
import gradio as gr | |
import torch | |
from huggingface_hub import snapshot_download | |
import yaml | |
import numpy as np | |
from PIL import Image | |
import sunpy.map | |
import sunpy.net.attrs as a | |
from sunpy.net import Fido | |
from astropy.wcs import WCS | |
import astropy.units as u | |
from reproject import reproject_interp # Correct import statement | |
import os | |
import warnings | |
import logging | |
import datetime | |
import matplotlib.pyplot as plt | |
import sunpy.visualization.colormaps as sunpy_cm | |
# --- Use the official Surya modules --- | |
from surya.models.helio_spectformer import HelioSpectFormer | |
from surya.utils.data import build_scalers, inverse_transform_single_channel | |
# --- Configuration --- | |
warnings.filterwarnings("ignore", category=UserWarning, module='sunpy') | |
warnings.filterwarnings("ignore", category=FutureWarning) | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Global cache for model, config, etc. | |
APP_CACHE = {} | |
SDO_CHANNELS_MAP = { | |
"aia94": (a.Wavelength(94, 94, "angstrom"), a.Sample(12 * u.s)), | |
"aia131": (a.Wavelength(131, 131, "angstrom"), a.Sample(12 * u.s)), | |
"aia171": (a.Wavelength(171, 171, "angstrom"), a.Sample(12 * u.s)), | |
"aia193": (a.Wavelength(193, 193, "angstrom"), a.Sample(12 * u.s)), | |
"aia211": (a.Wavelength(211, 211, "angstrom"), a.Sample(12 * u.s)), | |
"aia304": (a.Wavelength(304, 304, "angstrom"), a.Sample(12 * u.s)), | |
"aia335": (a.Wavelength(335, 335, "angstrom"), a.Sample(12 * u.s)), | |
"aia1600": (a.Wavelength(1600, 1600, "angstrom"), a.Sample(24 * u.s)), | |
"hmi_m": (a.Physobs("intensity"), a.Sample(45 * u.s)), | |
"hmi_bx": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)), | |
"hmi_by": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)), # Placeholder | |
"hmi_bz": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)), # Placeholder | |
"hmi_v": (a.Physobs("los_velocity"), a.Sample(45 * u.s)), | |
} | |
SDO_CHANNELS = list(SDO_CHANNELS_MAP.keys()) | |
# --- 1. Model Loading and Setup --- | |
def setup_and_load_model(progress=gr.Progress()): | |
if "model" in APP_CACHE: | |
return | |
progress(0.1, desc="Downloading model files (first run only)...") | |
snapshot_download(repo_id="nasa-ibm-ai4science/Surya-1.0", local_dir="data/Surya-1.0", | |
allow_patterns=["config.yaml", "scalers.yaml", "surya.366m.v1.pt"]) | |
progress(0.5, desc="Loading configuration and scalers...") | |
with open("data/Surya-1.0/config.yaml") as fp: | |
config = yaml.safe_load(fp) | |
APP_CACHE["config"] = config | |
scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r")) | |
APP_CACHE["scalers"] = build_scalers(info=scalers_info) | |
progress(0.7, desc="Initializing and loading model...") | |
model_config = config["model"] | |
model = HelioSpectFormer( | |
img_size=model_config["img_size"], patch_size=model_config["patch_size"], | |
in_chans=len(config["data"]["sdo_channels"]), embed_dim=model_config["embed_dim"], | |
time_embedding={"type": "linear", "time_dim": len(config["data"]["time_delta_input_minutes"])}, | |
depth=model_config["depth"], n_spectral_blocks=model_config["n_spectral_blocks"], | |
num_heads=model_config["num_heads"], mlp_ratio=model_config["mlp_ratio"], | |
drop_rate=model_config["drop_rate"], dtype=torch.bfloat16, | |
window_size=model_config["window_size"], dp_rank=model_config["dp_rank"], | |
learned_flow=model_config["learned_flow"], use_latitude_in_learned_flow=model_config["learned_flow"], | |
init_weights=False, checkpoint_layers=list(range(model_config["depth"])), | |
rpe=model_config["rpe"], ensemble=model_config["ensemble"], finetune=model_config["finetune"], | |
) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
APP_CACHE["device"] = device | |
weights = torch.load(f"data/Surya-1.0/surya.366m.v1.pt", map_location=torch.device(device)) | |
model.load_state_dict(weights, strict=True) | |
model.to(device) | |
model.eval() | |
APP_CACHE["model"] = model | |
logger.info("Model setup complete.") | |
# --- 2. Live Data Fetching and Preprocessing --- | |
def fetch_and_process_sdo_data(target_dt, progress): | |
config = APP_CACHE["config"] | |
img_size = config["model"]["img_size"][0] | |
input_deltas = config["data"]["time_delta_input_minutes"] | |
target_delta = config["data"]["time_delta_target_minutes"][0] | |
input_times = [target_dt + datetime.timedelta(minutes=m) for m in input_deltas] | |
target_time = target_dt + datetime.timedelta(minutes=target_delta) | |
all_times = sorted(list(set(input_times + [target_time]))) | |
data_maps = {} | |
total_downloads = len(all_times) * len(SDO_CHANNELS_MAP) | |
downloads_done = 0 | |
for t in all_times: | |
data_maps[t] = {} | |
for i, (channel, (physobs, sample)) in enumerate(SDO_CHANNELS_MAP.items()): | |
progress(downloads_done / total_downloads, desc=f"Downloading {channel} for {t.strftime('%H:%M')}...") | |
instrument = a.Instrument.hmi if "hmi" in channel else a.Instrument.aia | |
if channel in ["hmi_by", "hmi_bz"]: | |
if data_maps[t].get("hmi_bx"): data_maps[t][channel] = data_maps[t]["hmi_bx"] | |
continue | |
time_attr = a.Time(t - datetime.timedelta(minutes=10), t + datetime.timedelta(minutes=10)) | |
search_query = [time_attr, physobs, sample] | |
# AIA and HMI queries are slightly different | |
if "aia" in channel: | |
search_query.append(a.Instrument.aia) | |
else: | |
search_query.append(a.Instrument.hmi) | |
query = Fido.search(*search_query) | |
if not query: raise ValueError(f"No data found for {channel} at {t}") | |
files = Fido.fetch(query[0, 0], path="./data/sdo_cache") | |
data_maps[t][channel] = sunpy.map.Map(files[0]) | |
downloads_done += 1 | |
output_wcs = WCS(naxis=2) | |
output_wcs.wcs.crpix = [(img_size + 1) / 2, (img_size + 1) / 2] | |
output_wcs.wcs.cdelt = np.array([-1.2, 1.2]) * u.arcsec | |
output_wcs.wcs.crval = [0, 0] * u.arcsec | |
output_wcs.wcs.ctype = ['HPLN-TAN', 'HPLT-TAN'] | |
processed_tensors = {} | |
total_processing = len(all_times) * len(SDO_CHANNELS) | |
processing_done = 0 | |
for t, channel_maps in data_maps.items(): | |
channel_tensors = [] | |
for i, channel in enumerate(SDO_CHANNELS): | |
progress(processing_done / total_processing, desc=f"Processing {channel} for {t.strftime('%H:%M')}...") | |
smap = channel_maps[channel] | |
reprojected_data, _ = reproject_interp(smap, output_wcs, shape_out=(img_size, img_size)) | |
exp_time = smap.meta.get('exptime', 1.0) | |
if exp_time is None or exp_time <= 0: exp_time = 1.0 | |
norm_data = reprojected_data / exp_time | |
scaler = APP_CACHE["scalers"][channel] | |
scaled_data = scaler.transform(norm_data) | |
channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32))) | |
processing_done += 1 | |
processed_tensors[t] = torch.stack(channel_tensors) | |
input_tensor_list = [processed_tensors[t] for t in input_times] | |
input_tensor = torch.stack(input_tensor_list, dim=1).unsqueeze(0) | |
target_map = data_maps[target_time] | |
last_input_map = data_maps[input_times[-1]] | |
return input_tensor, last_input_map, target_map | |
# --- 3. Inference and Visualization --- | |
def run_inference(input_tensor): | |
logger.info("Running model inference...") | |
model = APP_CACHE["model"] | |
device = APP_CACHE["device"] | |
time_deltas = APP_CACHE["config"]["data"]["time_delta_input_minutes"] | |
time_delta_tensor = torch.tensor(time_deltas, dtype=torch.float32).unsqueeze(0).to(device) | |
input_batch = {"ts": input_tensor.to(device), "time_delta_input": time_delta_tensor} | |
with torch.no_grad(): | |
with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16): | |
prediction = model(input_batch) | |
logger.info("Inference complete.") | |
return prediction.cpu() | |
def generate_visualization(last_input_map, prediction_tensor, target_map, channel_name): | |
if last_input_map is None: | |
return None, None, None | |
c_idx = SDO_CHANNELS.index(channel_name) | |
# Process Prediction | |
means, stds, epsilons, sl_scale_factors = APP_CACHE["scalers"][SDO_CHANNELS[0]].get_params() | |
pred_slice = inverse_transform_single_channel( | |
prediction_tensor[0, c_idx].numpy(), | |
mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx] | |
) | |
vmax = np.quantile(np.nan_to_num(target_map[channel_name].data), 0.995) | |
cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag' | |
cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray')) | |
def to_pil(data, flip=False): | |
data_clipped = np.nan_to_num(data) | |
data_clipped = np.clip(data_clipped, 0, vmax) | |
data_norm = data_clipped / vmax if vmax > 0 else data_clipped | |
colored = (cmap(data_norm)[:, :, :3] * 255).astype(np.uint8) | |
img = Image.fromarray(colored) | |
return img.transpose(Image.Transpose.FLIP_TOP_BOTTOM) if flip else img | |
return to_pil(last_input_map[channel_name].data), to_pil(pred_slice, flip=True), to_pil(target_map[channel_name].data) | |
# --- 4. Gradio UI and Controllers --- | |
def forecast_controller(dt_str, progress=gr.Progress(track_tqdm=True)): | |
try: | |
if not dt_str: | |
raise gr.Error("Please select a date and time.") | |
progress(0, desc="Initializing...") | |
setup_and_load_model(progress) | |
target_dt = datetime.datetime.fromisoformat(dt_str) | |
logger.info(f"Starting forecast for target time: {target_dt}") | |
input_tensor, last_input_map, target_map = fetch_and_process_sdo_data(target_dt, progress) | |
prediction_tensor = run_inference(input_tensor) | |
img_in, img_pred, img_target = generate_visualization(last_input_map, prediction_tensor, target_map, "aia171") | |
status = f"Forecast complete for {target_dt.isoformat()}. Ready to explore channels." | |
logger.info(status) | |
return (last_input_map, prediction_tensor, target_map, | |
img_in, img_pred, img_target, status, gr.update(visible=True)) | |
except Exception as e: | |
logger.error(f"An error occurred: {e}", exc_info=True) | |
raise gr.Error(f"Failed to generate forecast. Error: {e}") | |
def update_visualization_controller(last_input_map, prediction_tensor, target_map, channel_name): | |
if last_input_map is None: | |
return None, None, None | |
return generate_visualization(last_input_map, prediction_tensor, target_map, channel_name) | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
state_last_input = gr.State() | |
state_prediction = gr.State() | |
state_target = gr.State() | |
gr.Markdown( | |
""" | |
<div align='center'> | |
# ☀️ Surya: Live Forecast Demo ☀️ | |
### Generate a real forecast for any recent date using NASA's Heliophysics Model. | |
**Instructions:** | |
1. Pick a date and time (at least 1 hour in the past). | |
2. Click 'Generate Forecast'. **This will be slow (5-15 minutes) as it downloads live data.** | |
3. Once complete, select different channels to explore the multi-spectrum forecast. | |
</div> | |
""" | |
) | |
with gr.Row(): | |
datetime_input = gr.Textbox(label="Enter Forecast Start Time (YYYY-MM-DD HH:MM:SS)", | |
value=(datetime.datetime.now() - datetime.timedelta(hours=3)).strftime("%Y-%m-%d %H:%M:%S")) | |
run_button = gr.Button("🔮 Generate Forecast", variant="primary") | |
with gr.Group(visible=False) as results_group: | |
status_box = gr.Textbox(label="Status", interactive=False) | |
channel_selector = gr.Dropdown(choices=SDO_CHANNELS, value="aia171", label="🛰️ Select SDO Channel") | |
with gr.Row(): | |
input_display = gr.Image(label="Last Input to Model", height=512, width=512, interactive=False) | |
prediction_display = gr.Image(label="Surya's Forecast", height=512, width=512, interactive=False) | |
target_display = gr.Image(label="Ground Truth", height=512, width=512, interactive=False) | |
run_button.click( | |
fn=forecast_controller, | |
inputs=[datetime_input], | |
outputs=[state_last_input, state_prediction, state_target, | |
input_display, prediction_display, target_display, status_box, results_group] | |
) | |
channel_selector.change( | |
fn=update_visualization_controller, | |
inputs=[state_last_input, state_prediction, state_target, channel_selector], | |
outputs=[input_display, prediction_display, target_display] | |
) | |
if __name__ == "__main__": | |
os.makedirs("./data/sdo_cache", exist_ok=True) | |
demo.launch(debug=True) |