File size: 11,769 Bytes
a952d46
e9a1c0f
 
a952d46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9a1c0f
a952d46
 
 
 
 
 
 
 
 
62a6171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93f7649
 
62a6171
93f7649
 
 
 
 
 
 
 
 
 
 
62a6171
 
93f7649
45b15ae
 
 
e9a1c0f
 
 
 
 
a952d46
 
 
45b15ae
 
a952d46
45b15ae
 
a952d46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9a1c0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a952d46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45b15ae
 
 
 
 
 
 
e9a1c0f
 
45b15ae
e9a1c0f
 
a952d46
e9a1c0f
 
 
 
 
 
 
 
 
 
 
 
a952d46
62a6171
 
 
 
 
 
 
 
93f7649
9a29800
 
 
 
 
e9a1c0f
 
 
 
 
 
 
 
 
 
 
9a29800
e9a1c0f
9a29800
 
 
e9a1c0f
 
 
45b15ae
e9a1c0f
 
9a29800
e9a1c0f
9a29800
 
 
 
 
 
 
 
 
 
 
 
e9a1c0f
9a29800
 
 
 
 
 
e9a1c0f
9a29800
e9a1c0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a29800
 
 
e9a1c0f
 
 
9a29800
 
e9a1c0f
 
 
 
 
 
 
 
 
 
 
a952d46
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
import os
import tempfile
from pathlib import Path
# Set memory optimization environment variables
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
os.environ['ANEMOI_INFERENCE_NUM_CHUNKS'] = '16'

import gradio as gr
import datetime
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.tri as tri
from anemoi.inference.runners.simple import SimpleRunner
from ecmwf.opendata import Client as OpendataClient
import earthkit.data as ekd
import earthkit.regrid as ekr
import matplotlib.animation as animation

# Define parameters (updating to match notebook.py)
PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"]
PARAM_SOIL = ["vsw", "sot"]
PARAM_PL = ["gh", "t", "u", "v", "w", "q"]
LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50]
SOIL_LEVELS = [1, 2]
DEFAULT_DATE = OpendataClient().latest()

# First organize variables into categories
VARIABLE_GROUPS = {
    "Surface Variables": {
        "10u": "10m U Wind Component",
        "10v": "10m V Wind Component",
        "2d": "2m Dewpoint Temperature",
        "2t": "2m Temperature",
        "msl": "Mean Sea Level Pressure",
        "skt": "Skin Temperature",
        "sp": "Surface Pressure",
        "tcw": "Total Column Water",
        "lsm": "Land-Sea Mask",
        "z": "Surface Geopotential",
        "slor": "Slope of Sub-gridscale Orography",
        "sdor": "Standard Deviation of Orography",
    },
    "Soil Variables": {
        "stl1": "Soil Temperature Level 1",
        "stl2": "Soil Temperature Level 2",
        "swvl1": "Soil Water Volume Level 1",
        "swvl2": "Soil Water Volume Level 2",
    },
    "Pressure Level Variables": {}  # Will fill this dynamically
}

# Add pressure level variables dynamically
for var in ["t", "u", "v", "w", "q", "z"]:
    var_name = {
        "t": "Temperature",
        "u": "U Wind Component",
        "v": "V Wind Component",
        "w": "Vertical Velocity",
        "q": "Specific Humidity",
        "z": "Geopotential"
    }[var]
    
    for level in LEVELS:
        var_id = f"{var}_{level}"
        VARIABLE_GROUPS["Pressure Level Variables"][var_id] = f"{var_name} at {level}hPa"

# Load the model once at startup
MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device="cuda")  # Default to CUDA

# Create and set custom temp directory
TEMP_DIR = Path("./gradio_temp")
TEMP_DIR.mkdir(exist_ok=True)
os.environ['GRADIO_TEMP_DIR'] = str(TEMP_DIR)

