import os # Set memory optimization environment variables os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' os.environ['ANEMOI_INFERENCE_NUM_CHUNKS'] = '16' import gradio as gr import datetime import numpy as np import matplotlib.pyplot as plt import cartopy.crs as ccrs import cartopy.feature as cfeature import matplotlib.tri as tri from anemoi.inference.runners.simple import SimpleRunner from ecmwf.opendata import Client as OpendataClient import earthkit.data as ekd import earthkit.regrid as ekr # Define parameters (updating to match notebook.py) PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"] PARAM_SOIL = ["vsw", "sot"] PARAM_PL = ["gh", "t", "u", "v", "w", "q"] LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50] SOIL_LEVELS = [1, 2] DEFAULT_DATE = OpendataClient().latest() # First define the variable descriptions VARIABLE_DESCRIPTIONS = { # Surface variables (10m) "10u": "10m U Wind Component", "10v": "10m V Wind Component", "2d": "2m Dewpoint Temperature", "2t": "2m Temperature", "msl": "Mean Sea Level Pressure", "skt": "Skin Temperature", "sp": "Surface Pressure", "tcw": "Total Column Water", "lsm": "Land-Sea Mask", "z": "Surface Geopotential", "slor": "Slope of Sub-gridscale Orography", "sdor": "Standard Deviation of Orography", # Soil variables "stl1": "Soil Temperature Level 1", "stl2": "Soil Temperature Level 2", "swvl1": "Soil Water Volume Level 1", "swvl2": "Soil Water Volume Level 2", } # Add pressure level variable descriptions dynamically for var in ["t", "u", "v", "w", "q", "z"]: var_name = { "t": "Temperature", "u": "U Wind Component", "v": "V Wind Component", "w": "Vertical Velocity", "q": "Specific Humidity", "z": "Geopotential" }[var] for level in LEVELS: VARIABLE_DESCRIPTIONS[f"{var}_{level}"] = f"{var_name} at {level}hPa" def get_open_data(param, levelist=[]): fields = {} # Get the data for the current date and the previous date for date in [DEFAULT_DATE - datetime.timedelta(hours=6), DEFAULT_DATE]: data = ekd.from_source("ecmwf-open-data", date=date, param=param, levelist=levelist) for f in data: assert f.to_numpy().shape == (721, 1440) values = np.roll(f.to_numpy(), -f.shape[1] // 2, axis=1) values = ekr.interpolate(values, {"grid": (0.25, 0.25)}, {"grid": "N320"}) name = f"{f.metadata('param')}_{f.metadata('levelist')}" if levelist else f.metadata("param") if name not in fields: fields[name] = [] fields[name].append(values) # Create a single matrix for each parameter for param, values in fields.items(): fields[param] = np.stack(values) return fields def run_forecast(date, lead_time, device): # Get all required fields fields = {} # Get surface fields fields.update(get_open_data(param=PARAM_SFC)) # Get soil fields and rename them soil = get_open_data(param=PARAM_SOIL, levelist=SOIL_LEVELS) mapping = { 'sot_1': 'stl1', 'sot_2': 'stl2', 'vsw_1': 'swvl1', 'vsw_2': 'swvl2' } for k, v in soil.items(): fields[mapping[k]] = v # Get pressure level fields fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS)) # Convert geopotential height to geopotential for level in LEVELS: gh = fields.pop(f"gh_{level}") fields[f"z_{level}"] = gh * 9.80665 input_state = dict(date=date, fields=fields) runner = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device) results = [] for state in runner.run(input_state=input_state, lead_time=lead_time): results.append(state) return results[-1] def plot_forecast(state, selected_variable): latitudes, longitudes = state["latitudes"], state["longitudes"] values = state["fields"][selected_variable] fig, ax = plt.subplots(figsize=(11, 6), subplot_kw={"projection": ccrs.PlateCarree()}) ax.coastlines() ax.add_feature(cfeature.BORDERS, linestyle=":") triangulation = tri.Triangulation(longitudes, latitudes) contour = ax.tricontourf(triangulation, values, levels=20, transform=ccrs.PlateCarree(), cmap="RdBu") plt.title(f"{selected_variable} at {state['date']}") plt.colorbar(contour) return fig # Then create the available variables list AVAILABLE_VARIABLES = ( # Surface variables ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"] + # Soil variables ["stl1", "stl2", "swvl1", "swvl2"] + # Pressure level variables (adding level suffix) [f"{var}_{level}" for var in ["t", "u", "v", "w", "q", "z"] for level in LEVELS] ) # Finally create the dropdown choices DROPDOWN_CHOICES = [ (f"{VARIABLE_DESCRIPTIONS[var_id]} ({var_id})", var_id) for var_id in sorted(AVAILABLE_VARIABLES) ] def gradio_interface(date_str, lead_time, device, selected_variable): try: date = datetime.datetime.strptime(date_str, "%Y-%m-%d") except ValueError: raise gr.Error("Please enter a valid date in YYYY-MM-DD format") state = run_forecast(date, lead_time, device) return plot_forecast(state, selected_variable) demo = gr.Interface( fn=gradio_interface, inputs=[ gr.Textbox(value=DEFAULT_DATE.strftime("%Y-%m-%d"), label="Forecast Date (YYYY-MM-DD)"), gr.Slider(minimum=6, maximum=48, step=6, value=12, label="Lead Time (Hours)"), gr.Radio(choices=["cuda", "cpu"], value="cuda", label="Compute Device"), gr.Dropdown( choices=DROPDOWN_CHOICES, value="t_850", # This should be the variable ID label="Select Variable to Plot", info="Choose a meteorological variable to visualize" ) ], outputs=gr.Plot(), title="AIFS Weather Forecast", description="Interactive visualization of ECMWF AIFS weather forecasts. Select a date, forecast lead time, and meteorological variable to plot." ) demo.launch()