import os import tempfile from pathlib import Path # Set memory optimization environment variables os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' os.environ['ANEMOI_INFERENCE_NUM_CHUNKS'] = '16' import gradio as gr import datetime import numpy as np import matplotlib.pyplot as plt import cartopy.crs as ccrs import cartopy.feature as cfeature import matplotlib.tri as tri from anemoi.inference.runners.simple import SimpleRunner from ecmwf.opendata import Client as OpendataClient import earthkit.data as ekd import earthkit.regrid as ekr import matplotlib.animation as animation from functools import lru_cache import hashlib import pickle import json from typing import List, Dict, Any import logging # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Define parameters (updating to match notebook.py) PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"] PARAM_SOIL = ["vsw", "sot"] PARAM_PL = ["gh", "t", "u", "v", "w", "q"] LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50] SOIL_LEVELS = [1, 2] DEFAULT_DATE = OpendataClient().latest() # First organize variables into categories VARIABLE_GROUPS = { "Surface Variables": { "10u": "10m U Wind Component", "10v": "10m V Wind Component", "2d": "2m Dewpoint Temperature", "2t": "2m Temperature", "msl": "Mean Sea Level Pressure", "skt": "Skin Temperature", "sp": "Surface Pressure", "tcw": "Total Column Water", "lsm": "Land-Sea Mask", "z": "Surface Geopotential", "slor": "Slope of Sub-gridscale Orography", "sdor": "Standard Deviation of Orography", }, "Soil Variables": { "stl1": "Soil Temperature Level 1", "stl2": "Soil Temperature Level 2", "swvl1": "Soil Water Volume Level 1", "swvl2": "Soil Water Volume Level 2", }, "Pressure Level Variables": {} # Will fill this dynamically } # Add pressure level variables dynamically for var in ["t", "u", "v", "w", "q", "z"]: var_name = { "t": "Temperature", "u": "U Wind Component", "v": "V Wind Component", "w": "Vertical Velocity", "q": "Specific Humidity", "z": "Geopotential" }[var] for level in LEVELS: var_id = f"{var}_{level}" VARIABLE_GROUPS["Pressure Level Variables"][var_id] = f"{var_name} at {level}hPa" # Load the model once at startup MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device="cuda") # Default to CUDA # Create and set custom temp directory TEMP_DIR = Path("./gradio_temp") TEMP_DIR.mkdir(exist_ok=True) os.environ['GRADIO_TEMP_DIR'] = str(TEMP_DIR) # Add these cache-related functions after the MODEL initialization def get_cache_key(date: datetime.datetime, params: List[str], levellist: List[int]) -> str: """Create a unique cache key based on the request parameters""" key_parts = [ date.isoformat(), ",".join(sorted(params)), ",".join(str(x) for x in sorted(levellist)) if levellist else "no_levels" ] key_string = "_".join(key_parts) cache_key = hashlib.md5(key_string.encode()).hexdigest() logger.info(f"Generated cache key: {cache_key} for {key_string}") return cache_key def get_cache_path(cache_key: str) -> Path: """Get the path to the cache file""" cache_dir = TEMP_DIR / "data_cache" cache_dir.mkdir(exist_ok=True) return cache_dir / f"{cache_key}.pkl" def save_to_cache(cache_key: str, data: Dict[str, Any]) -> None: """Save data to disk cache""" cache_file = get_cache_path(cache_key) try: with open(cache_file, 'wb') as f: pickle.dump(data, f) logger.info(f"Successfully saved data to cache: {cache_file}") except Exception as e: logger.error(f"Failed to save to cache: {e}") def load_from_cache(cache_key: str) -> Dict[str, Any]: """Load data from disk cache""" cache_file = get_cache_path(cache_key) if cache_file.exists(): try: with open(cache_file, 'rb') as f: data = pickle.load(f) logger.info(f"Successfully loaded data from cache: {cache_file}") return data except Exception as e: logger.error(f"Failed to load from cache: {e}") cache_file.unlink(missing_ok=True) logger.info(f"No cache file found: {cache_file}") return None # Modify the get_open_data function to use caching @lru_cache(maxsize=32) def get_cached_data(date_str: str, param_tuple: tuple, levelist_tuple: tuple) -> Dict[str, Any]: """Memory cache wrapper for get_open_data""" return get_open_data_impl( datetime.datetime.fromisoformat(date_str), list(param_tuple), list(levelist_tuple) if levelist_tuple else [] ) def get_open_data(param: List[str], levelist: List[int] = None) -> Dict[str, Any]: """Main function to get data with caching""" if levelist is None: levelist = [] # Try disk cache first (more persistent than memory cache) cache_key = get_cache_key(DEFAULT_DATE, param, levelist) logger.info(f"Checking cache for key: {cache_key}") cached_data = load_from_cache(cache_key) if cached_data is not None: logger.info(f"Cache hit for {cache_key}") return cached_data # If not in cache, download and process the data logger.info(f"Cache miss for {cache_key}, downloading fresh data") fields = get_open_data_impl(DEFAULT_DATE, param, levelist) # Save to disk cache save_to_cache(cache_key, fields) return fields def get_open_data_impl(date: datetime.datetime, param: List[str], levelist: List[int]) -> Dict[str, Any]: """Implementation of data download and processing""" fields = {} myiterable = [date - datetime.timedelta(hours=6), date] logger.info(f"Downloading data for dates: {myiterable}") for current_date in myiterable: logger.info(f"Fetching data for {current_date}") data = ekd.from_source("ecmwf-open-data", date=current_date, param=param, levelist=levelist) for f in data: assert f.to_numpy().shape == (721, 1440) values = np.roll(f.to_numpy(), -f.shape[1] // 2, axis=1) values = ekr.interpolate(values, {"grid": (0.25, 0.25)}, {"grid": "N320"}) name = f"{f.metadata('param')}_{f.metadata('levelist')}" if levelist else f.metadata("param") if name not in fields: fields[name] = [] fields[name].append(values) # Create a single matrix for each parameter for param, values in fields.items(): fields[param] = np.stack(values) return fields def plot_forecast(state, selected_variable): logger.info(f"Plotting forecast for {selected_variable} at time {state['date']}") # Setup the figure and axis fig = plt.figure(figsize=(15, 8)) ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=0)) # Get the coordinates latitudes, longitudes = state["latitudes"], state["longitudes"] fixed_lons = np.where(longitudes > 180, longitudes - 360, longitudes) triangulation = tri.Triangulation(fixed_lons, latitudes) # Get the values values = state["fields"][selected_variable] logger.info(f"Value range: min={np.min(values):.2f}, max={np.max(values):.2f}") # Set map features ax.set_global() ax.set_extent([-180, 180, -85, 85], crs=ccrs.PlateCarree()) ax.coastlines(resolution='50m') ax.add_feature(cfeature.BORDERS, linestyle=":", alpha=0.5) ax.gridlines(draw_labels=True) # Create contour plot contour = ax.tricontourf(triangulation, values, levels=20, transform=ccrs.PlateCarree(), cmap='RdBu_r') # Add colorbar plt.colorbar(contour, ax=ax, orientation='horizontal', pad=0.05) # Format the date string forecast_time = state["date"] if isinstance(forecast_time, str): forecast_time = datetime.datetime.fromisoformat(forecast_time) time_str = forecast_time.strftime("%Y-%m-%d %H:%M UTC") # Get variable description var_desc = None for group in VARIABLE_GROUPS.values(): if selected_variable in group: var_desc = group[selected_variable] break var_name = var_desc if var_desc else selected_variable ax.set_title(f"{var_name} - {time_str}") # Save as PNG temp_file = str(TEMP_DIR / f"forecast_{datetime.datetime.now().timestamp()}.png") plt.savefig(temp_file, bbox_inches='tight', dpi=100) plt.close() return temp_file def run_forecast(date: datetime.datetime, lead_time: int, device: str) -> Dict[str, Any]: # Get all required fields fields = {} logger.info(f"Starting forecast for lead_time: {lead_time} hours") # Get surface fields logger.info("Getting surface fields...") fields.update(get_open_data(param=PARAM_SFC)) # Get soil fields and rename them logger.info("Getting soil fields...") soil = get_open_data(param=PARAM_SOIL, levelist=SOIL_LEVELS) mapping = { 'sot_1': 'stl1', 'sot_2': 'stl2', 'vsw_1': 'swvl1', 'vsw_2': 'swvl2' } for k, v in soil.items(): fields[mapping[k]] = v # Get pressure level fields logger.info("Getting pressure level fields...") fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS)) # Convert geopotential height to geopotential for level in LEVELS: gh = fields.pop(f"gh_{level}") fields[f"z_{level}"] = gh * 9.80665 input_state = dict(date=date, fields=fields) # Use the global model instance global MODEL if device != MODEL.device: MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device) # Run the model and get the final state final_state = None for state in MODEL.run(input_state=input_state, lead_time=lead_time): logger.info(f"\nšŸ˜€ date={state['date']} latitudes={state['latitudes'].shape} " f"longitudes={state['longitudes'].shape} fields={len(state['fields'])}") # Log a few example variables to show we have all fields for var in ['2t', 'msl', 't_1000', 'z_850']: if var in state['fields']: values = state['fields'][var] logger.info(f" {var:<6} shape={values.shape} " f"min={np.min(values):.6f} " f"max={np.max(values):.6f}") final_state = state logger.info(f"Final state contains {len(final_state['fields'])} variables") return final_state def get_available_variables(state): """Get available variables from the state and organize them into groups""" available_vars = set(state['fields'].keys()) # Create dropdown choices only for available variables choices = [] for group_name, variables in VARIABLE_GROUPS.items(): group_vars = [(f"{desc} ({var_id})", var_id) for var_id, desc in variables.items() if var_id in available_vars] if group_vars: # Only add group if it has available variables choices.append((f"── {group_name} ──", None)) choices.extend(group_vars) return choices def update_interface(): with gr.Blocks(css=""" .centered-header { text-align: center; margin-bottom: 20px; } .subtitle { font-size: 1.2em; line-height: 1.5; margin: 20px 0; } .footer { text-align: center; padding: 20px; margin-top: 20px; border-top: 1px solid #eee; } """) as demo: state = gr.State(None) with gr.Row(): with gr.Column(scale=1): lead_time = gr.Slider( minimum=6, maximum=48, step=6, value=12, label="Forecast Hours Ahead" ) variable = gr.Dropdown( choices=[], # Start empty value=None, label="Select Variable to Plot" ) with gr.Row(): clear_btn = gr.Button("Clear") run_btn = gr.Button("Run Forecast", variant="primary") with gr.Row(): download_json = gr.Button("Download JSON") download_nc = gr.Button("Download NetCDF") with gr.Column(scale=2): forecast_output = gr.Image() def run_and_store(lead_time): """Run forecast and store state""" state = run_forecast(DEFAULT_DATE, lead_time, "cuda") # Get available variables choices = get_available_variables(state) # Select first real variable as default default_var = next((var_id for _, var_id in choices if var_id is not None), None) # Generate initial plot plot = plot_forecast(state, default_var) if default_var else None return [state, gr.Dropdown(choices=choices), default_var, plot] def update_plot_from_state(state, variable): """Update plot using stored state""" if state is None or variable is None: return None try: return plot_forecast(state, variable) except KeyError as e: logger.error(f"Variable {variable} not found in state: {e}") return None def clear(): """Clear everything""" return [None, None, gr.Dropdown(choices=[]), None] def save_json(state): if state is None: return None return save_forecast_data(state, 'json') def save_netcdf(state): if state is None: return None return save_forecast_data(state, 'netcdf') # Connect the components run_btn.click( fn=run_and_store, inputs=[lead_time], outputs=[state, variable, variable, forecast_output] ) variable.change( fn=update_plot_from_state, inputs=[state, variable], outputs=forecast_output ) clear_btn.click( fn=clear, inputs=[], outputs=[state, forecast_output, variable, variable] ) download_json.click( fn=save_json, inputs=[state], outputs=gr.File() ) download_nc.click( fn=save_netcdf, inputs=[state], outputs=gr.File() ) return demo # Create and launch the interface demo = update_interface() demo.launch()