def get_open_data(param, levelist=[]):
    fields = {}
    # Get the data for the current date and the previous date
    myiterable = [DEFAULT_DATE - datetime.timedelta(hours=6), DEFAULT_DATE]
    print(myiterable)
    for date in [DEFAULT_DATE - datetime.timedelta(hours=6), DEFAULT_DATE]:
        print(f"Fetching data for {date}")
        # sources can be seen https://earthkit-data.readthedocs.io/en/latest/guide/sources.html#id57
        data = ekd.from_source("ecmwf-open-data", date=date, param=param, levelist=levelist)
        for f in data:
            assert f.to_numpy().shape == (721, 1440)
            values = np.roll(f.to_numpy(), -f.shape[1] // 2, axis=1)
            values = ekr.interpolate(values, {"grid": (0.25, 0.25)}, {"grid": "N320"})
            name = f"{f.metadata('param')}_{f.metadata('levelist')}" if levelist else f.metadata("param")
            if name not in fields:
                fields[name] = []
            fields[name].append(values)
    
    # Create a single matrix for each parameter
    for param, values in fields.items():
        fields[param] = np.stack(values)
    
    return fields

def plot_forecast_animation(states, selected_variable):
    # Setup the figure and axis
    fig = plt.figure(figsize=(15, 8))
    ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=0))
    
    # Get the first state to setup the plot
    first_state = states[0]
    latitudes, longitudes = first_state["latitudes"], first_state["longitudes"]
    fixed_lons = np.where(longitudes > 180, longitudes - 360, longitudes)
    triangulation = tri.Triangulation(fixed_lons, latitudes)
    
    # Find global min/max for consistent colorbar
    all_values = [state["fields"][selected_variable] for state in states]
    vmin, vmax = np.min(all_values), np.max(all_values)
    
    # Create a single colorbar that will be reused
    contour = None
    cbar_ax = None
    
    def update(frame):
        nonlocal contour, cbar_ax
        ax.clear()
        
        # Set map features
        ax.set_global()
        ax.set_extent([-180, 180, -85, 85], crs=ccrs.PlateCarree())
        ax.coastlines(resolution='50m')
        ax.add_feature(cfeature.BORDERS, linestyle=":", alpha=0.5)
        ax.gridlines(draw_labels=True)
        
        state = states[frame]
        values = state["fields"][selected_variable]
        
        # Clear the previous colorbar axis if it exists
        if cbar_ax:
            cbar_ax.remove()
        
        # Create new contour plot
        contour = ax.tricontourf(triangulation, values,
                                levels=20, transform=ccrs.PlateCarree(),
                                cmap='RdBu_r', vmin=vmin, vmax=vmax)
        
        # Create new colorbar
        cbar_ax = fig.add_axes([0.1, 0.05, 0.8, 0.03])  # [left, bottom, width, height]
        plt.colorbar(contour, cax=cbar_ax, orientation='horizontal')
        
        # Format the date string properly
        forecast_time = state["date"]
        if isinstance(forecast_time, str):
            try:
                forecast_time = datetime.datetime.strptime(forecast_time, "%Y-%m-%d %H:%M:%S")
            except ValueError:
                try:
                    forecast_time = datetime.datetime.strptime(forecast_time, "%Y-%m-%d %H:%M:%S.%f")
                except ValueError:
                    forecast_time = DEFAULT_DATE
        
        time_str = forecast_time.strftime("%Y-%m-%d %H:%M UTC")
        
        # Get variable description from VARIABLE_GROUPS
        var_desc = None
        for group in VARIABLE_GROUPS.values():
            if selected_variable in group:
                var_desc = group[selected_variable]
                break
        var_name = var_desc if var_desc else selected_variable
        
        ax.set_title(f"{var_name} - {time_str}")
    
    # Create animation
    anim = animation.FuncAnimation(
        fig, update,
        frames=len(states),
        interval=1000,  # 1 second between frames
        repeat=True,
        blit=False  # Must be False to update the colorbar
    )
    
    # Save as MP4
    temp_file = str(TEMP_DIR / f"forecast_{datetime.datetime.now().timestamp()}.mp4")
    anim.save(temp_file, writer='ffmpeg', fps=1)
    plt.close()
    
    return temp_file

def run_forecast(date, lead_time, device):
    # Get all required fields
    fields = {}
    
    # Get surface fields
    fields.update(get_open_data(param=PARAM_SFC))
    
    # Get soil fields and rename them
    soil = get_open_data(param=PARAM_SOIL, levelist=SOIL_LEVELS)
    mapping = {
        'sot_1': 'stl1', 'sot_2': 'stl2',
        'vsw_1': 'swvl1', 'vsw_2': 'swvl2'
    }
    for k, v in soil.items():
        fields[mapping[k]] = v
    
    # Get pressure level fields
    fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS))
    
    # Convert geopotential height to geopotential
    for level in LEVELS:
        gh = fields.pop(f"gh_{level}")
        fields[f"z_{level}"] = gh * 9.80665
    
    input_state = dict(date=date, fields=fields)
    
    # Use the global model instance
    global MODEL
    # If device preference changed, move model to new device
    if device != MODEL.device:
        MODEL = SimpleRunner("aifs-single-mse-1.0.ckpt", device=device)
    
    # Collect all states instead of just the last one
    states = []
    for state in MODEL.run(input_state=input_state, lead_time=lead_time):
        states.append(state)
    return states

