Spaces:
Running
Running
import gradio as gr | |
import torch | |
from huggingface_hub import snapshot_download | |
import yaml | |
import numpy as np | |
from PIL import Image | |
import requests | |
import os | |
import warnings | |
import logging | |
import datetime | |
import matplotlib.pyplot as plt | |
import sunpy.visualization.colormaps as sunpy_cm | |
import traceback | |
from io import BytesIO | |
import re | |
from surya.models.helio_spectformer import HelioSpectFormer | |
from surya.utils.data import build_scalers | |
from surya.datasets.helio import inverse_transform_single_channel | |
warnings.filterwarnings("ignore", category=UserWarning, module='sunpy') | |
warnings.filterwarnings("ignore", category=FutureWarning) | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
APP_CACHE = {} | |
CHANNEL_TO_URL_CODE = { | |
"aia94": "0094", "aia131": "0131", "aia171": "0171", "aia193": "0193", | |
"aia211": "0211", "aia304": "0304", "aia335": "0335", "aia1600": "1600", | |
"hmi_m": "HMIBC", "hmi_bx": "HMIB", "hmi_by": "HMIB", | |
"hmi_bz": "HMIB", "hmi_v": "HMID" | |
} | |
SDO_CHANNELS = list(CHANNEL_TO_URL_CODE.keys()) | |
def setup_and_load_model(): | |
if "model" in APP_CACHE: | |
yield "Model already loaded. Skipping setup." | |
return | |
yield "Downloading model files (first run only)..." | |
snapshot_download(repo_id="nasa-ibm-ai4science/Surya-1.0", local_dir="data/Surya-1.0", | |
allow_patterns=["config.yaml", "scalers.yaml", "surya.366m.v1.pt"]) | |
yield "Loading configuration and data scalers..." | |
with open("data/Surya-1.0/config.yaml") as fp: | |
config = yaml.safe_load(fp) | |
APP_CACHE["config"] = config | |
scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r")) | |
APP_CACHE["scalers"] = build_scalers(info=scalers_info) | |
yield "Initializing model architecture..." | |
model_config = config["model"] | |
model = HelioSpectFormer( | |
img_size=model_config["img_size"], patch_size=model_config["patch_size"], | |
in_chans=len(config["data"]["sdo_channels"]), embed_dim=model_config["embed_dim"], | |
time_embedding={"type": "linear", "time_dim": len(config["data"]["time_delta_input_minutes"])}, | |
depth=model_config["depth"], n_spectral_blocks=model_config["n_spectral_blocks"], | |
num_heads=model_config["num_heads"], mlp_ratio=model_config["mlp_ratio"], | |
drop_rate=model_config["drop_rate"], dtype=torch.bfloat16, | |
window_size=model_config["window_size"], dp_rank=model_config["dp_rank"], | |
learned_flow=model_config["learned_flow"], use_latitude_in_learned_flow=model_config["learned_flow"], | |
init_weights=False, checkpoint_layers=list(range(model_config["depth"])), | |
rpe=model_config["rpe"], ensemble=model_config["ensemble"], finetune=model_config["finetune"], | |
) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
APP_CACHE["device"] = device | |
yield f"Loading model weights to {device}..." | |
weights = torch.load(f"data/Surya-1.0/surya.366m.v1.pt", map_location=torch.device(device)) | |
model.load_state_dict(weights, strict=True) | |
model.to(device) | |
model.eval() | |
APP_CACHE["model"] = model | |
yield "โ Model setup complete." | |
def find_nearest_browse_image_url(channel, target_dt): | |
url_code = CHANNEL_TO_URL_CODE[channel] | |
base_url = "https://sdo.gsfc.nasa.gov/assets/img/browse" | |
for i in range(2): | |
dt_to_try = target_dt - datetime.timedelta(days=i) | |
dir_url = dt_to_try.strftime(f"{base_url}/%Y/%m/%d/") | |
response = requests.get(dir_url) | |
if response.status_code != 200: | |
continue | |
filenames = re.findall(r'href="(\d{8}_\d{6}_4096_' + url_code + r'\.jpg)"', response.text) | |
if not filenames: | |
continue | |
best_filename = "" | |
min_diff = float('inf') | |
for fname in filenames: | |
try: | |
timestamp_str = fname.split('_')[1] | |
img_dt = datetime.datetime.strptime(f"{dt_to_try.strftime('%Y%m%d')}{timestamp_str}", "%Y%m%d%H%M%S") | |
diff = abs((target_dt - img_dt).total_seconds()) | |
if diff < min_diff: | |
min_diff = diff | |
best_filename = fname | |
except (ValueError, IndexError): | |
continue | |
if best_filename: | |
return dir_url + best_filename | |
raise FileNotFoundError(f"Could not find any browse images for {channel} in the last 48 hours.") | |
def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes): | |
config = APP_CACHE["config"] | |
img_size = config["model"]["img_size"] | |
input_deltas = config["data"]["time_delta_input_minutes"] | |
input_times = [target_dt + datetime.timedelta(minutes=m) for m in input_deltas] | |
target_time = target_dt + datetime.timedelta(minutes=forecast_horizon_minutes) | |
all_times = sorted(list(set(input_times + [target_time]))) | |
images = {} | |
total_fetches = len(all_times) * len(SDO_CHANNELS) | |
fetches_done = 0 | |
yield f"Starting search for {total_fetches} data files..." | |
for t in all_times: | |
images[t] = {} | |
for channel in SDO_CHANNELS: | |
fetches_done += 1 | |
yield f"Finding [{fetches_done}/{total_fetches}]: Closest image for {channel} near {t.strftime('%Y-%m-%d %H:%M')}..." | |
image_url = find_nearest_browse_image_url(channel, t) | |
yield f"Downloading: {os.path.basename(image_url)}..." | |
response = requests.get(image_url) | |
response.raise_for_status() | |
images[t][channel] = Image.open(BytesIO(response.content)) | |
yield "โ All images found and downloaded. Starting preprocessing..." | |
scalers_dict = APP_CACHE["scalers"] | |
processed_tensors = {} | |
for t, channel_images in images.items(): | |
channel_tensors = [] | |
for i, channel in enumerate(SDO_CHANNELS): | |
img = channel_images[channel] | |
if img.mode != 'L': | |
img = img.convert('L') | |
img_resized = img.resize((img_size, img_size), Image.Resampling.LANCZOS) | |
norm_data = np.array(img_resized, dtype=np.float32) | |
scaler = scalers_dict[channel] | |
scaled_data = scaler.transform(norm_data.reshape(-1, 1)).reshape(norm_data.shape) | |
channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32))) | |
processed_tensors[t] = torch.stack(channel_tensors) | |
yield "โ Preprocessing complete." | |
input_tensor_list = [processed_tensors[t] for t in input_times] | |
input_tensor = torch.stack(input_tensor_list, dim=1).unsqueeze(0) | |
target_image_map = images[target_time] | |
last_input_image_map = images[input_times[-1]] | |
yield (input_tensor, last_input_image_map, target_image_map) | |
def run_inference(input_tensor): | |
model = APP_CACHE["model"] | |
device = APP_CACHE["device"] | |
time_deltas = APP_CACHE["config"]["data"]["time_delta_input_minutes"] | |
time_delta_tensor = torch.tensor(time_deltas, dtype=torch.float32).unsqueeze(0).to(device) | |
input_batch = {"ts": input_tensor.to(device), "time_delta_input": time_delta_tensor} | |
with torch.no_grad(): | |
with torch.autocast(device_type=device.split(':')[0], dtype=torch.bfloat16): | |
prediction = model(input_batch) | |
return prediction.cpu().to(torch.float32) | |
def generate_visualization(last_input_map, prediction_tensor, target_map, channel_name): | |
if last_input_map is None: return None, None, None | |
c_idx = SDO_CHANNELS.index(channel_name) | |
# *** FIX: Access the specific scaler for the channel from the dictionary *** | |
scaler = APP_CACHE["scalers"][channel_name] | |
# *** FIX: Access the parameters as attributes, not from to_dict() *** | |
mean = scaler.mean | |
std = scaler.std | |
epsilon = scaler.epsilon | |
sl_scale_factor = scaler.sl_scale_factor | |
pred_slice = inverse_transform_single_channel( | |
prediction_tensor[0, c_idx].numpy(), mean=mean, std=std, epsilon=epsilon, sl_scale_factor=sl_scale_factor | |
) | |
target_img_data = np.array(target_map[channel_name]) | |
vmax = np.quantile(np.nan_to_num(target_img_data), 0.995) | |
cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag' | |
cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray')) | |
def to_pil(data): | |
data_clipped = np.nan_to_num(data) | |
data_clipped = np.clip(data_clipped, 0, vmax) | |
data_norm = data_clipped / vmax if vmax > 0 else data_clipped | |
colored = (cmap(data_norm)[:, :, :3] * 255).astype(np.uint8) | |
return Image.fromarray(colored) | |
return last_input_map[channel_name], to_pil(pred_slice), target_map[channel_name] | |
def forecast_controller(date_str, hour, minute, forecast_horizon): | |
yield { | |
log_box: gr.update(value="Starting forecast...", visible=True), | |
run_button: gr.update(interactive=False), | |
date_input: gr.update(interactive=False), | |
hour_slider: gr.update(interactive=False), | |
minute_slider: gr.update(interactive=False), | |
horizon_slider: gr.update(interactive=False), | |
results_group: gr.update(visible=False) | |
} | |
try: | |
if not date_str: raise gr.Error("Please select a date.") | |
for status in setup_and_load_model(): | |
yield { log_box: status } | |
target_dt = datetime.datetime.fromisoformat(f"{date_str}T{int(hour):02d}:{int(minute):02d}:00") | |
data_pipeline = fetch_and_process_sdo_data(target_dt, forecast_horizon) | |
while True: | |
try: | |
status = next(data_pipeline) | |
if isinstance(status, tuple): | |
input_tensor, last_input_map, target_map = status | |
break | |
yield { log_box: status } | |
except StopIteration: | |
raise gr.Error("Data processing pipeline finished unexpectedly.") | |
yield { log_box: "Running AI model inference..." } | |
prediction_tensor = run_inference(input_tensor) | |
yield { log_box: "Generating final visualizations..." } | |
img_in, img_pred, img_target = generate_visualization(last_input_map, prediction_tensor, target_map, "aia171") | |
yield { | |
log_box: f"โ Forecast complete for {target_dt.isoformat()} (+{forecast_horizon} mins).", | |
results_group: gr.update(visible=True), | |
state_last_input: last_input_map, | |
state_prediction: prediction_tensor, | |
state_target: target_map, | |
input_display: img_in, | |
prediction_display: img_pred, | |
target_display: img_target, | |
} | |
except Exception as e: | |
error_str = traceback.format_exc() | |
logger.error(f"An error occurred: {e}\n{error_str}") | |
yield { log_box: f"โ ERROR: {e}\n\nTraceback:\n{error_str}" } | |
finally: | |
yield { | |
run_button: gr.update(interactive=True), | |
date_input: gr.update(interactive=True), | |
hour_slider: gr.update(interactive=True), | |
minute_slider: gr.update(interactive=True), | |
horizon_slider: gr.update(interactive=True), | |
} | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
state_last_input = gr.State() | |
state_prediction = gr.State() | |
state_target = gr.State() | |
gr.Markdown( | |
""" | |
# โ๏ธ Surya: Live Forecast Demo โ๏ธ | |
### A Foundation Model for Solar Dynamics | |
This demo runs NASA's **Surya**, a foundation model trained to understand the physics of the Sun. | |
It looks at the Sun in 13 different channels (wavelengths of light) simultaneously to learn the complex relationships between phenomena like coronal loops, magnetic fields, and solar flares. By seeing these interconnected views, it can generate a holistic forecast of what the entire solar disk will look like in the near future. | |
<br> | |
<p style="color:red;font-weight:bold;">NOTE: This demo uses lower-quality browse images for reliability. The model was trained on high-fidelity scientific data, so forecast accuracy may vary.</p> | |
""" | |
) | |
with gr.Accordion("Step 1: Configure Forecast", open=True): | |
with gr.Row(): | |
date_input = gr.Textbox( | |
label="Date (YYYY-MM-DD)", | |
value=(datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=3)).strftime("%Y-%m-%d") | |
) | |
hour_slider = gr.Slider(label="Hour (UTC)", minimum=0, maximum=23, step=1, value=(datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=3)).hour) | |
minute_slider = gr.Slider(label="Minute (UTC)", minimum=0, maximum=59, step=1, value=(datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=3)).minute) | |
horizon_slider = gr.Slider( | |
label="Forecast Horizon (minutes ahead)", | |
minimum=12, maximum=120, step=12, value=12 | |
) | |
run_button = gr.Button("๐ฎ Generate Forecast", variant="primary") | |
with gr.Accordion("Step 2: View Log", open=False): | |
log_box = gr.Textbox(label="Log", interactive=False, visible=False, lines=5, max_lines=10) | |
with gr.Group(visible=False) as results_group: | |
gr.Markdown("### Step 3: Explore Results") | |
channel_selector = gr.Dropdown( | |
choices=SDO_CHANNELS, value="aia171", label="๐ฐ๏ธ Select SDO Channel to Visualize" | |
) | |
with gr.Row(): | |
input_display = gr.Image(label="Last Input to Model", height=512, width=512, interactive=False) | |
prediction_display = gr.Image(label="Surya's Forecast", height=512, width=512, interactive=False) | |
target_display = gr.Image(label="Ground Truth", height=512, width=512, interactive=False) | |
run_button.click( | |
fn=forecast_controller, | |
inputs=[date_input, hour_slider, minute_slider, horizon_slider], | |
outputs=[ | |
log_box, run_button, date_input, hour_slider, minute_slider, horizon_slider, results_group, | |
state_last_input, state_prediction, state_target, | |
input_display, prediction_display, target_display | |
] | |
) | |
channel_selector.change( | |
fn=generate_visualization, | |
inputs=[state_last_input, state_prediction, state_target, channel_selector], | |
outputs=[input_display, prediction_display, target_display] | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) |