import gradio as gr import torch from huggingface_hub import hf_hub_download from surya.model import Surya # This now works because of the file structure import numpy as np from PIL import Image import os import warnings # Suppress warnings for a cleaner demo experience warnings.filterwarnings("ignore") # --- 1. Define Constants and Data Channels --- # Based on the Surya project's data preprocessing AIA_CHANNELS = ["94", "131", "171", "193", "211", "304", "335", "1600"] HMI_CHANNELS = ["bx", "by", "bz", "by_abs", "bz_abs"] ALL_CHANNELS = [f"AIA {ch} Å" for ch in AIA_CHANNELS] + [f"HMI {ch}" for ch in HMI_CHANNELS] # --- 2. Caching and Loading the Model and Data --- @gr.cache def load_model_and_data(): """ Downloads the pre-trained Surya model, the test data, and initializes the model. This function is cached so this happens only once. """ print("Downloading model and test data... This may take a moment.") # Define local directories for caching model_dir = "./surya_model" data_dir = "./surya_data" os.makedirs(model_dir, exist_ok=True) os.makedirs(data_dir, exist_ok=True) # Download the model weights and test data from Hugging Face checkpoint_path = hf_hub_download( repo_id="nasa-ibm-ai4science/Surya-1.0", filename="surya.366m.v1.pt", local_dir=model_dir ) test_data_path = hf_hub_download( repo_id="nasa-ibm-ai4science/Surya-1.0", filename="test_data.pt", local_dir=data_dir ) print("Downloads complete.") # Initialize the model architecture model = Surya( img_size=4096, patch_size=16, in_chans=13, embed_dim=1280, spectral_blocks=2, attention_blocks=8, ) # Load the weights into the model print("Loading model weights...") model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) model.eval() print("Model loaded successfully.") # Load the test data test_data = torch.load(test_data_path) test_input = test_data["input"] # Input tensor for the model test_label = test_data["label"] # Ground truth for comparison return model, test_input, test_label # --- 3. Helper function for Image Conversion --- def tensor_to_image(tensor_slice): """ Normalizes a 2D tensor slice and converts it to a PIL Image for display. """ # Detach tensor from graph, move to CPU, and convert to numpy img_np = tensor_slice.detach().cpu().numpy() # Normalize the tensor to a 0-255 range for image display min_val, max_val = np.min(img_np), np.max(img_np) if max_val > min_val: img_np = (img_np - min_val) / (max_val - min_val) img_array = (img_np * 255).astype(np.uint8) return Image.fromarray(img_array) # --- 4. Main Prediction and Visualization Function --- def run_forecast(channel_name, progress=gr.Progress()): """ This function is triggered by the button click in the Gradio interface. It runs the model prediction and generates the images for display. """ progress(0, desc="Loading model and data (first run may be slow)...") # Load the model and data (will be fast after the first run due to caching) model, test_input, test_label = load_model_and_data() progress(0.5, desc="Running inference on the model...") # Perform the forecast with torch.no_grad(): prediction = model(test_input) progress(0.8, desc="Generating visualizations...") # Get the index of the selected channel channel_index = ALL_CHANNELS.index(channel_name) # Extract the last time step from the input sequence for display # Shape: [batch, channels, time, height, width] -> select channel, last time step input_slice = test_input[0, channel_index, -1, :, :] input_image = tensor_to_image(input_slice) # Extract the corresponding slice from the model's prediction # Shape: [batch, channels, time, height, width] -> select channel, first predicted step predicted_slice = prediction[0, channel_index, 0, :, :] predicted_image = tensor_to_image(predicted_slice) # Extract the corresponding slice from the ground truth label label_slice = test_label[0, channel_index, 0, :, :] label_image = tensor_to_image(label_slice) print(f"Forecast generated for channel: {channel_name}") return input_image, predicted_image, label_image # --- 5. Building the Gradio Interface --- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """