saburq's picture
add animation
e9a1c0f
raw
history blame
11.8 kB
import os
import tempfile
from pathlib import Path
# Set memory optimization environment variables
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
os.environ['ANEMOI_INFERENCE_NUM_CHUNKS'] = '16'
import gradio as gr
import datetime
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.tri as tri
from anemoi.inference.runners.simple import SimpleRunner
from ecmwf.opendata import Client as OpendataClient
import earthkit.data as ekd
import earthkit.regrid as ekr
import matplotlib.animation as animation
# Define parameters (updating to match notebook.py)
PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"]
PARAM_SOIL = ["vsw", "sot"]
PARAM_PL = ["gh", "t", "u", "v", "w", "q"]
LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50]
SOIL_LEVELS = [1, 2]
DEFAULT_DATE = OpendataClient().latest()
# First organize variables into categories
VARIABLE_GROUPS = {
"Surface Variables": {
"10u": "10m U Wind Component",
"10v": "10m V Wind Component",
"2d": "2m Dewpoint Temperature",
"2t": "2m Temperature",
"msl": "Mean Sea Level Pressure",
"skt": "Skin Temperature",
"sp": "Surface Pressure",
"tcw": "Total Column Water",
"lsm": "Land-Sea Mask",
"z": "Surface Geopotential",
"slor": "Slope of Sub-gridscale Orography",
"sdor": "Standard Deviation of Orography",
},
"Soil Variables": {
"stl1": "Soil Temperature Level 1",
"stl2": "Soil Temperature Level 2",
"swvl1": "Soil Water Volume Level 1",
"swvl2": "Soil Water Volume Level 2",
},
"Pressure Level Variables": {} # Will fill this dynamically
}
# Add pressure level variables dynamically
for var in ["t", "u", "v", "w", "q", "z"]:
var_name = {
"t": "Temperature",
"u": "U Wind Component",
"v": "V Wind Component",
"w": "Vertical Velocity",
"q": "Specific Humidity",
"z": "Geopotential"
}[var]
for level in LEVELS:
var_id = f"{var}_{level}"
VARIABLE_GROUPS["Pressure Level Variables"][var_id] = f"{var_name} at {level}hPa"
# Load the model once at startup
MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device="cuda") # Default to CUDA
# Create and set custom temp directory
TEMP_DIR = Path("./gradio_temp")
TEMP_DIR.mkdir(exist_ok=True)
os.environ['GRADIO_TEMP_DIR'] = str(TEMP_DIR)
def get_open_data(param, levelist=[]):
fields = {}
# Get the data for the current date and the previous date
myiterable = [DEFAULT_DATE - datetime.timedelta(hours=6), DEFAULT_DATE]
print(myiterable)
for date in [DEFAULT_DATE - datetime.timedelta(hours=6), DEFAULT_DATE]:
print(f"Fetching data for {date}")
# sources can be seen https://earthkit-data.readthedocs.io/en/latest/guide/sources.html#id57
data = ekd.from_source("ecmwf-open-data", date=date, param=param, levelist=levelist)
for f in data:
assert f.to_numpy().shape == (721, 1440)
values = np.roll(f.to_numpy(), -f.shape[1] // 2, axis=1)
values = ekr.interpolate(values, {"grid": (0.25, 0.25)}, {"grid": "N320"})
name = f"{f.metadata('param')}_{f.metadata('levelist')}" if levelist else f.metadata("param")
if name not in fields:
fields[name] = []
fields[name].append(values)
# Create a single matrix for each parameter
for param, values in fields.items():
fields[param] = np.stack(values)
return fields
def plot_forecast_animation(states, selected_variable):
# Setup the figure and axis
fig = plt.figure(figsize=(15, 8))
ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=0))
# Get the first state to setup the plot
first_state = states[0]
latitudes, longitudes = first_state["latitudes"], first_state["longitudes"]
fixed_lons = np.where(longitudes > 180, longitudes - 360, longitudes)
triangulation = tri.Triangulation(fixed_lons, latitudes)
# Find global min/max for consistent colorbar
all_values = [state["fields"][selected_variable] for state in states]
vmin, vmax = np.min(all_values), np.max(all_values)
# Create a single colorbar that will be reused
contour = None
cbar_ax = None
def update(frame):
nonlocal contour, cbar_ax
ax.clear()
# Set map features
ax.set_global()
ax.set_extent([-180, 180, -85, 85], crs=ccrs.PlateCarree())
ax.coastlines(resolution='50m')
ax.add_feature(cfeature.BORDERS, linestyle=":", alpha=0.5)
ax.gridlines(draw_labels=True)
state = states[frame]
values = state["fields"][selected_variable]
# Clear the previous colorbar axis if it exists
if cbar_ax:
cbar_ax.remove()
# Create new contour plot
contour = ax.tricontourf(triangulation, values,
levels=20, transform=ccrs.PlateCarree(),
cmap='RdBu_r', vmin=vmin, vmax=vmax)
# Create new colorbar
cbar_ax = fig.add_axes([0.1, 0.05, 0.8, 0.03]) # [left, bottom, width, height]
plt.colorbar(contour, cax=cbar_ax, orientation='horizontal')
# Format the date string properly
forecast_time = state["date"]
if isinstance(forecast_time, str):
try:
forecast_time = datetime.datetime.strptime(forecast_time, "%Y-%m-%d %H:%M:%S")
except ValueError:
try:
forecast_time = datetime.datetime.strptime(forecast_time, "%Y-%m-%d %H:%M:%S.%f")
except ValueError:
forecast_time = DEFAULT_DATE
time_str = forecast_time.strftime("%Y-%m-%d %H:%M UTC")
# Get variable description from VARIABLE_GROUPS
var_desc = None
for group in VARIABLE_GROUPS.values():
if selected_variable in group:
var_desc = group[selected_variable]
break
var_name = var_desc if var_desc else selected_variable
ax.set_title(f"{var_name} - {time_str}")
# Create animation
anim = animation.FuncAnimation(
fig, update,
frames=len(states),
interval=1000, # 1 second between frames
repeat=True,
blit=False # Must be False to update the colorbar
)
# Save as MP4
temp_file = str(TEMP_DIR / f"forecast_{datetime.datetime.now().timestamp()}.mp4")
anim.save(temp_file, writer='ffmpeg', fps=1)
plt.close()
return temp_file
def run_forecast(date, lead_time, device):
# Get all required fields
fields = {}
# Get surface fields
fields.update(get_open_data(param=PARAM_SFC))
# Get soil fields and rename them
soil = get_open_data(param=PARAM_SOIL, levelist=SOIL_LEVELS)
mapping = {
'sot_1': 'stl1', 'sot_2': 'stl2',
'vsw_1': 'swvl1', 'vsw_2': 'swvl2'
}
for k, v in soil.items():
fields[mapping[k]] = v
# Get pressure level fields
fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS))
# Convert geopotential height to geopotential
for level in LEVELS:
gh = fields.pop(f"gh_{level}")
fields[f"z_{level}"] = gh * 9.80665
input_state = dict(date=date, fields=fields)
# Use the global model instance
global MODEL
# If device preference changed, move model to new device
if device != MODEL.device:
MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device)
# Collect all states instead of just the last one
states = []
for state in MODEL.run(input_state=input_state, lead_time=lead_time):
states.append(state)
return states
def update_plot(lead_time, variable):
cleanup_old_files() # Clean up old files before creating new ones
states = run_forecast(DEFAULT_DATE, lead_time, "cuda")
return plot_forecast_animation(states, variable)
# Add cleanup function for old files
def cleanup_old_files():
# Remove files older than 1 hour
current_time = datetime.datetime.now().timestamp()
for file in TEMP_DIR.glob("*.mp4"): # Changed from *.gif to *.mp4
if current_time - file.stat().st_mtime > 3600: # 1 hour in seconds
file.unlink(missing_ok=True)
# Create dropdown choices with groups
DROPDOWN_CHOICES = []
for group_name, variables in VARIABLE_GROUPS.items():
# Add group separator
DROPDOWN_CHOICES.append((f"── {group_name} ──", None))
# Add variables in this group
for var_id, desc in sorted(variables.items()):
DROPDOWN_CHOICES.append((f"{desc} ({var_id})", var_id))
with gr.Blocks(css="""
.centered-header {
text-align: center;
margin-bottom: 20px;
}
.subtitle {
font-size: 1.2em;
line-height: 1.5;
margin: 20px 0;
}
.footer {
text-align: center;
padding: 20px;
margin-top: 20px;
border-top: 1px solid #eee;
}
""") as demo:
# Header section
gr.Markdown(f"""
# AIFS Weather Forecast
<div class="subtitle">
Interactive visualization of ECMWF AIFS weather forecasts.<br>
Starting from the latest available data ({DEFAULT_DATE.strftime('%Y-%m-%d %H:%M UTC')}),<br>
select how many hours ahead you want to forecast and which meteorological variable to visualize.
</div>
""")
# Main content
with gr.Row():
with gr.Column(scale=1):
lead_time = gr.Slider(
minimum=6,
maximum=48,
step=6,
value=12,
label="Forecast Hours Ahead"
)
variable = gr.Dropdown(
choices=DROPDOWN_CHOICES,
value="2t",
label="Select Variable to Plot"
)
with gr.Row():
clear_btn = gr.Button("Clear")
submit_btn = gr.Button("Submit", variant="primary")
with gr.Column(scale=2):
animation_output = gr.Video()
# Footer with fork instructions and model reference
gr.Markdown("""
<div class="footer">
<h3>Want to run this on your own?</h3>
You can fork this space and run it yourself:
1. Visit <a href="https://huggingface.co/spaces/geobase/aifs-forecast" target="_blank">https://huggingface.co/spaces/geobase/aifs-forecast</a>\n
2. Click the "Duplicate this Space" button in the top right\n
3. Select your hardware requirements (GPU recommended)\n
4. Wait for your copy to deploy
<h3>Model Information</h3>
This demo uses the <a href="https://huggingface.co/ecmwf/aifs-single-1.0" target="_blank">AIFS Single 1.0</a> model from ECMWF,
which is their first operationally supported Artificial Intelligence Forecasting System. The model produces highly skilled forecasts
for upper-air variables, surface weather parameters, and tropical cyclone tracks.
Note: If you encounter any issues with this demo, trying your own fork might work better!
</div>
""")
def clear():
return [
12,
"2t",
None
]
# Connect the inputs to the forecast function
submit_btn.click(
fn=update_plot,
inputs=[lead_time, variable],
outputs=animation_output
)
clear_btn.click(
fn=clear,
inputs=[],
outputs=[lead_time, variable, animation_output]
)
demo.launch()