def update_plot(lead_time, variable):
    cleanup_old_files()  # Clean up old files before creating new ones
    states = run_forecast(DEFAULT_DATE, lead_time, "cuda")
    return plot_forecast_animation(states, variable)

# Add cleanup function for old files
def cleanup_old_files():
    # Remove files older than 1 hour
    current_time = datetime.datetime.now().timestamp()
    for file in TEMP_DIR.glob("*.mp4"):  # Changed from *.gif to *.mp4
        if current_time - file.stat().st_mtime > 3600:  # 1 hour in seconds
            file.unlink(missing_ok=True)

# Create dropdown choices with groups
DROPDOWN_CHOICES = []
for group_name, variables in VARIABLE_GROUPS.items():
    # Add group separator
    DROPDOWN_CHOICES.append((f"── {group_name} ──", None))
    # Add variables in this group
    for var_id, desc in sorted(variables.items()):
        DROPDOWN_CHOICES.append((f"{desc} ({var_id})", var_id))

with gr.Blocks(css="""
    .centered-header {
        text-align: center;
        margin-bottom: 20px;
    }
    .subtitle {
        font-size: 1.2em;
        line-height: 1.5;
        margin: 20px 0;
    }
    .footer {
        text-align: center;
        padding: 20px;
        margin-top: 20px;
        border-top: 1px solid #eee;
    }
""") as demo:
    # Header section
    gr.Markdown(f"""
    # AIFS Weather Forecast
    
    <div class="subtitle">
    Interactive visualization of ECMWF AIFS weather forecasts.<br>
    Starting from the latest available data ({DEFAULT_DATE.strftime('%Y-%m-%d %H:%M UTC')}),<br>
    select how many hours ahead you want to forecast and which meteorological variable to visualize.
    </div>
    """)
    
    # Main content
    with gr.Row():
        with gr.Column(scale=1):
            lead_time = gr.Slider(
                minimum=6, 
                maximum=48, 
                step=6, 
                value=12, 
                label="Forecast Hours Ahead"
            )
            variable = gr.Dropdown(
                choices=DROPDOWN_CHOICES,
                value="2t",
                label="Select Variable to Plot"
            )
            with gr.Row():
                clear_btn = gr.Button("Clear")
                submit_btn = gr.Button("Submit", variant="primary")
        
        with gr.Column(scale=2):
            animation_output = gr.Video()
    
    # Footer with fork instructions and model reference
    gr.Markdown("""
    <div class="footer">
    <h3>Want to run this on your own?</h3>
    You can fork this space and run it yourself:
    
    1. Visit <a href="https://huggingface.co/spaces/geobase/aifs-forecast" target="_blank">https://huggingface.co/spaces/geobase/aifs-forecast</a>\n
    2. Click the "Duplicate this Space" button in the top right\n
    3. Select your hardware requirements (GPU recommended)\n
    4. Wait for your copy to deploy
    
    <h3>Model Information</h3>
    This demo uses the <a href="https://huggingface.co/ecmwf/aifs-single-1.0" target="_blank">AIFS Single 1.0</a> model from ECMWF, 
    which is their first operationally supported Artificial Intelligence Forecasting System. The model produces highly skilled forecasts 
    for upper-air variables, surface weather parameters, and tropical cyclone tracks.
    
    Note: If you encounter any issues with this demo, trying your own fork might work better!
    </div>
    """)
    
    def clear():
        return [
            12,
            "2t",
            None
        ]
    
    # Connect the inputs to the forecast function
    submit_btn.click(
        fn=update_plot, 
        inputs=[lead_time, variable], 
        outputs=animation_output
    )
    clear_btn.click(
        fn=clear, 
        inputs=[], 
        outputs=[lead_time, variable, animation_output]
    )

demo.launch()