import gradio as gr import torch from huggingface_hub import snapshot_download import yaml import numpy as np from PIL import Image import requests import os import warnings import logging import datetime import matplotlib.pyplot as plt import sunpy.visualization.colormaps as sunpy_cm import traceback from io import BytesIO import re from surya.models.helio_spectformer import HelioSpectFormer from surya.utils.data import build_scalers from surya.datasets.helio import inverse_transform_single_channel warnings.filterwarnings("ignore", category=UserWarning, module='sunpy') warnings.filterwarnings("ignore", category=FutureWarning) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) APP_CACHE = {} CHANNEL_TO_URL_CODE = { "aia94": "0094", "aia131": "0131", "aia171": "0171", "aia193": "0193", "aia211": "0211", "aia304": "0304", "aia335": "0335", "aia1600": "1600", "hmi_m": "HMIBC", "hmi_bx": "HMIB", "hmi_by": "HMIB", "hmi_bz": "HMIB", "hmi_v": "HMID" } SDO_CHANNELS = list(CHANNEL_TO_URL_CODE.keys()) def setup_and_load_model(): if "model" in APP_CACHE: yield "Model already loaded. Skipping setup." return yield "Downloading model files (first run only)..." snapshot_download(repo_id="nasa-ibm-ai4science/Surya-1.0", local_dir="data/Surya-1.0", allow_patterns=["config.yaml", "scalers.yaml", "surya.366m.v1.pt"]) yield "Loading configuration and data scalers..." with open("data/Surya-1.0/config.yaml") as fp: config = yaml.safe_load(fp) APP_CACHE["config"] = config scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r")) APP_CACHE["scalers"] = build_scalers(info=scalers_info) yield "Initializing model architecture..." model_config = config["model"] model = HelioSpectFormer( img_size=model_config["img_size"], patch_size=model_config["patch_size"], in_chans=len(config["data"]["sdo_channels"]), embed_dim=model_config["embed_dim"], time_embedding={"type": "linear", "time_dim": len(config["data"]["time_delta_input_minutes"])}, depth=model_config["depth"], n_spectral_blocks=model_config["n_spectral_blocks"], num_heads=model_config["num_heads"], mlp_ratio=model_config["mlp_ratio"], drop_rate=model_config["drop_rate"], dtype=torch.bfloat16, window_size=model_config["window_size"], dp_rank=model_config["dp_rank"], learned_flow=model_config["learned_flow"], use_latitude_in_learned_flow=model_config["learned_flow"], init_weights=False, checkpoint_layers=list(range(model_config["depth"])), rpe=model_config["rpe"], ensemble=model_config["ensemble"], finetune=model_config["finetune"], ) device = "cuda" if torch.cuda.is_available() else "cpu" APP_CACHE["device"] = device yield f"Loading model weights to {device}..." weights = torch.load(f"data/Surya-1.0/surya.366m.v1.pt", map_location=torch.device(device)) model.load_state_dict(weights, strict=True) model.to(device) model.eval() APP_CACHE["model"] = model yield "✅ Model setup complete." def find_nearest_browse_image_url(channel, target_dt): url_code = CHANNEL_TO_URL_CODE[channel] base_url = "https://sdo.gsfc.nasa.gov/assets/img/browse" for i in range(2): dt_to_try = target_dt - datetime.timedelta(days=i) dir_url = dt_to_try.strftime(f"{base_url}/%Y/%m/%d/") response = requests.get(dir_url) if response.status_code != 200: continue filenames = re.findall(r'href="(\d{8}_\d{6}_4096_' + url_code + r'\.jpg)"', response.text) if not filenames: continue best_filename = "" min_diff = float('inf') for fname in filenames: try: timestamp_str = fname.split('_')[1] img_dt = datetime.datetime.strptime(f"{dt_to_try.strftime('%Y%m%d')}{timestamp_str}", "%Y%m%d%H%M%S") diff = abs((target_dt - img_dt).total_seconds()) if diff < min_diff: min_diff = diff best_filename = fname except (ValueError, IndexError): continue if best_filename: return dir_url + best_filename raise FileNotFoundError(f"Could not find any browse images for {channel} in the last 48 hours.") def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes): config = APP_CACHE["config"] img_size = config["model"]["img_size"] input_deltas = config["data"]["time_delta_input_minutes"] input_times = [target_dt + datetime.timedelta(minutes=m) for m in input_deltas] target_time = target_dt + datetime.timedelta(minutes=forecast_horizon_minutes) all_times = sorted(list(set(input_times + [target_time]))) images = {} total_fetches = len(all_times) * len(SDO_CHANNELS) fetches_done = 0 yield f"Starting search for {total_fetches} data files..." for t in all_times: images[t] = {} for channel in SDO_CHANNELS: fetches_done += 1 yield f"Finding [{fetches_done}/{total_fetches}]: Closest image for {channel} near {t.strftime('%Y-%m-%d %H:%M')}..." image_url = find_nearest_browse_image_url(channel, t) yield f"Downloading: {os.path.basename(image_url)}..." response = requests.get(image_url) response.raise_for_status() images[t][channel] = Image.open(BytesIO(response.content)) yield "✅ All images found and downloaded. Starting preprocessing..." scalers_dict = APP_CACHE["scalers"] processed_tensors = {} for t, channel_images in images.items(): channel_tensors = [] for i, channel in enumerate(SDO_CHANNELS): img = channel_images[channel] if img.mode != 'L': img = img.convert('L') img_resized = img.resize((img_size, img_size), Image.Resampling.LANCZOS) norm_data = np.array(img_resized, dtype=np.float32) scaler = scalers_dict[channel] scaled_data = scaler.transform(norm_data.reshape(-1, 1)).reshape(norm_data.shape) channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32))) processed_tensors[t] = torch.stack(channel_tensors) yield "✅ Preprocessing complete." input_tensor_list = [processed_tensors[t] for t in input_times] input_tensor = torch.stack(input_tensor_list, dim=1).unsqueeze(0) target_image_map = images[target_time] last_input_image_map = images[input_times[-1]] yield (input_tensor, last_input_image_map, target_image_map) def run_inference(input_tensor): model = APP_CACHE["model"] device = APP_CACHE["device"] time_deltas = APP_CACHE["config"]["data"]["time_delta_input_minutes"] time_delta_tensor = torch.tensor(time_deltas, dtype=torch.float32).unsqueeze(0).to(device) input_batch = {"ts": input_tensor.to(device), "time_delta_input": time_delta_tensor} with torch.no_grad(): with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16): prediction = model(input_batch) return prediction.cpu().to(torch.float32) def generate_visualization(last_input_map, prediction_tensor, target_map, channel_name): if last_input_map is None: return None, None, None c_idx = SDO_CHANNELS.index(channel_name) # *** FIX: Access the specific scaler for the channel from the dictionary *** scaler = APP_CACHE["scalers"][channel_name] # *** FIX: Access the parameters as attributes, not from to_dict() *** mean = scaler.mean std = scaler.std epsilon = scaler.epsilon sl_scale_factor = scaler.sl_scale_factor pred_slice = inverse_transform_single_channel( prediction_tensor[0, c_idx].numpy(), mean=mean, std=std, epsilon=epsilon, sl_scale_factor=sl_scale_factor ) target_img_data = np.array(target_map[channel_name]) vmax = np.quantile(np.nan_to_num(target_img_data), 0.995) cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag' cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray')) def to_pil(data): data_clipped = np.nan_to_num(data) data_clipped = np.clip(data_clipped, 0, vmax) data_norm = data_clipped / vmax if vmax > 0 else data_clipped colored = (cmap(data_norm)[:, :, :3] * 255).astype(np.uint8) return Image.fromarray(colored) return last_input_map[channel_name], to_pil(pred_slice), target_map[channel_name] def forecast_controller(date_str, hour, minute, forecast_horizon): yield { log_box: gr.update(value="Starting forecast...", visible=True), run_button: gr.update(interactive=False), date_input: gr.update(interactive=False), hour_slider: gr.update(interactive=False), minute_slider: gr.update(interactive=False), horizon_slider: gr.update(interactive=False), results_group: gr.update(visible=False) } try: if not date_str: raise gr.Error("Please select a date.") for status in setup_and_load_model(): yield { log_box: status } target_dt = datetime.datetime.fromisoformat(f"{date_str}T{int(hour):02d}:{int(minute):02d}:00") data_pipeline = fetch_and_process_sdo_data(target_dt, forecast_horizon) while True: try: status = next(data_pipeline) if isinstance(status, tuple): input_tensor, last_input_map, target_map = status break yield { log_box: status } except StopIteration: raise gr.Error("Data processing pipeline finished unexpectedly.") yield { log_box: "Running AI model inference..." } prediction_tensor = run_inference(input_tensor) yield { log_box: "Generating final visualizations..." } img_in, img_pred, img_target = generate_visualization(last_input_map, prediction_tensor, target_map, "aia171") yield { log_box: f"✅ Forecast complete for {target_dt.isoformat()} (+{forecast_horizon} mins).", results_group: gr.update(visible=True), state_last_input: last_input_map, state_prediction: prediction_tensor, state_target: target_map, input_display: img_in, prediction_display: img_pred, target_display: img_target, } except Exception as e: error_str = traceback.format_exc() logger.error(f"An error occurred: {e}\n{error_str}") yield { log_box: f"❌ ERROR: {e}\n\nTraceback:\n{error_str}" } finally: yield { run_button: gr.update(interactive=True), date_input: gr.update(interactive=True), hour_slider: gr.update(interactive=True), minute_slider: gr.update(interactive=True), horizon_slider: gr.update(interactive=True), } with gr.Blocks(theme=gr.themes.Soft()) as demo: state_last_input = gr.State() state_prediction = gr.State() state_target = gr.State() gr.Markdown( """ # ☀️ Surya: Live Forecast Demo ☀️ ### A Foundation Model for Solar Dynamics This demo runs NASA's **Surya**, a foundation model trained to understand the physics of the Sun. It looks at the Sun in 13 different channels (wavelengths of light) simultaneously to learn the complex relationships between phenomena like coronal loops, magnetic fields, and solar flares. By seeing these interconnected views, it can generate a holistic forecast of what the entire solar disk will look like in the near future.

