jithin14's picture
Add Gradio app and dependencies
0704ad2
raw
history blame
8.67 kB
import json
import numpy as np
import pandas as pd
import torch
import dgl
import gradio as gr
import plotly.graph_objects as go
import plotly.express as px
from math import radians, cos, sin
from sklearn.neighbors import KDTree
import torch.nn as nn
# Define normalization parameters
NORM_PARAMS = {
'v_min': 0.0508155134766646,
'v_max': 23.99703788008771,
'x_min': -97.40959,
'x_max': -96.55169890366584,
'y_min': 32.587689999999995,
'y_max': 33.067024421728206
}
# Define the model class
class WindGNN(nn.Module):
def __init__(self, in_feats=5, hidden_size=128, out_feats=3, num_layers=3):
super(WindGNN, self).__init__()
self.layers = nn.ModuleList()
self.input_proj = nn.Linear(in_feats, hidden_size)
for _ in range(num_layers):
self.layers.append(dgl.nn.GraphConv(hidden_size, hidden_size))
self.velocity_head = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 4),
nn.LayerNorm(hidden_size // 4),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_size // 4, 1),
nn.Sigmoid()
)
self.direction_head = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 4),
nn.LayerNorm(hidden_size // 4),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_size // 4, 2)
)
self.layer_norms = nn.ModuleList([
nn.LayerNorm(hidden_size) for _ in range(num_layers)
])
self.dropout = nn.Dropout(0.2)
def forward(self, g, features):
h = self.input_proj(features)
for i, (conv, norm) in enumerate(zip(self.layers, self.layer_norms)):
h_new = conv(g, h)
h_new = norm(h_new)
h_new = nn.functional.relu(h_new)
h_new = self.dropout(h_new)
h = h + h_new if i > 0 else h_new
velocity = self.velocity_head(h)
direction = self.direction_head(h)
direction = nn.functional.normalize(direction, dim=1)
return torch.cat([velocity, direction], dim=1)
# Denormalize predictions
def denormalize_predictions(pred):
velocity = pred[:, 0] * (NORM_PARAMS['v_max'] - NORM_PARAMS['v_min']) + NORM_PARAMS['v_min']
direction = np.rad2deg(np.arctan2(pred[:, 1], pred[:, 2])) % 360
return velocity, direction
# Function to create plotly visualization with error annotations
def plot_errors(box_errors):
fig = go.Figure()
colors = px.colors.qualitative.Set3
for i, box_error in enumerate(box_errors):
coords = box_error['coords']
lon_min, lon_max = np.min(coords[:, 0]), np.max(coords[:, 0])
lat_min, lat_max = np.min(coords[:, 1]), np.max(coords[:, 1])
center_lon = np.mean([lon_min, lon_max])
center_lat = np.mean([lat_min, lat_max]) - (lat_max - lat_min) * 0.35
# Add box boundaries
fig.add_trace(go.Scatter(
x=[lon_min, lon_max, lon_max, lon_min, lon_min],
y=[lat_min, lat_min, lat_max, lat_max, lat_min],
mode='lines',
line=dict(color=colors[i % len(colors)])
))
# Add annotations for velocity and direction errors
fig.add_annotation(
x=center_lon,
y=center_lat + 0.02, # Adjust position for velocity error at the top
text=f"{box_error['velocity_error']:.2f} m/s",
showarrow=False,
font=dict(size=10) # Adjust font size if needed
)
fig.add_annotation(
x=center_lon,
y=center_lat - 0.01, # Slightly closer to the center for direction error
text=f"{box_error['direction_error']:.2f}°",
showarrow=False,
font=dict(size=10)
)
fig.update_layout(
title='Wind Prediction Error After 15 Minutes',
xaxis_title='Longitude',
yaxis_title='Latitude',
showlegend=False,
hovermode='closest',
width=700, # Reduced width
height=700 # Reduced height
)
fig.update_yaxes(scaleanchor="x", scaleratio=1)
return fig
# Inference function for Gradio interface
def inference(current_csv, future_csv):
current_data = pd.read_csv(current_csv)
future_data = pd.read_csv(future_csv)
# Create graph from current data
coords = current_data[['x', 'y']].to_numpy()
speeds = current_data['v'].to_numpy()
directions = current_data['d'].to_numpy()
# Normalize data
norm_x = (coords[:, 0] - NORM_PARAMS['x_min']) / (NORM_PARAMS['x_max'] - NORM_PARAMS['x_min'])
norm_y = (coords[:, 1] - NORM_PARAMS['y_min']) / (NORM_PARAMS['y_max'] - NORM_PARAMS['y_min'])
norm_speeds = (speeds - NORM_PARAMS['v_min']) / (NORM_PARAMS['v_max'] - NORM_PARAMS['v_min'])
directions_rad = np.deg2rad(directions)
sin_dir = np.sin(directions_rad)
cos_dir = np.cos(directions_rad)
features = np.column_stack([norm_x, norm_y, norm_speeds, sin_dir, cos_dir])
# Load and run model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = WindGNN().to(device)
checkpoint = torch.load('5.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Create DGL graph
tree = KDTree(coords)
distances, indices = tree.query(coords, k=9)
src_nodes = []
dst_nodes = []
for i in range(len(coords)):
for j in range(1, 9):
neighbor_idx = indices[i][j]
src_nodes.append(i)
dst_nodes.append(neighbor_idx)
g = dgl.graph((torch.tensor(src_nodes), torch.tensor(dst_nodes)))
g.ndata['feat'] = torch.FloatTensor(features).to(device)
# Predict with model
with torch.no_grad():
predictions = model(g, g.ndata['feat']).cpu().numpy()
# Denormalize predictions
pred_velocity, pred_direction = denormalize_predictions(predictions)
# Calculate errors
true_velocity = future_data['v'].to_numpy()
true_direction = future_data['d'].to_numpy()
velocity_errors = np.abs(pred_velocity - true_velocity)
direction_errors = np.abs(pred_direction - true_direction)
mean_velocity_error = np.mean(velocity_errors)
mean_direction_error = np.mean(direction_errors)
max_velocity_error = np.max(velocity_errors)
min_velocity_error = np.min(velocity_errors)
max_direction_error = np.max(direction_errors)
min_direction_error = np.min(direction_errors)
# Prepare box errors
box_errors = []
points_per_box = len(coords) // 24
for i in range(24):
start_idx = i * points_per_box
end_idx = (i + 1) * points_per_box
box_error = {
'coords': coords[start_idx:end_idx],
'velocity_error': np.mean(velocity_errors[start_idx:end_idx]),
'direction_error': np.mean(direction_errors[start_idx:end_idx])
}
box_errors.append(box_error)
fig = plot_errors(box_errors)
# Prepare detailed error information
error_info = (
f"Mean Velocity Error: {mean_velocity_error:.3f} m/s, "
f"Mean Direction Error: {mean_direction_error:.3f}°\n"
f"Max Velocity Error: {max_velocity_error:.3f} m/s, "
f"Min Velocity Error: {min_velocity_error:.3f} m/s\n"
f"Max Direction Error: {max_direction_error:.3f}°, "
f"Min Direction Error: {min_direction_error:.3f}°"
)
# Combine original and predicted data into a DataFrame for display
result_df = pd.DataFrame({
'x': current_data['x'],
'y': current_data['y'],
'Original Velocity': true_velocity,
'Predicted Velocity': pred_velocity,
'Original Direction': true_direction,
'Predicted Direction': pred_direction
})
return fig, error_info, result_df
# Paths to example CSV files
example_csv_files = [
["2021-05-03_0500.csv", "2021-05-03_0515.csv"],
["2021-05-03_0515.csv", "2021-05-03_0530.csv"],
["2021-07-01_0700.csv", "2021-07-01_0715.csv"],
["2021-07-01_0715.csv", "2021-07-01_0730.csv"]
]
# Gradio Interface
iface = gr.Interface(
fn=inference,
inputs=[gr.File(file_types=['.csv'], label="Current Wind Data CSV"),
gr.File(file_types=['.csv'], label="Future Wind Data CSV")],
outputs=["plot", "text", gr.DataFrame(label="Prediction Results")],
title="Wind Prediction Model",
description="Upload CSV files containing current and 15-minute future wind data to see detailed prediction errors per box.",
examples=example_csv_files
)
iface.launch()