# Save this file as in the root of the cloned Surya repository import gradio as gr import torch from huggingface_hub import snapshot_download import yaml import numpy as np from PIL import Image import sunpy.map import sunpy.net.attrs as a from sunpy.net import Fido from astropy.wcs import WCS import astropy.units as u from reproject import reproject_interp import os import warnings import logging import datetime import matplotlib.pyplot as plt import sunpy.visualization.colormaps as sunpy_cm import traceback # --- Use the official Surya modules --- from surya.models.helio_spectformer import HelioSpectFormer from surya.utils.data import build_scalers from surya.datasets.helio import inverse_transform_single_channel # --- Configuration --- warnings.filterwarnings("ignore", category=UserWarning, module='sunpy') warnings.filterwarnings("ignore", category=FutureWarning) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global cache for model, config, etc. APP_CACHE = {} SDO_CHANNELS_MAP = { "aia94": (a.Wavelength(94 * u.angstrom), a.Sample(12 * u.s)), "aia131": (a.Wavelength(131 * u.angstrom), a.Sample(12 * u.s)), "aia171": (a.Wavelength(171 * u.angstrom), a.Sample(12 * u.s)), "aia193": (a.Wavelength(193 * u.angstrom), a.Sample(12 * u.s)), "aia211": (a.Wavelength(211 * u.angstrom), a.Sample(12 * u.s)), "aia304": (a.Wavelength(304 * u.angstrom), a.Sample(12 * u.s)), "aia335": (a.Wavelength(335 * u.angstrom), a.Sample(12 * u.s)), "aia1600": (a.Wavelength(1600 * u.angstrom), a.Sample(24 * u.s)), "hmi_m": (a.Physobs("intensity"), a.Sample(45 * u.s)), "hmi_bx": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)), "hmi_by": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)), # Placeholder "hmi_bz": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)), # Placeholder "hmi_v": (a.Physobs("los_velocity"), a.Sample(45 * u.s)), } SDO_CHANNELS = list(SDO_CHANNELS_MAP.keys()) # --- 1. Model Loading and Setup --- def setup_and_load_model(): if "model" in APP_CACHE: yield "Model already loaded. Skipping setup." return yield "Downloading model files (first run only)..." snapshot_download(repo_id="nasa-ibm-ai4science/Surya-1.0", local_dir="data/Surya-1.0", allow_patterns=["config.yaml", "scalers.yaml", "surya.366m.v1.pt"]) yield "Loading configuration and data scalers..." with open("data/Surya-1.0/config.yaml") as fp: config = yaml.safe_load(fp) APP_CACHE["config"] = config scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r")) APP_CACHE["scalers"] = build_scalers(info=scalers_info) yield "Initializing model architecture..." model_config = config["model"] model = HelioSpectFormer( img_size=model_config["img_size"], patch_size=model_config["patch_size"], in_chans=len(config["data"]["sdo_channels"]), embed_dim=model_config["embed_dim"], time_embedding={"type": "linear", "time_dim": len(config["data"]["time_delta_input_minutes"])}, depth=model_config["depth"], n_spectral_blocks=model_config["n_spectral_blocks"], num_heads=model_config["num_heads"], mlp_ratio=model_config["mlp_ratio"], drop_rate=model_config["drop_rate"], dtype=torch.bfloat16, window_size=model_config["window_size"], dp_rank=model_config["dp_rank"], learned_flow=model_config["learned_flow"], use_latitude_in_learned_flow=model_config["learned_flow"], init_weights=False, checkpoint_layers=list(range(model_config["depth"])), rpe=model_config["rpe"], ensemble=model_config["ensemble"], finetune=model_config["finetune"], ) device = "cuda" if torch.cuda.is_available() else "cpu" APP_CACHE["device"] = device yield f"Loading model weights to {device}..." weights = torch.load(f"data/Surya-1.0/surya.366m.v1.pt", map_location=torch.device(device)) model.load_state_dict(weights, strict=True) model.to(device) model.eval() APP_CACHE["model"] = model yield "✅ Model setup complete." # --- 2. Live Data Fetching and Preprocessing (as a generator) --- def fetch_and_process_sdo_data(target_dt): config = APP_CACHE["config"] img_size = config["model"]["img_size"] input_deltas = config["data"]["time_delta_input_minutes"] # *** FIX: Access target_delta as an integer, not a list. Removed [0]. *** target_delta = config["data"]["time_delta_target_minutes"] input_times = [target_dt + datetime.timedelta(minutes=m) for m in input_deltas] target_time = target_dt + datetime.timedelta(minutes=target_delta) all_times = sorted(list(set(input_times + [target_time]))) data_maps = {} total_downloads = len(all_times) * len(SDO_CHANNELS) downloads_done = 0 yield f"Starting download of {total_downloads} data files..." for t in all_times: data_maps[t] = {} for i, (channel, (physobs, sample)) in enumerate(SDO_CHANNELS_MAP.items()): downloads_done += 1 yield f"Downloading [{downloads_done}/{total_downloads}]: {channel} for {t.strftime('%Y-%m-%d %H:%M')}..." if channel in ["hmi_by", "hmi_bz"]: if data_maps[t].get("hmi_bx"): data_maps[t][channel] = data_maps[t]["hmi_bx"] continue time_attr = a.Time(t - datetime.timedelta(minutes=10), t + datetime.timedelta(minutes=10)) instrument = a.Instrument.hmi if "hmi" in channel else a.Instrument.aia query = Fido.search(time_attr, instrument, physobs, sample) if not query: raise ValueError(f"No data found for {channel} at {t}") files = Fido.fetch(query[0, 0], path="./data/sdo_cache") data_maps[t][channel] = sunpy.map.Map(files[0]) yield "✅ All files downloaded. Starting preprocessing..." output_wcs = WCS(naxis=2) output_wcs.wcs.crpix = [(img_size + 1) / 2, (img_size + 1) / 2] output_wcs.wcs.cdelt = np.array([-1.2, 1.2]) * u.arcsec output_wcs.wcs.crval = [0, 0] * u.arcsec output_wcs.wcs.ctype = ['HPLN-TAN', 'HPLT-TAN'] scaler = APP_CACHE["scalers"] processed_tensors = {} for t, channel_maps in data_maps.items(): channel_tensors = [] for i, channel in enumerate(SDO_CHANNELS): smap = channel_maps[channel] reprojected_data, _ = reproject_interp(smap, output_wcs, shape_out=(img_size, img_size)) exp_time = smap.meta.get('exptime', 1.0) if exp_time is None or exp_time <= 0: exp_time = 1.0 norm_data = reprojected_data / exp_time scaled_data = scaler.transform(norm_data, c_idx=i) channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32))) processed_tensors[t] = torch.stack(channel_tensors) yield "✅ Preprocessing complete." input_tensor_list = [processed_tensors[t] for t in input_times] input_tensor = torch.stack(input_tensor_list, dim=1).unsqueeze(0) target_map = data_maps[target_time] last_input_map = data_maps[input_times[-1]] yield (input_tensor, last_input_map, target_map) # --- 3. Inference and Visualization --- def run_inference(input_tensor): model = APP_CACHE["model"] device = APP_CACHE["device"] time_deltas = APP_CACHE["config"]["data"]["time_delta_input_minutes"] time_delta_tensor = torch.tensor(time_deltas, dtype=torch.float32).unsqueeze(0).to(device) input_batch = {"ts": input_tensor.to(device), "time_delta_input": time_delta_tensor} with torch.no_grad(): with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16): prediction = model(input_batch) return prediction.cpu() def generate_visualization(last_input_map, prediction_tensor, target_map, channel_name): if last_input_map is None: return None, None, None c_idx = SDO_CHANNELS.index(channel_name) scaler = APP_CACHE["scalers"] all_means, all_stds, all_epsilons, all_sl_scale_factors = scaler.get_params() mean, std, epsilon, sl_scale_factor = all_means[c_idx], all_stds[c_idx], all_epsilons[c_idx], all_sl_scale_factors[c_idx] pred_slice = inverse_transform_single_channel( prediction_tensor[0, c_idx].numpy(), mean=mean, std=std, epsilon=epsilon, sl_scale_factor=sl_scale_factor ) vmax = np.quantile(np.nan_to_num(target_map[channel_name].data), 0.995) cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag' cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray')) def to_pil(data, flip=False): data_clipped = np.nan_to_num(data) data_clipped = np.clip(data_clipped, 0, vmax) data_norm = data_clipped / vmax if vmax > 0 else data_clipped colored = (cmap(data_norm)[:, :, :3] * 255).astype(np.uint8) img = Image.fromarray(colored) return img.transpose(Image.Transpose.FLIP_TOP_BOTTOM) if flip else img return to_pil(last_input_map[channel_name].data, flip=True), to_pil(pred_slice, flip=True), to_pil(target_map[channel_name].data, flip=True) # --- 4. Gradio UI and Controllers --- def forecast_controller(dt_str): yield { log_box: gr.update(value="Starting forecast...", visible=True), run_button: gr.update(interactive=False), datetime_input: gr.update(interactive=False), results_group: gr.update(visible=False) } try: if not dt_str: raise gr.Error("Please select a date and time.") for status in setup_and_load_model(): yield { log_box: status } target_dt = datetime.datetime.fromisoformat(dt_str) data_pipeline = fetch_and_process_sdo_data(target_dt) while True: try: status = next(data_pipeline) if isinstance(status, tuple): input_tensor, last_input_map, target_map = status break yield { log_box: status } except StopIteration: raise gr.Error("Data processing pipeline finished unexpectedly.") yield { log_box: "Running AI model inference..." } prediction_tensor = run_inference(input_tensor) yield { log_box: "Generating final visualizations..." } img_in, img_pred, img_target = generate_visualization(last_input_map, prediction_tensor, target_map, "aia171") yield { log_box: f"✅ Forecast complete for {target_dt.isoformat()}.", results_group: gr.update(visible=True), state_last_input: last_input_map, state_prediction: prediction_tensor, state_target: target_map, input_display: img_in, prediction_display: img_pred, target_display: img_target, } except Exception as e: error_str = traceback.format_exc() logger.error(f"An error occurred: {e}\n{error_str}") yield { log_box: f"❌ ERROR: {e}\n\nTraceback:\n{error_str}" } finally: yield { run_button: gr.update(interactive=True), datetime_input: gr.update(interactive=True) } # --- 5. Gradio UI Definition --- with gr.Blocks(theme=gr.themes.Soft()) as demo: state_last_input = gr.State() state_prediction = gr.State() state_target = gr.State() gr.Markdown( """