NOTE: This demo uses lower-quality browse images for reliability. The model was trained on high-fidelity scientific data, so forecast accuracy may vary.

""" ) with gr.Accordion("Step 1: Configure Forecast", open=True): with gr.Row(): date_input = gr.Textbox( label="Date (YYYY-MM-DD)", value=(datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=3)).strftime("%Y-%m-%d") ) hour_slider = gr.Slider(label="Hour (UTC)", minimum=0, maximum=23, step=1, value=(datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=3)).hour) minute_slider = gr.Slider(label="Minute (UTC)", minimum=0, maximum=59, step=1, value=(datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=3)).minute) horizon_slider = gr.Slider( label="Forecast Horizon (minutes ahead)", minimum=12, maximum=120, step=12, value=12 ) run_button = gr.Button("🔮 Generate Forecast", variant="primary") with gr.Accordion("Step 2: View Log", open=False): log_box = gr.Textbox(label="Log", interactive=False, visible=False, lines=5, max_lines=10) with gr.Group(visible=False) as results_group: gr.Markdown("### Step 3: Explore Results") channel_selector = gr.Dropdown( choices=SDO_CHANNELS, value="aia171", label="🛰️ Select SDO Channel to Visualize" ) with gr.Row(): input_display = gr.Image(label="Last Input to Model", height=512, width=512, interactive=False) prediction_display = gr.Image(label="Surya's Forecast", height=512, width=512, interactive=False) target_display = gr.Image(label="Ground Truth", height=512, width=512, interactive=False) run_button.click( fn=forecast_controller, inputs=[date_input, hour_slider, minute_slider, horizon_slider], outputs=[ log_box, run_button, date_input, hour_slider, minute_slider, horizon_slider, results_group, state_last_input, state_prediction, state_target, input_display, prediction_display, target_display ] ) channel_selector.change( fn=generate_visualization, inputs=[state_last_input, state_prediction, state_target, channel_selector], outputs=[input_display, prediction_display, target_display] ) if __name__ == "__main__": demo.launch(debug=True)