Spaces:
Sleeping
Sleeping
import gradio as gr | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
import numpy as np | |
import tensorflow as tf | |
import joblib | |
from sklearn.metrics import mean_absolute_error, mean_squared_error | |
# Load the dataset | |
webtraffic_data = pd.read_csv("webtraffic.csv") | |
# Rename 'Hour Index' for easier use | |
webtraffic_data.rename(columns={"Hour Index": "Datetime"}, inplace=True) | |
# Create a datetime-like index for visualization purposes | |
webtraffic_data['Datetime'] = pd.date_range(start='2023-01-01', periods=len(webtraffic_data), freq='H') | |
# Split the data into train/test for evaluation | |
train_size = int(len(webtraffic_data) * 0.8) | |
test_size = len(webtraffic_data) - train_size | |
train_data = webtraffic_data.iloc[:train_size] | |
test_data = webtraffic_data.iloc[train_size:] | |
# Load the pre-trained models | |
sarima_model = joblib.load("sarima_model.pkl") # SARIMA model | |
lstm_model = tf.keras.models.load_model("lstm_model.keras") # LSTM model | |
# Initialize future periods for prediction | |
future_periods = len(test_data) | |
# Generate predictions for SARIMA | |
sarima_predictions = sarima_model.forecast(steps=future_periods) | |
# Prepare data for LSTM predictions | |
from sklearn.preprocessing import MinMaxScaler | |
scaler_X = MinMaxScaler(feature_range=(0, 1)) | |
scaler_y = MinMaxScaler(feature_range=(0, 1)) | |
# Fit the scaler to the training data | |
X_train_scaled = scaler_X.fit_transform(train_data['Sessions'].values.reshape(-1, 1)) | |
y_train_scaled = scaler_y.fit_transform(train_data['Sessions'].values.reshape(-1, 1)) | |
# Scale test data | |
X_test_scaled = scaler_X.transform(test_data['Sessions'].values.reshape(-1, 1)) | |
y_test_scaled = scaler_y.transform(test_data['Sessions'].values.reshape(-1, 1)) | |
# Reshape data for LSTM input | |
X_test_lstm = X_test_scaled.reshape((X_test_scaled.shape[0], 1, X_test_scaled.shape[1])) | |
# Predict with LSTM | |
lstm_predictions_scaled = lstm_model.predict(X_test_lstm) | |
lstm_predictions = scaler_y.inverse_transform(lstm_predictions_scaled).flatten() | |
# Combine predictions into a DataFrame for visualization | |
future_predictions = pd.DataFrame({ | |
"Datetime": test_data['Datetime'], | |
"SARIMA_Predicted": sarima_predictions, | |
"LSTM_Predicted": lstm_predictions | |
}) | |
# Calculate metrics for both models | |
mae_sarima_future = mean_absolute_error(test_data['Sessions'], sarima_predictions) | |
rmse_sarima_future = mean_squared_error(test_data['Sessions'], sarima_predictions, squared=False) | |
mae_lstm_future = mean_absolute_error(test_data['Sessions'], lstm_predictions) | |
rmse_lstm_future = mean_squared_error(test_data['Sessions'], lstm_predictions, squared=False) | |
# Function to generate plot based on the selected model | |
def generate_plot(model): | |
"""Generate plot based on the selected model.""" | |
plt.figure(figsize=(15, 6)) | |
actual_dates = test_data['Datetime'] | |
plt.plot(actual_dates, test_data['Sessions'], label='Actual Traffic', color='black', linestyle='dotted', linewidth=2) | |
if model == "SARIMA": | |
plt.plot(future_predictions['Datetime'], future_predictions['SARIMA_Predicted'], label='SARIMA Predicted', color='blue', linewidth=2) | |
elif model == "LSTM": | |
plt.plot(future_predictions['Datetime'], future_predictions['LSTM_Predicted'], label='LSTM Predicted', color='green', linewidth=2) | |
plt.title(f"{model} Predictions vs Actual Traffic", fontsize=16) | |
plt.xlabel("Datetime", fontsize=12) | |
plt.ylabel("Sessions", fontsize=12) | |
plt.legend(loc="upper left") | |
plt.grid(True) | |
plt.tight_layout() | |
plot_path = f"{model.lower()}_plot.png" | |
plt.savefig(plot_path) | |
plt.close() | |
return plot_path | |
# Function to display metrics for both models | |
def display_metrics(): | |
"""Generate a DataFrame with metrics for SARIMA and LSTM.""" | |
metrics = { | |
"Model": ["SARIMA", "LSTM"], | |
"Mean Absolute Error (MAE)": [mae_sarima_future, mae_lstm_future], | |
"Root Mean Squared Error (RMSE)": [rmse_sarima_future, rmse_lstm_future] | |
} | |
return pd.DataFrame(metrics) | |
# Gradio interface function | |
def dashboard_interface(model="SARIMA"): | |
"""Generate plot and metrics for the selected model.""" | |
plot_path = generate_plot(model) # Generate plot for the selected model | |
metrics_df = display_metrics() # Get metrics | |
return plot_path, metrics_df.to_string() | |
# Build the Gradio interface | |
with gr.Blocks() as dashboard: | |
gr.Markdown("## Interactive Web Traffic Prediction Dashboard") | |
gr.Markdown("Use the dropdown menu to select a model and view its predictions vs actual traffic along with performance metrics.") | |
# Dropdown for model selection | |
model_selection = gr.Dropdown(["SARIMA", "LSTM"], label="Select Model", value="SARIMA") | |
# Outputs: Plot and Metrics | |
plot_output = gr.Image(label="Prediction Plot") | |
metrics_output = gr.Textbox(label="Metrics", lines=15) | |
# Button to update dashboard | |
gr.Button("Update Dashboard").click( | |
fn=dashboard_interface, # Function to call | |
inputs=[model_selection], # Inputs to the function | |
outputs=[plot_output, metrics_output] # Outputs from the function | |
) | |
# Launch the Gradio dashboard | |
dashboard.launch() | |