import os import tempfile from pathlib import Path # Set memory optimization environment variables os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' os.environ['ANEMOI_INFERENCE_NUM_CHUNKS'] = '16' import gradio as gr import datetime import numpy as np import matplotlib.pyplot as plt import cartopy.crs as ccrs import cartopy.feature as cfeature import matplotlib.tri as tri from anemoi.inference.runners.simple import SimpleRunner from ecmwf.opendata import Client as OpendataClient import earthkit.data as ekd import earthkit.regrid as ekr import matplotlib.animation as animation # Define parameters (updating to match notebook.py) PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"] PARAM_SOIL = ["vsw", "sot"] PARAM_PL = ["gh", "t", "u", "v", "w", "q"] LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50] SOIL_LEVELS = [1, 2] DEFAULT_DATE = OpendataClient().latest() # First organize variables into categories VARIABLE_GROUPS = { "Surface Variables": { "10u": "10m U Wind Component", "10v": "10m V Wind Component", "2d": "2m Dewpoint Temperature", "2t": "2m Temperature", "msl": "Mean Sea Level Pressure", "skt": "Skin Temperature", "sp": "Surface Pressure", "tcw": "Total Column Water", "lsm": "Land-Sea Mask", "z": "Surface Geopotential", "slor": "Slope of Sub-gridscale Orography", "sdor": "Standard Deviation of Orography", }, "Soil Variables": { "stl1": "Soil Temperature Level 1", "stl2": "Soil Temperature Level 2", "swvl1": "Soil Water Volume Level 1", "swvl2": "Soil Water Volume Level 2", }, "Pressure Level Variables": {} # Will fill this dynamically } # Add pressure level variables dynamically for var in ["t", "u", "v", "w", "q", "z"]: var_name = { "t": "Temperature", "u": "U Wind Component", "v": "V Wind Component", "w": "Vertical Velocity", "q": "Specific Humidity", "z": "Geopotential" }[var] for level in LEVELS: var_id = f"{var}_{level}" VARIABLE_GROUPS["Pressure Level Variables"][var_id] = f"{var_name} at {level}hPa" # Load the model once at startup MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device="cuda") # Default to CUDA # Create and set custom temp directory TEMP_DIR = Path("./gradio_temp") TEMP_DIR.mkdir(exist_ok=True) os.environ['GRADIO_TEMP_DIR'] = str(TEMP_DIR) def get_open_data(param, levelist=[]): fields = {} # Get the data for the current date and the previous date myiterable = [DEFAULT_DATE - datetime.timedelta(hours=6), DEFAULT_DATE] print(myiterable) for date in [DEFAULT_DATE - datetime.timedelta(hours=6), DEFAULT_DATE]: print(f"Fetching data for {date}") # sources can be seen https://earthkit-data.readthedocs.io/en/latest/guide/sources.html#id57 data = ekd.from_source("ecmwf-open-data", date=date, param=param, levelist=levelist) for f in data: assert f.to_numpy().shape == (721, 1440) values = np.roll(f.to_numpy(), -f.shape[1] // 2, axis=1) values = ekr.interpolate(values, {"grid": (0.25, 0.25)}, {"grid": "N320"}) name = f"{f.metadata('param')}_{f.metadata('levelist')}" if levelist else f.metadata("param") if name not in fields: fields[name] = [] fields[name].append(values) # Create a single matrix for each parameter for param, values in fields.items(): fields[param] = np.stack(values) return fields def plot_forecast_animation(states, selected_variable): # Setup the figure and axis fig = plt.figure(figsize=(15, 8)) ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=0)) # Get the first state to setup the plot first_state = states[0] latitudes, longitudes = first_state["latitudes"], first_state["longitudes"] fixed_lons = np.where(longitudes > 180, longitudes - 360, longitudes) triangulation = tri.Triangulation(fixed_lons, latitudes) # Find global min/max for consistent colorbar all_values = [state["fields"][selected_variable] for state in states] vmin, vmax = np.min(all_values), np.max(all_values) # Create a single colorbar that will be reused contour = None cbar_ax = None def update(frame): nonlocal contour, cbar_ax ax.clear() # Set map features ax.set_global() ax.set_extent([-180, 180, -85, 85], crs=ccrs.PlateCarree()) ax.coastlines(resolution='50m') ax.add_feature(cfeature.BORDERS, linestyle=":", alpha=0.5) ax.gridlines(draw_labels=True) state = states[frame] values = state["fields"][selected_variable] # Clear the previous colorbar axis if it exists if cbar_ax: cbar_ax.remove() # Create new contour plot contour = ax.tricontourf(triangulation, values, levels=20, transform=ccrs.PlateCarree(), cmap='RdBu_r', vmin=vmin, vmax=vmax) # Create new colorbar cbar_ax = fig.add_axes([0.1, 0.05, 0.8, 0.03]) # [left, bottom, width, height] plt.colorbar(contour, cax=cbar_ax, orientation='horizontal') # Format the date string properly forecast_time = state["date"] if isinstance(forecast_time, str): try: forecast_time = datetime.datetime.strptime(forecast_time, "%Y-%m-%d %H:%M:%S") except ValueError: try: forecast_time = datetime.datetime.strptime(forecast_time, "%Y-%m-%d %H:%M:%S.%f") except ValueError: forecast_time = DEFAULT_DATE time_str = forecast_time.strftime("%Y-%m-%d %H:%M UTC") # Get variable description from VARIABLE_GROUPS var_desc = None for group in VARIABLE_GROUPS.values(): if selected_variable in group: var_desc = group[selected_variable] break var_name = var_desc if var_desc else selected_variable ax.set_title(f"{var_name} - {time_str}") # Create animation anim = animation.FuncAnimation( fig, update, frames=len(states), interval=1000, # 1 second between frames repeat=True, blit=False # Must be False to update the colorbar ) # Save as MP4 temp_file = str(TEMP_DIR / f"forecast_{datetime.datetime.now().timestamp()}.mp4") anim.save(temp_file, writer='ffmpeg', fps=1) plt.close() return temp_file def run_forecast(date, lead_time, device): # Get all required fields fields = {} # Get surface fields fields.update(get_open_data(param=PARAM_SFC)) # Get soil fields and rename them soil = get_open_data(param=PARAM_SOIL, levelist=SOIL_LEVELS) mapping = { 'sot_1': 'stl1', 'sot_2': 'stl2', 'vsw_1': 'swvl1', 'vsw_2': 'swvl2' } for k, v in soil.items(): fields[mapping[k]] = v # Get pressure level fields fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS)) # Convert geopotential height to geopotential for level in LEVELS: gh = fields.pop(f"gh_{level}") fields[f"z_{level}"] = gh * 9.80665 input_state = dict(date=date, fields=fields) # Use the global model instance global MODEL # If device preference changed, move model to new device if device != MODEL.device: MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device) # Collect all states instead of just the last one states = [] for state in MODEL.run(input_state=input_state, lead_time=lead_time): states.append(state) return states def update_plot(lead_time, variable): cleanup_old_files() # Clean up old files before creating new ones states = run_forecast(DEFAULT_DATE, lead_time, "cuda") return plot_forecast_animation(states, variable) # Add cleanup function for old files def cleanup_old_files(): # Remove files older than 1 hour current_time = datetime.datetime.now().timestamp() for file in TEMP_DIR.glob("*.mp4"): # Changed from *.gif to *.mp4 if current_time - file.stat().st_mtime > 3600: # 1 hour in seconds file.unlink(missing_ok=True) # Create dropdown choices with groups DROPDOWN_CHOICES = [] for group_name, variables in VARIABLE_GROUPS.items(): # Add group separator DROPDOWN_CHOICES.append((f"── {group_name} ──", None)) # Add variables in this group for var_id, desc in sorted(variables.items()): DROPDOWN_CHOICES.append((f"{desc} ({var_id})", var_id)) with gr.Blocks(css=""" .centered-header { text-align: center; margin-bottom: 20px; } .subtitle { font-size: 1.2em; line-height: 1.5; margin: 20px 0; } .footer { text-align: center; padding: 20px; margin-top: 20px; border-top: 1px solid #eee; } """) as demo: # Header section gr.Markdown(f""" # AIFS Weather Forecast