jithin14's picture
Add Gradio app and dependencies
0704ad2
import json
import numpy as np
import pandas as pd
import torch
import dgl
import gradio as gr
import plotly.graph_objects as go
import plotly.express as px
from math import radians, cos, sin
from sklearn.neighbors import KDTree
import torch.nn as nn
# Define normalization parameters
NORM_PARAMS = {
'v_min': 0.0508155134766646,
'v_max': 23.99703788008771,
'x_min': -97.40959,
'x_max': -96.55169890366584,
'y_min': 32.587689999999995,
'y_max': 33.067024421728206
}
# Define the model class
class WindGNN(nn.Module):
def __init__(self, in_feats=5, hidden_size=128, out_feats=3, num_layers=3):
super(WindGNN, self).__init__()
self.layers = nn.ModuleList()
self.input_proj = nn.Linear(in_feats, hidden_size)
for _ in range(num_layers):
self.layers.append(dgl.nn.GraphConv(hidden_size, hidden_size))
self.velocity_head = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 4),
nn.LayerNorm(hidden_size // 4),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_size // 4, 1),
nn.Sigmoid()
)
self.direction_head = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 4),
nn.LayerNorm(hidden_size // 4),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_size // 4, 2)
)
self.layer_norms = nn.ModuleList([
nn.LayerNorm(hidden_size) for _ in range(num_layers)
])
self.dropout = nn.Dropout(0.2)
def forward(self, g, features):
h = self.input_proj(features)
for i, (conv, norm) in enumerate(zip(self.layers, self.layer_norms)):
h_new = conv(g, h)
h_new = norm(h_new)
h_new = nn.functional.relu(h_new)
h_new = self.dropout(h_new)
h = h + h_new if i > 0 else h_new
velocity = self.velocity_head(h)
direction = self.direction_head(h)
direction = nn.functional.normalize(direction, dim=1)
return torch.cat([velocity, direction], dim=1)
# Denormalize predictions
def denormalize_predictions(pred):
velocity = pred[:, 0] * (NORM_PARAMS['v_max'] - NORM_PARAMS['v_min']) + NORM_PARAMS['v_min']
direction = np.rad2deg(np.arctan2(pred[:, 1], pred[:, 2])) % 360
return velocity, direction
# Function to create plotly visualization with error annotations
def plot_errors(box_errors):
fig = go.Figure()
colors = px.colors.qualitative.Set3
for i, box_error in enumerate(box_errors):
coords = box_error['coords']
lon_min, lon_max = np.min(coords[:, 0]), np.max(coords[:, 0])
lat_min, lat_max = np.min(coords[:, 1]), np.max(coords[:, 1])
center_lon = np.mean([lon_min, lon_max])
center_lat = np.mean([lat_min, lat_max]) - (lat_max - lat_min) * 0.35
# Add box boundaries
fig.add_trace(go.Scatter(
x=[lon_min, lon_max, lon_max, lon_min, lon_min],
y=[lat_min, lat_min, lat_max, lat_max, lat_min],
mode='lines',
line=dict(color=colors[i % len(colors)])
))
# Add annotations for velocity and direction errors
fig.add_annotation(
x=center_lon,
y=center_lat + 0.02, # Adjust position for velocity error at the top
text=f"{box_error['velocity_error']:.2f} m/s",
showarrow=False,
font=dict(size=10) # Adjust font size if needed
)
fig.add_annotation(
x=center_lon,
y=center_lat - 0.01, # Slightly closer to the center for direction error
text=f"{box_error['direction_error']:.2f}°",
showarrow=False,
font=dict(size=10)
)
fig.update_layout(
title='Wind Prediction Error After 15 Minutes',
xaxis_title='Longitude',
yaxis_title='Latitude',
showlegend=False,
hovermode='closest',
width=700, # Reduced width
height=700 # Reduced height
)
fig.update_yaxes(scaleanchor="x", scaleratio=1)
return fig
# Inference function for Gradio interface
def inference(current_csv, future_csv):
current_data = pd.read_csv(current_csv)
future_data = pd.read_csv(future_csv)
# Create graph from current data
coords = current_data[['x', 'y']].to_numpy()
speeds = current_data['v'].to_numpy()
directions = current_data['d'].to_numpy()
# Normalize data
norm_x = (coords[:, 0] - NORM_PARAMS['x_min']) / (NORM_PARAMS['x_max'] - NORM_PARAMS['x_min'])
norm_y = (coords[:, 1] - NORM_PARAMS['y_min']) / (NORM_PARAMS['y_max'] - NORM_PARAMS['y_min'])
norm_speeds = (speeds - NORM_PARAMS['v_min']) / (NORM_PARAMS['v_max'] - NORM_PARAMS['v_min'])
directions_rad = np.deg2rad(directions)
sin_dir = np.sin(directions_rad)
cos_dir = np.cos(directions_rad)
features = np.column_stack([norm_x, norm_y, norm_speeds, sin_dir, cos_dir])
# Load and run model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = WindGNN().to(device)
checkpoint = torch.load('5.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Create DGL graph
tree = KDTree(coords)
distances, indices = tree.query(coords, k=9)
src_nodes = []
dst_nodes = []
for i in range(len(coords)):
for j in range(1, 9):
neighbor_idx = indices[i][j]
src_nodes.append(i)
dst_nodes.append(neighbor_idx)
g = dgl.graph((torch.tensor(src_nodes), torch.tensor(dst_nodes)))
g.ndata['feat'] = torch.FloatTensor(features).to(device)
# Predict with model
with torch.no_grad():
predictions = model(g, g.ndata['feat']).cpu().numpy()
# Denormalize predictions
pred_velocity, pred_direction = denormalize_predictions(predictions)
# Calculate errors
true_velocity = future_data['v'].to_numpy()
true_direction = future_data['d'].to_numpy()
velocity_errors = np.abs(pred_velocity - true_velocity)
direction_errors = np.abs(pred_direction - true_direction)
mean_velocity_error = np.mean(velocity_errors)
mean_direction_error = np.mean(direction_errors)
max_velocity_error = np.max(velocity_errors)
min_velocity_error = np.min(velocity_errors)
max_direction_error = np.max(direction_errors)
min_direction_error = np.min(direction_errors)
# Prepare box errors
box_errors = []
points_per_box = len(coords) // 24
for i in range(24):
start_idx = i * points_per_box
end_idx = (i + 1) * points_per_box
box_error = {
'coords': coords[start_idx:end_idx],
'velocity_error': np.mean(velocity_errors[start_idx:end_idx]),
'direction_error': np.mean(direction_errors[start_idx:end_idx])
}
box_errors.append(box_error)
fig = plot_errors(box_errors)
# Prepare detailed error information
error_info = (
f"Mean Velocity Error: {mean_velocity_error:.3f} m/s, "
f"Mean Direction Error: {mean_direction_error:.3f}°\n"
f"Max Velocity Error: {max_velocity_error:.3f} m/s, "
f"Min Velocity Error: {min_velocity_error:.3f} m/s\n"
f"Max Direction Error: {max_direction_error:.3f}°, "
f"Min Direction Error: {min_direction_error:.3f}°"
)
# Combine original and predicted data into a DataFrame for display
result_df = pd.DataFrame({
'x': current_data['x'],
'y': current_data['y'],
'Original Velocity': true_velocity,
'Predicted Velocity': pred_velocity,
'Original Direction': true_direction,
'Predicted Direction': pred_direction
})
return fig, error_info, result_df
# Paths to example CSV files
example_csv_files = [
["2021-05-03_0500.csv", "2021-05-03_0515.csv"],
["2021-05-03_0515.csv", "2021-05-03_0530.csv"],
["2021-07-01_0700.csv", "2021-07-01_0715.csv"],
["2021-07-01_0715.csv", "2021-07-01_0730.csv"]
]
# Gradio Interface
iface = gr.Interface(
fn=inference,
inputs=[gr.File(file_types=['.csv'], label="Current Wind Data CSV"),
gr.File(file_types=['.csv'], label="Future Wind Data CSV")],
outputs=["plot", "text", gr.DataFrame(label="Prediction Results")],
title="Wind Prediction Model",
description="Upload CSV files containing current and 15-minute future wind data to see detailed prediction errors per box.",
examples=example_csv_files
)
iface.launch()