jithin14 commited on
Commit
0704ad2
·
1 Parent(s): 50df795

Add Gradio app and dependencies

Browse files
Files changed (1) hide show
  1. app.py +241 -0
app.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ import pandas as pd
4
+ import torch
5
+ import dgl
6
+ import gradio as gr
7
+ import plotly.graph_objects as go
8
+ import plotly.express as px
9
+ from math import radians, cos, sin
10
+ from sklearn.neighbors import KDTree
11
+ import torch.nn as nn
12
+
13
+ # Define normalization parameters
14
+ NORM_PARAMS = {
15
+ 'v_min': 0.0508155134766646,
16
+ 'v_max': 23.99703788008771,
17
+ 'x_min': -97.40959,
18
+ 'x_max': -96.55169890366584,
19
+ 'y_min': 32.587689999999995,
20
+ 'y_max': 33.067024421728206
21
+ }
22
+
23
+ # Define the model class
24
+ class WindGNN(nn.Module):
25
+ def __init__(self, in_feats=5, hidden_size=128, out_feats=3, num_layers=3):
26
+ super(WindGNN, self).__init__()
27
+ self.layers = nn.ModuleList()
28
+ self.input_proj = nn.Linear(in_feats, hidden_size)
29
+
30
+ for _ in range(num_layers):
31
+ self.layers.append(dgl.nn.GraphConv(hidden_size, hidden_size))
32
+
33
+ self.velocity_head = nn.Sequential(
34
+ nn.Linear(hidden_size, hidden_size // 4),
35
+ nn.LayerNorm(hidden_size // 4),
36
+ nn.ReLU(),
37
+ nn.Dropout(0.2),
38
+ nn.Linear(hidden_size // 4, 1),
39
+ nn.Sigmoid()
40
+ )
41
+
42
+ self.direction_head = nn.Sequential(
43
+ nn.Linear(hidden_size, hidden_size // 4),
44
+ nn.LayerNorm(hidden_size // 4),
45
+ nn.ReLU(),
46
+ nn.Dropout(0.2),
47
+ nn.Linear(hidden_size // 4, 2)
48
+ )
49
+
50
+ self.layer_norms = nn.ModuleList([
51
+ nn.LayerNorm(hidden_size) for _ in range(num_layers)
52
+ ])
53
+
54
+ self.dropout = nn.Dropout(0.2)
55
+
56
+ def forward(self, g, features):
57
+ h = self.input_proj(features)
58
+ for i, (conv, norm) in enumerate(zip(self.layers, self.layer_norms)):
59
+ h_new = conv(g, h)
60
+ h_new = norm(h_new)
61
+ h_new = nn.functional.relu(h_new)
62
+ h_new = self.dropout(h_new)
63
+ h = h + h_new if i > 0 else h_new
64
+ velocity = self.velocity_head(h)
65
+ direction = self.direction_head(h)
66
+ direction = nn.functional.normalize(direction, dim=1)
67
+ return torch.cat([velocity, direction], dim=1)
68
+
69
+ # Denormalize predictions
70
+ def denormalize_predictions(pred):
71
+ velocity = pred[:, 0] * (NORM_PARAMS['v_max'] - NORM_PARAMS['v_min']) + NORM_PARAMS['v_min']
72
+ direction = np.rad2deg(np.arctan2(pred[:, 1], pred[:, 2])) % 360
73
+ return velocity, direction
74
+
75
+ # Function to create plotly visualization with error annotations
76
+ def plot_errors(box_errors):
77
+ fig = go.Figure()
78
+ colors = px.colors.qualitative.Set3
79
+ for i, box_error in enumerate(box_errors):
80
+ coords = box_error['coords']
81
+ lon_min, lon_max = np.min(coords[:, 0]), np.max(coords[:, 0])
82
+ lat_min, lat_max = np.min(coords[:, 1]), np.max(coords[:, 1])
83
+ center_lon = np.mean([lon_min, lon_max])
84
+ center_lat = np.mean([lat_min, lat_max]) - (lat_max - lat_min) * 0.35
85
+
86
+ # Add box boundaries
87
+ fig.add_trace(go.Scatter(
88
+ x=[lon_min, lon_max, lon_max, lon_min, lon_min],
89
+ y=[lat_min, lat_min, lat_max, lat_max, lat_min],
90
+ mode='lines',
91
+ line=dict(color=colors[i % len(colors)])
92
+ ))
93
+
94
+ # Add annotations for velocity and direction errors
95
+ fig.add_annotation(
96
+ x=center_lon,
97
+ y=center_lat + 0.02, # Adjust position for velocity error at the top
98
+ text=f"{box_error['velocity_error']:.2f} m/s",
99
+ showarrow=False,
100
+ font=dict(size=10) # Adjust font size if needed
101
+ )
102
+ fig.add_annotation(
103
+ x=center_lon,
104
+ y=center_lat - 0.01, # Slightly closer to the center for direction error
105
+ text=f"{box_error['direction_error']:.2f}°",
106
+ showarrow=False,
107
+ font=dict(size=10)
108
+ )
109
+
110
+ fig.update_layout(
111
+ title='Wind Prediction Error After 15 Minutes',
112
+ xaxis_title='Longitude',
113
+ yaxis_title='Latitude',
114
+ showlegend=False,
115
+ hovermode='closest',
116
+ width=700, # Reduced width
117
+ height=700 # Reduced height
118
+ )
119
+ fig.update_yaxes(scaleanchor="x", scaleratio=1)
120
+ return fig
121
+
122
+ # Inference function for Gradio interface
123
+ def inference(current_csv, future_csv):
124
+ current_data = pd.read_csv(current_csv)
125
+ future_data = pd.read_csv(future_csv)
126
+
127
+ # Create graph from current data
128
+ coords = current_data[['x', 'y']].to_numpy()
129
+ speeds = current_data['v'].to_numpy()
130
+ directions = current_data['d'].to_numpy()
131
+
132
+ # Normalize data
133
+ norm_x = (coords[:, 0] - NORM_PARAMS['x_min']) / (NORM_PARAMS['x_max'] - NORM_PARAMS['x_min'])
134
+ norm_y = (coords[:, 1] - NORM_PARAMS['y_min']) / (NORM_PARAMS['y_max'] - NORM_PARAMS['y_min'])
135
+ norm_speeds = (speeds - NORM_PARAMS['v_min']) / (NORM_PARAMS['v_max'] - NORM_PARAMS['v_min'])
136
+ directions_rad = np.deg2rad(directions)
137
+ sin_dir = np.sin(directions_rad)
138
+ cos_dir = np.cos(directions_rad)
139
+
140
+ features = np.column_stack([norm_x, norm_y, norm_speeds, sin_dir, cos_dir])
141
+
142
+ # Load and run model
143
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
144
+ model = WindGNN().to(device)
145
+ checkpoint = torch.load('5.pth', map_location=device)
146
+ model.load_state_dict(checkpoint['model_state_dict'])
147
+ model.eval()
148
+
149
+ # Create DGL graph
150
+ tree = KDTree(coords)
151
+ distances, indices = tree.query(coords, k=9)
152
+ src_nodes = []
153
+ dst_nodes = []
154
+
155
+ for i in range(len(coords)):
156
+ for j in range(1, 9):
157
+ neighbor_idx = indices[i][j]
158
+ src_nodes.append(i)
159
+ dst_nodes.append(neighbor_idx)
160
+
161
+ g = dgl.graph((torch.tensor(src_nodes), torch.tensor(dst_nodes)))
162
+ g.ndata['feat'] = torch.FloatTensor(features).to(device)
163
+
164
+ # Predict with model
165
+ with torch.no_grad():
166
+ predictions = model(g, g.ndata['feat']).cpu().numpy()
167
+
168
+ # Denormalize predictions
169
+ pred_velocity, pred_direction = denormalize_predictions(predictions)
170
+
171
+ # Calculate errors
172
+ true_velocity = future_data['v'].to_numpy()
173
+ true_direction = future_data['d'].to_numpy()
174
+
175
+ velocity_errors = np.abs(pred_velocity - true_velocity)
176
+ direction_errors = np.abs(pred_direction - true_direction)
177
+
178
+ mean_velocity_error = np.mean(velocity_errors)
179
+ mean_direction_error = np.mean(direction_errors)
180
+ max_velocity_error = np.max(velocity_errors)
181
+ min_velocity_error = np.min(velocity_errors)
182
+ max_direction_error = np.max(direction_errors)
183
+ min_direction_error = np.min(direction_errors)
184
+
185
+ # Prepare box errors
186
+ box_errors = []
187
+ points_per_box = len(coords) // 24
188
+ for i in range(24):
189
+ start_idx = i * points_per_box
190
+ end_idx = (i + 1) * points_per_box
191
+ box_error = {
192
+ 'coords': coords[start_idx:end_idx],
193
+ 'velocity_error': np.mean(velocity_errors[start_idx:end_idx]),
194
+ 'direction_error': np.mean(direction_errors[start_idx:end_idx])
195
+ }
196
+ box_errors.append(box_error)
197
+
198
+ fig = plot_errors(box_errors)
199
+
200
+ # Prepare detailed error information
201
+ error_info = (
202
+ f"Mean Velocity Error: {mean_velocity_error:.3f} m/s, "
203
+ f"Mean Direction Error: {mean_direction_error:.3f}°\n"
204
+ f"Max Velocity Error: {max_velocity_error:.3f} m/s, "
205
+ f"Min Velocity Error: {min_velocity_error:.3f} m/s\n"
206
+ f"Max Direction Error: {max_direction_error:.3f}°, "
207
+ f"Min Direction Error: {min_direction_error:.3f}°"
208
+ )
209
+
210
+ # Combine original and predicted data into a DataFrame for display
211
+ result_df = pd.DataFrame({
212
+ 'x': current_data['x'],
213
+ 'y': current_data['y'],
214
+ 'Original Velocity': true_velocity,
215
+ 'Predicted Velocity': pred_velocity,
216
+ 'Original Direction': true_direction,
217
+ 'Predicted Direction': pred_direction
218
+ })
219
+
220
+ return fig, error_info, result_df
221
+
222
+ # Paths to example CSV files
223
+ example_csv_files = [
224
+ ["2021-05-03_0500.csv", "2021-05-03_0515.csv"],
225
+ ["2021-05-03_0515.csv", "2021-05-03_0530.csv"],
226
+ ["2021-07-01_0700.csv", "2021-07-01_0715.csv"],
227
+ ["2021-07-01_0715.csv", "2021-07-01_0730.csv"]
228
+ ]
229
+
230
+ # Gradio Interface
231
+ iface = gr.Interface(
232
+ fn=inference,
233
+ inputs=[gr.File(file_types=['.csv'], label="Current Wind Data CSV"),
234
+ gr.File(file_types=['.csv'], label="Future Wind Data CSV")],
235
+ outputs=["plot", "text", gr.DataFrame(label="Prediction Results")],
236
+ title="Wind Prediction Model",
237
+ description="Upload CSV files containing current and 15-minute future wind data to see detailed prediction errors per box.",
238
+ examples=example_csv_files
239
+ )
240
+
241
+ iface.launch()