Spaces:
Sleeping
Sleeping
Add Gradio app and dependencies
Browse files
app.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
import torch
|
5 |
+
import dgl
|
6 |
+
import gradio as gr
|
7 |
+
import plotly.graph_objects as go
|
8 |
+
import plotly.express as px
|
9 |
+
from math import radians, cos, sin
|
10 |
+
from sklearn.neighbors import KDTree
|
11 |
+
import torch.nn as nn
|
12 |
+
|
13 |
+
# Define normalization parameters
|
14 |
+
NORM_PARAMS = {
|
15 |
+
'v_min': 0.0508155134766646,
|
16 |
+
'v_max': 23.99703788008771,
|
17 |
+
'x_min': -97.40959,
|
18 |
+
'x_max': -96.55169890366584,
|
19 |
+
'y_min': 32.587689999999995,
|
20 |
+
'y_max': 33.067024421728206
|
21 |
+
}
|
22 |
+
|
23 |
+
# Define the model class
|
24 |
+
class WindGNN(nn.Module):
|
25 |
+
def __init__(self, in_feats=5, hidden_size=128, out_feats=3, num_layers=3):
|
26 |
+
super(WindGNN, self).__init__()
|
27 |
+
self.layers = nn.ModuleList()
|
28 |
+
self.input_proj = nn.Linear(in_feats, hidden_size)
|
29 |
+
|
30 |
+
for _ in range(num_layers):
|
31 |
+
self.layers.append(dgl.nn.GraphConv(hidden_size, hidden_size))
|
32 |
+
|
33 |
+
self.velocity_head = nn.Sequential(
|
34 |
+
nn.Linear(hidden_size, hidden_size // 4),
|
35 |
+
nn.LayerNorm(hidden_size // 4),
|
36 |
+
nn.ReLU(),
|
37 |
+
nn.Dropout(0.2),
|
38 |
+
nn.Linear(hidden_size // 4, 1),
|
39 |
+
nn.Sigmoid()
|
40 |
+
)
|
41 |
+
|
42 |
+
self.direction_head = nn.Sequential(
|
43 |
+
nn.Linear(hidden_size, hidden_size // 4),
|
44 |
+
nn.LayerNorm(hidden_size // 4),
|
45 |
+
nn.ReLU(),
|
46 |
+
nn.Dropout(0.2),
|
47 |
+
nn.Linear(hidden_size // 4, 2)
|
48 |
+
)
|
49 |
+
|
50 |
+
self.layer_norms = nn.ModuleList([
|
51 |
+
nn.LayerNorm(hidden_size) for _ in range(num_layers)
|
52 |
+
])
|
53 |
+
|
54 |
+
self.dropout = nn.Dropout(0.2)
|
55 |
+
|
56 |
+
def forward(self, g, features):
|
57 |
+
h = self.input_proj(features)
|
58 |
+
for i, (conv, norm) in enumerate(zip(self.layers, self.layer_norms)):
|
59 |
+
h_new = conv(g, h)
|
60 |
+
h_new = norm(h_new)
|
61 |
+
h_new = nn.functional.relu(h_new)
|
62 |
+
h_new = self.dropout(h_new)
|
63 |
+
h = h + h_new if i > 0 else h_new
|
64 |
+
velocity = self.velocity_head(h)
|
65 |
+
direction = self.direction_head(h)
|
66 |
+
direction = nn.functional.normalize(direction, dim=1)
|
67 |
+
return torch.cat([velocity, direction], dim=1)
|
68 |
+
|
69 |
+
# Denormalize predictions
|
70 |
+
def denormalize_predictions(pred):
|
71 |
+
velocity = pred[:, 0] * (NORM_PARAMS['v_max'] - NORM_PARAMS['v_min']) + NORM_PARAMS['v_min']
|
72 |
+
direction = np.rad2deg(np.arctan2(pred[:, 1], pred[:, 2])) % 360
|
73 |
+
return velocity, direction
|
74 |
+
|
75 |
+
# Function to create plotly visualization with error annotations
|
76 |
+
def plot_errors(box_errors):
|
77 |
+
fig = go.Figure()
|
78 |
+
colors = px.colors.qualitative.Set3
|
79 |
+
for i, box_error in enumerate(box_errors):
|
80 |
+
coords = box_error['coords']
|
81 |
+
lon_min, lon_max = np.min(coords[:, 0]), np.max(coords[:, 0])
|
82 |
+
lat_min, lat_max = np.min(coords[:, 1]), np.max(coords[:, 1])
|
83 |
+
center_lon = np.mean([lon_min, lon_max])
|
84 |
+
center_lat = np.mean([lat_min, lat_max]) - (lat_max - lat_min) * 0.35
|
85 |
+
|
86 |
+
# Add box boundaries
|
87 |
+
fig.add_trace(go.Scatter(
|
88 |
+
x=[lon_min, lon_max, lon_max, lon_min, lon_min],
|
89 |
+
y=[lat_min, lat_min, lat_max, lat_max, lat_min],
|
90 |
+
mode='lines',
|
91 |
+
line=dict(color=colors[i % len(colors)])
|
92 |
+
))
|
93 |
+
|
94 |
+
# Add annotations for velocity and direction errors
|
95 |
+
fig.add_annotation(
|
96 |
+
x=center_lon,
|
97 |
+
y=center_lat + 0.02, # Adjust position for velocity error at the top
|
98 |
+
text=f"{box_error['velocity_error']:.2f} m/s",
|
99 |
+
showarrow=False,
|
100 |
+
font=dict(size=10) # Adjust font size if needed
|
101 |
+
)
|
102 |
+
fig.add_annotation(
|
103 |
+
x=center_lon,
|
104 |
+
y=center_lat - 0.01, # Slightly closer to the center for direction error
|
105 |
+
text=f"{box_error['direction_error']:.2f}°",
|
106 |
+
showarrow=False,
|
107 |
+
font=dict(size=10)
|
108 |
+
)
|
109 |
+
|
110 |
+
fig.update_layout(
|
111 |
+
title='Wind Prediction Error After 15 Minutes',
|
112 |
+
xaxis_title='Longitude',
|
113 |
+
yaxis_title='Latitude',
|
114 |
+
showlegend=False,
|
115 |
+
hovermode='closest',
|
116 |
+
width=700, # Reduced width
|
117 |
+
height=700 # Reduced height
|
118 |
+
)
|
119 |
+
fig.update_yaxes(scaleanchor="x", scaleratio=1)
|
120 |
+
return fig
|
121 |
+
|
122 |
+
# Inference function for Gradio interface
|
123 |
+
def inference(current_csv, future_csv):
|
124 |
+
current_data = pd.read_csv(current_csv)
|
125 |
+
future_data = pd.read_csv(future_csv)
|
126 |
+
|
127 |
+
# Create graph from current data
|
128 |
+
coords = current_data[['x', 'y']].to_numpy()
|
129 |
+
speeds = current_data['v'].to_numpy()
|
130 |
+
directions = current_data['d'].to_numpy()
|
131 |
+
|
132 |
+
# Normalize data
|
133 |
+
norm_x = (coords[:, 0] - NORM_PARAMS['x_min']) / (NORM_PARAMS['x_max'] - NORM_PARAMS['x_min'])
|
134 |
+
norm_y = (coords[:, 1] - NORM_PARAMS['y_min']) / (NORM_PARAMS['y_max'] - NORM_PARAMS['y_min'])
|
135 |
+
norm_speeds = (speeds - NORM_PARAMS['v_min']) / (NORM_PARAMS['v_max'] - NORM_PARAMS['v_min'])
|
136 |
+
directions_rad = np.deg2rad(directions)
|
137 |
+
sin_dir = np.sin(directions_rad)
|
138 |
+
cos_dir = np.cos(directions_rad)
|
139 |
+
|
140 |
+
features = np.column_stack([norm_x, norm_y, norm_speeds, sin_dir, cos_dir])
|
141 |
+
|
142 |
+
# Load and run model
|
143 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
144 |
+
model = WindGNN().to(device)
|
145 |
+
checkpoint = torch.load('5.pth', map_location=device)
|
146 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
147 |
+
model.eval()
|
148 |
+
|
149 |
+
# Create DGL graph
|
150 |
+
tree = KDTree(coords)
|
151 |
+
distances, indices = tree.query(coords, k=9)
|
152 |
+
src_nodes = []
|
153 |
+
dst_nodes = []
|
154 |
+
|
155 |
+
for i in range(len(coords)):
|
156 |
+
for j in range(1, 9):
|
157 |
+
neighbor_idx = indices[i][j]
|
158 |
+
src_nodes.append(i)
|
159 |
+
dst_nodes.append(neighbor_idx)
|
160 |
+
|
161 |
+
g = dgl.graph((torch.tensor(src_nodes), torch.tensor(dst_nodes)))
|
162 |
+
g.ndata['feat'] = torch.FloatTensor(features).to(device)
|
163 |
+
|
164 |
+
# Predict with model
|
165 |
+
with torch.no_grad():
|
166 |
+
predictions = model(g, g.ndata['feat']).cpu().numpy()
|
167 |
+
|
168 |
+
# Denormalize predictions
|
169 |
+
pred_velocity, pred_direction = denormalize_predictions(predictions)
|
170 |
+
|
171 |
+
# Calculate errors
|
172 |
+
true_velocity = future_data['v'].to_numpy()
|
173 |
+
true_direction = future_data['d'].to_numpy()
|
174 |
+
|
175 |
+
velocity_errors = np.abs(pred_velocity - true_velocity)
|
176 |
+
direction_errors = np.abs(pred_direction - true_direction)
|
177 |
+
|
178 |
+
mean_velocity_error = np.mean(velocity_errors)
|
179 |
+
mean_direction_error = np.mean(direction_errors)
|
180 |
+
max_velocity_error = np.max(velocity_errors)
|
181 |
+
min_velocity_error = np.min(velocity_errors)
|
182 |
+
max_direction_error = np.max(direction_errors)
|
183 |
+
min_direction_error = np.min(direction_errors)
|
184 |
+
|
185 |
+
# Prepare box errors
|
186 |
+
box_errors = []
|
187 |
+
points_per_box = len(coords) // 24
|
188 |
+
for i in range(24):
|
189 |
+
start_idx = i * points_per_box
|
190 |
+
end_idx = (i + 1) * points_per_box
|
191 |
+
box_error = {
|
192 |
+
'coords': coords[start_idx:end_idx],
|
193 |
+
'velocity_error': np.mean(velocity_errors[start_idx:end_idx]),
|
194 |
+
'direction_error': np.mean(direction_errors[start_idx:end_idx])
|
195 |
+
}
|
196 |
+
box_errors.append(box_error)
|
197 |
+
|
198 |
+
fig = plot_errors(box_errors)
|
199 |
+
|
200 |
+
# Prepare detailed error information
|
201 |
+
error_info = (
|
202 |
+
f"Mean Velocity Error: {mean_velocity_error:.3f} m/s, "
|
203 |
+
f"Mean Direction Error: {mean_direction_error:.3f}°\n"
|
204 |
+
f"Max Velocity Error: {max_velocity_error:.3f} m/s, "
|
205 |
+
f"Min Velocity Error: {min_velocity_error:.3f} m/s\n"
|
206 |
+
f"Max Direction Error: {max_direction_error:.3f}°, "
|
207 |
+
f"Min Direction Error: {min_direction_error:.3f}°"
|
208 |
+
)
|
209 |
+
|
210 |
+
# Combine original and predicted data into a DataFrame for display
|
211 |
+
result_df = pd.DataFrame({
|
212 |
+
'x': current_data['x'],
|
213 |
+
'y': current_data['y'],
|
214 |
+
'Original Velocity': true_velocity,
|
215 |
+
'Predicted Velocity': pred_velocity,
|
216 |
+
'Original Direction': true_direction,
|
217 |
+
'Predicted Direction': pred_direction
|
218 |
+
})
|
219 |
+
|
220 |
+
return fig, error_info, result_df
|
221 |
+
|
222 |
+
# Paths to example CSV files
|
223 |
+
example_csv_files = [
|
224 |
+
["2021-05-03_0500.csv", "2021-05-03_0515.csv"],
|
225 |
+
["2021-05-03_0515.csv", "2021-05-03_0530.csv"],
|
226 |
+
["2021-07-01_0700.csv", "2021-07-01_0715.csv"],
|
227 |
+
["2021-07-01_0715.csv", "2021-07-01_0730.csv"]
|
228 |
+
]
|
229 |
+
|
230 |
+
# Gradio Interface
|
231 |
+
iface = gr.Interface(
|
232 |
+
fn=inference,
|
233 |
+
inputs=[gr.File(file_types=['.csv'], label="Current Wind Data CSV"),
|
234 |
+
gr.File(file_types=['.csv'], label="Future Wind Data CSV")],
|
235 |
+
outputs=["plot", "text", gr.DataFrame(label="Prediction Results")],
|
236 |
+
title="Wind Prediction Model",
|
237 |
+
description="Upload CSV files containing current and 15-minute future wind data to see detailed prediction errors per box.",
|
238 |
+
examples=example_csv_files
|
239 |
+
)
|
240 |
+
|
241 |
+
iface.launch()
|