surya-demo / app.py
broadfield-dev's picture
Update app.py
9013587 verified
raw
history blame
11.2 kB
# Save this file as in the root of the cloned Surya repository
import gradio as gr
import torch
from huggingface_hub import snapshot_download
import yaml
import numpy as np
from PIL import Image
import sunpy.map
import sunpy.net.attrs as a
from sunpy.net import Fido
from astropy.wcs import WCS
import astropy.units as u
from reproject import reproject_interp
import os
import warnings
import logging
import datetime
import matplotlib.pyplot as plt
import sunpy.visualization.colormaps as sunpy_cm
import traceback
# --- Use the official Surya modules ---
from surya.models.helio_spectformer import HelioSpectFormer
from surya.utils.data import build_scalers
from surya.datasets.helio import inverse_transform_single_channel
# --- Configuration ---
warnings.filterwarnings("ignore", category=UserWarning, module='sunpy')
warnings.filterwarnings("ignore", category=FutureWarning)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global cache for model, config, etc.
APP_CACHE = {}
SDO_CHANNELS_MAP = {
"aia94": (a.Wavelength(94 * u.angstrom), a.Sample(12 * u.s)),
"aia131": (a.Wavelength(131 * u.angstrom), a.Sample(12 * u.s)),
"aia171": (a.Wavelength(171 * u.angstrom), a.Sample(12 * u.s)),
"aia193": (a.Wavelength(193 * u.angstrom), a.Sample(12 * u.s)),
"aia211": (a.Wavelength(211 * u.angstrom), a.Sample(12 * u.s)),
"aia304": (a.Wavelength(304 * u.angstrom), a.Sample(12 * u.s)),
"aia335": (a.Wavelength(335 * u.angstrom), a.Sample(12 * u.s)),
"aia1600": (a.Wavelength(1600 * u.angstrom), a.Sample(24 * u.s)),
"hmi_m": (a.Physobs("intensity"), a.Sample(45 * u.s)),
"hmi_bx": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)),
"hmi_by": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)), # Placeholder
"hmi_bz": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)), # Placeholder
"hmi_v": (a.Physobs("los_velocity"), a.Sample(45 * u.s)),
}
SDO_CHANNELS = list(SDO_CHANNELS_MAP.keys())
# --- 1. Model Loading and Setup ---
def setup_and_load_model():
if "model" in APP_CACHE:
yield "Model already loaded. Skipping setup."
return
yield "Downloading model files (first run only)..."
snapshot_download(repo_id="nasa-ibm-ai4science/Surya-1.0", local_dir="data/Surya-1.0",
allow_patterns=["config.yaml", "scalers.yaml", "surya.366m.v1.pt"])
yield "Loading configuration and data scalers..."
with open("data/Surya-1.0/config.yaml") as fp:
config = yaml.safe_load(fp)
APP_CACHE["config"] = config
scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r"))
APP_CACHE["scalers"] = build_scalers(info=scalers_info)
yield "Initializing model architecture..."
model_config = config["model"]
model = HelioSpectFormer(...) # Full model definition
device = "cuda" if torch.cuda.is_available() else "cpu"
APP_CACHE["device"] = device
yield f"Loading model weights to {device}..."
weights = torch.load(f"data/Surya-1.0/surya.366m.v1.pt", map_location=torch.device(device))
model.load_state_dict(weights, strict=True)
model.to(device)
model.eval()
APP_CACHE["model"] = model
yield "✅ Model setup complete."
# --- 2. Live Data Fetching and Preprocessing (as a generator) ---
def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
config = APP_CACHE["config"]
img_size = config["model"]["img_size"]
input_deltas = config["data"]["time_delta_input_minutes"]
target_delta = forecast_horizon_minutes # Use user-provided horizon
input_times = [target_dt + datetime.timedelta(minutes=m) for m in input_deltas]
target_time = target_dt + datetime.timedelta(minutes=target_delta)
all_times = sorted(list(set(input_times + [target_time])))
data_maps = {}
total_downloads = len(all_times) * len(SDO_CHANNELS)
downloads_done = 0
yield f"Starting download of {total_downloads} data files..."
for t in all_times:
data_maps[t] = {}
for i, (channel, (physobs, sample)) in enumerate(SDO_CHANNELS_MAP.items()):
downloads_done += 1
yield f"Downloading [{downloads_done}/{total_downloads}]: {channel} for {t.strftime('%Y-%m-%d %H:%M')}..."
if channel in ["hmi_by", "hmi_bz"]:
if data_maps[t].get("hmi_bx"): data_maps[t][channel] = data_maps[t]["hmi_bx"]
continue
# *** FIX: Use a.Time.nearest=True for robust fetching instead of a time window ***
instrument = a.Instrument.hmi if "hmi" in channel else a.Instrument.aia
query = Fido.search(a.Time(t), instrument, physobs, sample, a.Time.nearest==True)
if not query: raise ValueError(f"No data found for {channel} near {t}")
files = Fido.fetch(query, path="./data/sdo_cache") # Fetch the entire result
data_maps[t][channel] = sunpy.map.Map(files[0])
yield "✅ All files downloaded. Starting preprocessing..."
output_wcs = WCS(naxis=2)
output_wcs.wcs.crpix = [(img_size + 1) / 2, (img_size + 1) / 2]
output_wcs.wcs.cdelt = np.array([-1.2, 1.2]) * u.arcsec
output_wcs.wcs.crval = [0, 0] * u.arcsec
output_wcs.wcs.ctype = ['HPLN-TAN', 'HPLT-TAN']
scaler = APP_CACHE["scalers"]
processed_tensors = {}
for t, channel_maps in data_maps.items():
channel_tensors = []
for i, channel in enumerate(SDO_CHANNELS):
smap = channel_maps[channel]
reprojected_data, _ = reproject_interp(smap, output_wcs, shape_out=(img_size, img_size))
exp_time = smap.meta.get('exptime', 1.0)
if exp_time is None or exp_time <= 0: exp_time = 1.0
norm_data = reprojected_data / exp_time
scaled_data = scaler.transform(norm_data, c_idx=i)
channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32)))
processed_tensors[t] = torch.stack(channel_tensors)
yield "✅ Preprocessing complete."
input_tensor_list = [processed_tensors[t] for t in input_times]
input_tensor = torch.stack(input_tensor_list, dim=1).unsqueeze(0)
target_map = data_maps[target_time]
last_input_map = data_maps[input_times[-1]]
yield (input_tensor, last_input_map, target_map)
# --- 3. Inference and Visualization ---
def run_inference(input_tensor):
# This function remains the same
...
def generate_visualization(last_input_map, prediction_tensor, target_map, channel_name):
# This function remains the same
...
# --- 4. Gradio UI and Controllers ---
def forecast_controller(date_str, hour, minute, forecast_horizon):
yield {
log_box: gr.update(value="Starting forecast...", visible=True),
run_button: gr.update(interactive=False),
# Also disable the other controls
date_input: gr.update(interactive=False),
hour_slider: gr.update(interactive=False),
minute_slider: gr.update(interactive=False),
horizon_slider: gr.update(interactive=False),
results_group: gr.update(visible=False)
}
try:
if not date_str: raise gr.Error("Please select a date.")
for status in setup_and_load_model():
yield { log_box: status }
# Construct datetime from the new UI components
target_dt = datetime.datetime.fromisoformat(f"{date_str}T{int(hour):02d}:{int(minute):02d}:00")
data_pipeline = fetch_and_process_sdo_data(target_dt, forecast_horizon)
# The rest of the generator logic remains the same...
while True:
try:
status = next(data_pipeline)
if isinstance(status, tuple):
input_tensor, last_input_map, target_map = status
break
yield { log_box: status }
except StopIteration:
raise gr.Error("Data processing pipeline finished unexpectedly.")
yield { log_box: "Running AI model inference..." }
prediction_tensor = run_inference(input_tensor)
yield { log_box: "Generating final visualizations..." }
img_in, img_pred, img_target = generate_visualization(last_input_map, prediction_tensor, target_map, "aia171")
yield {
log_box: f"✅ Forecast complete for {target_dt.isoformat()} (+{forecast_horizon} mins).",
results_group: gr.update(visible=True),
# ... update states and images
}
except Exception as e:
# ... error handling
finally:
# Re-enable all controls
yield {
run_button: gr.update(interactive=True),
date_input: gr.update(interactive=True),
hour_slider: gr.update(interactive=True),
minute_slider: gr.update(interactive=True),
horizon_slider: gr.update(interactive=True),
}
# --- 5. Gradio UI Definition ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
# State objects remain the same
...
gr.Markdown(...) # Title remains the same
# --- NEW: Controls Section ---
with gr.Accordion("Step 1: Configure Forecast", open=True):
with gr.Row():
date_input = gr.Textbox(
label="Date",
value=datetime.date.today().strftime("%Y-%m-%d")
)
hour_slider = gr.Slider(label="Hour (UTC)", minimum=0, maximum=23, step=1, value=datetime.datetime.utcnow().hour - 3)
minute_slider = gr.Slider(label="Minute", minimum=0, maximum=59, step=1, value=datetime.datetime.utcnow().minute)
horizon_slider = gr.Slider(
label="Forecast Horizon (minutes ahead)",
minimum=12, maximum=120, step=12, value=12
)
run_button = gr.Button("🔮 Generate Forecast", variant="primary")
# --- NEW: Moved log box to its own section ---
with gr.Accordion("Step 2: View Log", open=False) as log_accordion:
log_box = gr.Textbox(label="Log", interactive=False, visible=True, lines=5, max_lines=10)
# --- Results section is now Step 3 ---
with gr.Group(visible=False) as results_group:
gr.Markdown("### Step 3: Explore Results")
channel_selector = gr.Dropdown(...)
with gr.Row():
input_display = gr.Image(...)
prediction_display = gr.Image(...)
target_display = gr.Image(...)
# --- Event Handlers ---
run_button.click(
fn=forecast_controller,
inputs=[date_input, hour_slider, minute_slider, horizon_slider],
outputs=[
log_box, run_button, date_input, hour_slider, minute_slider, horizon_slider, results_group,
state_last_input, state_prediction, state_target,
input_display, prediction_display, target_display
]
)
channel_selector.change(...) # This remains the same
if __name__ == "__main__":
# Fill in the missing ... from previous versions for the full script
# This is a condensed version showing only the key changes
demo.launch(debug=True)