Spaces:
Build error
Build error
import os | |
import tempfile | |
from pathlib import Path | |
# Set memory optimization environment variables | |
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' | |
os.environ['ANEMOI_INFERENCE_NUM_CHUNKS'] = '16' | |
import gradio as gr | |
import datetime | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import cartopy.crs as ccrs | |
import cartopy.feature as cfeature | |
import matplotlib.tri as tri | |
from anemoi.inference.runners.simple import SimpleRunner | |
from ecmwf.opendata import Client as OpendataClient | |
import earthkit.data as ekd | |
import earthkit.regrid as ekr | |
import matplotlib.animation as animation | |
# Define parameters (updating to match notebook.py) | |
PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"] | |
PARAM_SOIL = ["vsw", "sot"] | |
PARAM_PL = ["gh", "t", "u", "v", "w", "q"] | |
LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50] | |
SOIL_LEVELS = [1, 2] | |
DEFAULT_DATE = OpendataClient().latest() | |
# First organize variables into categories | |
VARIABLE_GROUPS = { | |
"Surface Variables": { | |
"10u": "10m U Wind Component", | |
"10v": "10m V Wind Component", | |
"2d": "2m Dewpoint Temperature", | |
"2t": "2m Temperature", | |
"msl": "Mean Sea Level Pressure", | |
"skt": "Skin Temperature", | |
"sp": "Surface Pressure", | |
"tcw": "Total Column Water", | |
"lsm": "Land-Sea Mask", | |
"z": "Surface Geopotential", | |
"slor": "Slope of Sub-gridscale Orography", | |
"sdor": "Standard Deviation of Orography", | |
}, | |
"Soil Variables": { | |
"stl1": "Soil Temperature Level 1", | |
"stl2": "Soil Temperature Level 2", | |
"swvl1": "Soil Water Volume Level 1", | |
"swvl2": "Soil Water Volume Level 2", | |
}, | |
"Pressure Level Variables": {} # Will fill this dynamically | |
} | |
# Add pressure level variables dynamically | |
for var in ["t", "u", "v", "w", "q", "z"]: | |
var_name = { | |
"t": "Temperature", | |
"u": "U Wind Component", | |
"v": "V Wind Component", | |
"w": "Vertical Velocity", | |
"q": "Specific Humidity", | |
"z": "Geopotential" | |
}[var] | |
for level in LEVELS: | |
var_id = f"{var}_{level}" | |
VARIABLE_GROUPS["Pressure Level Variables"][var_id] = f"{var_name} at {level}hPa" | |
# Load the model once at startup | |
MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device="cuda") # Default to CUDA | |
# Create and set custom temp directory | |
TEMP_DIR = Path("./gradio_temp") | |
TEMP_DIR.mkdir(exist_ok=True) | |
os.environ['GRADIO_TEMP_DIR'] = str(TEMP_DIR) | |
def get_open_data(param, levelist=[]): | |
fields = {} | |
# Get the data for the current date and the previous date | |
myiterable = [DEFAULT_DATE - datetime.timedelta(hours=6), DEFAULT_DATE] | |
print(myiterable) | |
for date in [DEFAULT_DATE - datetime.timedelta(hours=6), DEFAULT_DATE]: | |
print(f"Fetching data for {date}") | |
# sources can be seen https://earthkit-data.readthedocs.io/en/latest/guide/sources.html#id57 | |
data = ekd.from_source("ecmwf-open-data", date=date, param=param, levelist=levelist) | |
for f in data: | |
assert f.to_numpy().shape == (721, 1440) | |
values = np.roll(f.to_numpy(), -f.shape[1] // 2, axis=1) | |
values = ekr.interpolate(values, {"grid": (0.25, 0.25)}, {"grid": "N320"}) | |
name = f"{f.metadata('param')}_{f.metadata('levelist')}" if levelist else f.metadata("param") | |
if name not in fields: | |
fields[name] = [] | |
fields[name].append(values) | |
# Create a single matrix for each parameter | |
for param, values in fields.items(): | |
fields[param] = np.stack(values) | |
return fields | |
def plot_forecast_animation(states, selected_variable): | |
# Setup the figure and axis | |
fig = plt.figure(figsize=(15, 8)) | |
ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=0)) | |
# Get the first state to setup the plot | |
first_state = states[0] | |
latitudes, longitudes = first_state["latitudes"], first_state["longitudes"] | |
fixed_lons = np.where(longitudes > 180, longitudes - 360, longitudes) | |
triangulation = tri.Triangulation(fixed_lons, latitudes) | |
# Find global min/max for consistent colorbar | |
all_values = [state["fields"][selected_variable] for state in states] | |
vmin, vmax = np.min(all_values), np.max(all_values) | |
# Create a single colorbar that will be reused | |
contour = None | |
cbar_ax = None | |
def update(frame): | |
nonlocal contour, cbar_ax | |
ax.clear() | |
# Set map features | |
ax.set_global() | |
ax.set_extent([-180, 180, -85, 85], crs=ccrs.PlateCarree()) | |
ax.coastlines(resolution='50m') | |
ax.add_feature(cfeature.BORDERS, linestyle=":", alpha=0.5) | |
ax.gridlines(draw_labels=True) | |
state = states[frame] | |
values = state["fields"][selected_variable] | |
# Clear the previous colorbar axis if it exists | |
if cbar_ax: | |
cbar_ax.remove() | |
# Create new contour plot | |
contour = ax.tricontourf(triangulation, values, | |
levels=20, transform=ccrs.PlateCarree(), | |
cmap='RdBu_r', vmin=vmin, vmax=vmax) | |
# Create new colorbar | |
cbar_ax = fig.add_axes([0.1, 0.05, 0.8, 0.03]) # [left, bottom, width, height] | |
plt.colorbar(contour, cax=cbar_ax, orientation='horizontal') | |
# Format the date string properly | |
forecast_time = state["date"] | |
if isinstance(forecast_time, str): | |
try: | |
forecast_time = datetime.datetime.strptime(forecast_time, "%Y-%m-%d %H:%M:%S") | |
except ValueError: | |
try: | |
forecast_time = datetime.datetime.strptime(forecast_time, "%Y-%m-%d %H:%M:%S.%f") | |
except ValueError: | |
forecast_time = DEFAULT_DATE | |
time_str = forecast_time.strftime("%Y-%m-%d %H:%M UTC") | |
# Get variable description from VARIABLE_GROUPS | |
var_desc = None | |
for group in VARIABLE_GROUPS.values(): | |
if selected_variable in group: | |
var_desc = group[selected_variable] | |
break | |
var_name = var_desc if var_desc else selected_variable | |
ax.set_title(f"{var_name} - {time_str}") | |
# Create animation | |
anim = animation.FuncAnimation( | |
fig, update, | |
frames=len(states), | |
interval=1000, # 1 second between frames | |
repeat=True, | |
blit=False # Must be False to update the colorbar | |
) | |
# Save as MP4 | |
temp_file = str(TEMP_DIR / f"forecast_{datetime.datetime.now().timestamp()}.mp4") | |
anim.save(temp_file, writer='ffmpeg', fps=1) | |
plt.close() | |
return temp_file | |
def run_forecast(date, lead_time, device): | |
# Get all required fields | |
fields = {} | |
# Get surface fields | |
fields.update(get_open_data(param=PARAM_SFC)) | |
# Get soil fields and rename them | |
soil = get_open_data(param=PARAM_SOIL, levelist=SOIL_LEVELS) | |
mapping = { | |
'sot_1': 'stl1', 'sot_2': 'stl2', | |
'vsw_1': 'swvl1', 'vsw_2': 'swvl2' | |
} | |
for k, v in soil.items(): | |
fields[mapping[k]] = v | |
# Get pressure level fields | |
fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS)) | |
# Convert geopotential height to geopotential | |
for level in LEVELS: | |
gh = fields.pop(f"gh_{level}") | |
fields[f"z_{level}"] = gh * 9.80665 | |
input_state = dict(date=date, fields=fields) | |
# Use the global model instance | |
global MODEL | |
# If device preference changed, move model to new device | |
if device != MODEL.device: | |
MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device) | |
# Collect all states instead of just the last one | |
states = [] | |
for state in MODEL.run(input_state=input_state, lead_time=lead_time): | |
states.append(state) | |
return states | |
def update_plot(lead_time, variable): | |
cleanup_old_files() # Clean up old files before creating new ones | |
states = run_forecast(DEFAULT_DATE, lead_time, "cuda") | |
return plot_forecast_animation(states, variable) | |
# Add cleanup function for old files | |
def cleanup_old_files(): | |
# Remove files older than 1 hour | |
current_time = datetime.datetime.now().timestamp() | |
for file in TEMP_DIR.glob("*.mp4"): # Changed from *.gif to *.mp4 | |
if current_time - file.stat().st_mtime > 3600: # 1 hour in seconds | |
file.unlink(missing_ok=True) | |
# Create dropdown choices with groups | |
DROPDOWN_CHOICES = [] | |
for group_name, variables in VARIABLE_GROUPS.items(): | |
# Add group separator | |
DROPDOWN_CHOICES.append((f"ββ {group_name} ββ", None)) | |
# Add variables in this group | |
for var_id, desc in sorted(variables.items()): | |
DROPDOWN_CHOICES.append((f"{desc} ({var_id})", var_id)) | |
with gr.Blocks(css=""" | |
.centered-header { | |
text-align: center; | |
margin-bottom: 20px; | |
} | |
.subtitle { | |
font-size: 1.2em; | |
line-height: 1.5; | |
margin: 20px 0; | |
} | |
.footer { | |
text-align: center; | |
padding: 20px; | |
margin-top: 20px; | |
border-top: 1px solid #eee; | |
} | |
""") as demo: | |
# Header section | |
gr.Markdown(f""" | |
# AIFS Weather Forecast | |
<div class="subtitle"> | |
Interactive visualization of ECMWF AIFS weather forecasts.<br> | |
Starting from the latest available data ({DEFAULT_DATE.strftime('%Y-%m-%d %H:%M UTC')}),<br> | |
select how many hours ahead you want to forecast and which meteorological variable to visualize. | |
</div> | |
""") | |
# Main content | |
with gr.Row(): | |
with gr.Column(scale=1): | |
lead_time = gr.Slider( | |
minimum=6, | |
maximum=48, | |
step=6, | |
value=12, | |
label="Forecast Hours Ahead" | |
) | |
variable = gr.Dropdown( | |
choices=DROPDOWN_CHOICES, | |
value="2t", | |
label="Select Variable to Plot" | |
) | |
with gr.Row(): | |
clear_btn = gr.Button("Clear") | |
submit_btn = gr.Button("Submit", variant="primary") | |
with gr.Column(scale=2): | |
animation_output = gr.Video() | |
# Footer with fork instructions and model reference | |
gr.Markdown(""" | |
<div class="footer"> | |
<h3>Want to run this on your own?</h3> | |
You can fork this space and run it yourself: | |
1. Visit <a href="https://huggingface.co/spaces/geobase/aifs-forecast" target="_blank">https://huggingface.co/spaces/geobase/aifs-forecast</a>\n | |
2. Click the "Duplicate this Space" button in the top right\n | |
3. Select your hardware requirements (GPU recommended)\n | |
4. Wait for your copy to deploy | |
<h3>Model Information</h3> | |
This demo uses the <a href="https://huggingface.co/ecmwf/aifs-single-1.0" target="_blank">AIFS Single 1.0</a> model from ECMWF, | |
which is their first operationally supported Artificial Intelligence Forecasting System. The model produces highly skilled forecasts | |
for upper-air variables, surface weather parameters, and tropical cyclone tracks. | |
Note: If you encounter any issues with this demo, trying your own fork might work better! | |
</div> | |
""") | |
def clear(): | |
return [ | |
12, | |
"2t", | |
None | |
] | |
# Connect the inputs to the forecast function | |
submit_btn.click( | |
fn=update_plot, | |
inputs=[lead_time, variable], | |
outputs=animation_output | |
) | |
clear_btn.click( | |
fn=clear, | |
inputs=[], | |
outputs=[lead_time, variable, animation_output] | |
) | |
demo.launch() | |