import gradio as gr import torch from huggingface_hub import snapshot_download import yaml import numpy as np from PIL import Image import requests import os import warnings import logging import datetime import matplotlib.pyplot as plt import sunpy.visualization.colormaps as sunpy_cm import traceback from io import BytesIO import re from surya.models.helio_spectformer import HelioSpectFormer from surya.utils.data import build_scalers from surya.datasets.helio import inverse_transform_single_channel warnings.filterwarnings("ignore", category=UserWarning, module='sunpy') warnings.filterwarnings("ignore", category=FutureWarning) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) APP_CACHE = {} CHANNEL_TO_URL_CODE = { "aia94": "0094", "aia131": "0131", "aia171": "0171", "aia193": "0193", "aia211": "0211", "aia304": "0304", "aia335": "0335", "aia1600": "1600", "hmi_m": "HMIBC", "hmi_bx": "HMIB", "hmi_by": "HMIB", "hmi_bz": "HMIB", "hmi_v": "HMID" } SDO_CHANNELS = list(CHANNEL_TO_URL_CODE.keys()) def setup_and_load_model(): if "model" in APP_CACHE: yield "Model already loaded. Skipping setup." return yield "Downloading model files (first run only)..." snapshot_download(repo_id="nasa-ibm-ai4science/Surya-1.0", local_dir="data/Surya-1.0", allow_patterns=["config.yaml", "scalers.yaml", "surya.366m.v1.pt"]) yield "Loading configuration and data scalers..." with open("data/Surya-1.0/config.yaml") as fp: config = yaml.safe_load(fp) APP_CACHE["config"] = config scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r")) APP_CACHE["scalers"] = build_scalers(info=scalers_info) yield "Initializing model architecture..." model_config = config["model"] model = HelioSpectFormer( img_size=model_config["img_size"], patch_size=model_config["patch_size"], in_chans=len(config["data"]["sdo_channels"]), embed_dim=model_config["embed_dim"], time_embedding={"type": "linear", "time_dim": len(config["data"]["time_delta_input_minutes"])}, depth=model_config["depth"], n_spectral_blocks=model_config["n_spectral_blocks"], num_heads=model_config["num_heads"], mlp_ratio=model_config["mlp_ratio"], drop_rate=model_config["drop_rate"], dtype=torch.bfloat16, window_size=model_config["window_size"], dp_rank=model_config["dp_rank"], learned_flow=model_config["learned_flow"], use_latitude_in_learned_flow=model_config["learned_flow"], init_weights=False, checkpoint_layers=list(range(model_config["depth"])), rpe=model_config["rpe"], ensemble=model_config["ensemble"], finetune=model_config["finetune"], ) device = "cuda" if torch.cuda.is_available() else "cpu" APP_CACHE["device"] = device yield f"Loading model weights to {device}..." weights = torch.load(f"data/Surya-1.0/surya.366m.v1.pt", map_location=torch.device(device)) model.load_state_dict(weights, strict=True) model.to(device) model.eval() APP_CACHE["model"] = model yield "✅ Model setup complete." def find_nearest_browse_image_url(channel, target_dt): url_code = CHANNEL_TO_URL_CODE[channel] base_url = "https://sdo.gsfc.nasa.gov/assets/img/browse" for i in range(2): dt_to_try = target_dt - datetime.timedelta(days=i) dir_url = dt_to_try.strftime(f"{base_url}/%Y/%m/%d/") response = requests.get(dir_url) if response.status_code != 200: continue filenames = re.findall(r'href="(\d{8}_\d{6}_4096_' + url_code + r'\.jpg)"', response.text) if not filenames: continue best_filename = "" min_diff = float('inf') for fname in filenames: try: timestamp_str = fname.split('_')[1] img_dt = datetime.datetime.strptime(f"{dt_to_try.strftime('%Y%m%d')}{timestamp_str}", "%Y%m%d%H%M%S") diff = abs((target_dt - img_dt).total_seconds()) if diff < min_diff: min_diff = diff best_filename = fname except (ValueError, IndexError): continue if best_filename: return dir_url + best_filename raise FileNotFoundError(f"Could not find any browse images for {channel} in the last 48 hours.") def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes): config = APP_CACHE["config"] img_size = config["model"]["img_size"] input_deltas = config["data"]["time_delta_input_minutes"] input_times = [target_dt + datetime.timedelta(minutes=m) for m in input_deltas] target_time = target_dt + datetime.timedelta(minutes=forecast_horizon_minutes) all_times = sorted(list(set(input_times + [target_time]))) images = {} total_fetches = len(all_times) * len(SDO_CHANNELS) fetches_done = 0 yield f"Starting search for {total_fetches} data files..." for t in all_times: images[t] = {} for channel in SDO_CHANNELS: fetches_done += 1 yield f"Finding [{fetches_done}/{total_fetches}]: Closest image for {channel} near {t.strftime('%Y-%m-%d %H:%M')}..." image_url = find_nearest_browse_image_url(channel, t) yield f"Downloading: {os.path.basename(image_url)}..." response = requests.get(image_url) response.raise_for_status() images[t][channel] = Image.open(BytesIO(response.content)) yield "✅ All images found and downloaded. Starting preprocessing..." scalers_dict = APP_CACHE["scalers"] processed_tensors = {} for t, channel_images in images.items(): channel_tensors = [] for i, channel in enumerate(SDO_CHANNELS): img = channel_images[channel] if img.mode != 'L': img = img.convert('L') img_resized = img.resize((img_size, img_size), Image.Resampling.LANCZOS) norm_data = np.array(img_resized, dtype=np.float32) # *** FIX: Retrieve the correct scaler object from the dictionary for the current channel *** scaler = scalers_dict[channel] scaled_data = scaler.transform(norm_data.reshape(-1, 1)).reshape(norm_data.shape) channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32))) processed_tensors[t] = torch.stack(channel_tensors) yield "✅ Preprocessing complete." input_tensor_list = [processed_tensors[t] for t in input_times] input_tensor = torch.stack(input_tensor_list, dim=1).unsqueeze(0) target_image_map = images[target_time] last_input_image_map = images[input_times[-1]] yield (input_tensor, last_input_image_map, target_image_map) def run_inference(input_tensor): model = APP_CACHE["model"] device = APP_CACHE["device"] time_deltas = APP_CACHE["config"]["data"]["time_delta_input_minutes"] time_delta_tensor = torch.tensor(time_deltas, dtype=torch.float32).unsqueeze(0).to(device) input_batch = {"ts": input_tensor.to(device), "time_delta_input": time_delta_tensor} with torch.no_grad(): with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16): prediction = model(input_batch) return prediction.cpu() def generate_visualization(last_input_map, prediction_tensor, target_map, channel_name): if last_input_map is None: return None, None, None c_idx = SDO_CHANNELS.index(channel_name) # *** FIX: Retrieve the correct scaler object for the current channel to get its parameters *** scaler = APP_CACHE["scalers"][channel_name] params = scaler.to_dict() mean, std = params['mean'], params['std'] # Note: The inverse transform for the simplified JPEG pipeline might differ from the original # We will use a standard inverse scaling, which is the most logical approach here. pred_slice_scaled = prediction_tensor[0, c_idx].numpy() pred_slice = (pred_slice_scaled * std) + mean target_img_data = np.array(target_map[channel_name]) vmax = np.quantile(np.nan_to_num(target_img_data), 0.995) cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag' cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray')) def to_pil(data): data_clipped = np.nan_to_num(data) data_clipped = np.clip(data_clipped, 0, vmax) data_norm = data_clipped / vmax if vmax > 0 else data_clipped colored = (cmap(data_norm)[:, :, :3] * 255).astype(np.uint8) return Image.fromarray(colored) return last_input_map[channel_name], to_pil(pred_slice), target_map[channel_name] def forecast_controller(date_str, hour, minute, forecast_horizon): yield { log_box: gr.update(value="Starting forecast...", visible=True), run_button: gr.update(interactive=False), date_input: gr.update(interactive=False), hour_slider: gr.update(interactive=False), minute_slider: gr.update(interactive=False), horizon_slider: gr.update(interactive=False), results_group: gr.update(visible=False) } try: if not date_str: raise gr.Error("Please select a date.") for status in setup_and_load_model(): yield { log_box: status } target_dt = datetime.datetime.fromisoformat(f"{date_str}T{int(hour):02d}:{int(minute):02d}:00") data_pipeline = fetch_and_process_sdo_data(target_dt, forecast_horizon) while True: try: status = next(data_pipeline) if isinstance(status, tuple): input_tensor, last_input_map, target_map = status break yield { log_box: status } except StopIteration: raise gr.Error("Data processing pipeline finished unexpectedly.") yield { log_box: "Running AI model inference..." } prediction_tensor = run_inference(input_tensor) yield { log_box: "Generating final visualizations..." } img_in, img_pred, img_target = generate_visualization(last_input_map, prediction_tensor, target_map, "aia171") yield { log_box: f"✅ Forecast complete for {target_dt.isoformat()} (+{forecast_horizon} mins).", results_group: gr.update(visible=True), state_last_input: last_input_map, state_prediction: prediction_tensor, state_target: target_map, input_display: img_in, prediction_display: img_pred, target_display: img_target, } except Exception as e: error_str = traceback.format_exc() logger.error(f"An error occurred: {e}\n{error_str}") yield { log_box: f"❌ ERROR: {e}\n\nTraceback:\n{error_str}" } finally: yield { run_button: gr.update(interactive=True), date_input: gr.update(interactive=True), hour_slider: gr.update(interactive=True), minute_slider: gr.update(interactive=True), horizon_slider: gr.update(interactive=True), } with gr.Blocks(theme=gr.themes.Soft()) as demo: state_last_input = gr.State() state_prediction = gr.State() state_target = gr.State() gr.Markdown( """
NOTE: This demo uses lower-quality browse images for reliability. The model was trained on high-fidelity scientific data, so forecast accuracy may vary.