Spaces:
Sleeping
Sleeping
Commit
·
f69d4dd
1
Parent(s):
873bd97
deployment issues
Browse files
app.py
CHANGED
@@ -28,11 +28,11 @@ lstm_model = tf.keras.models.load_model("lstm_model.keras") # LSTM model
|
|
28 |
scaler_X = MinMaxScaler(feature_range=(0, 1))
|
29 |
scaler_y = MinMaxScaler(feature_range=(0, 1))
|
30 |
|
31 |
-
#
|
32 |
X_train_scaled = scaler_X.fit_transform(train_data['Sessions'].values.reshape(-1, 1))
|
33 |
y_train_scaled = scaler_y.fit_transform(train_data['Sessions'].values.reshape(-1, 1))
|
34 |
|
35 |
-
# Scale
|
36 |
X_test_scaled = scaler_X.transform(test_data['Sessions'].values.reshape(-1, 1))
|
37 |
y_test_scaled = scaler_y.transform(test_data['Sessions'].values.reshape(-1, 1))
|
38 |
|
@@ -40,83 +40,85 @@ y_test_scaled = scaler_y.transform(test_data['Sessions'].values.reshape(-1, 1))
|
|
40 |
X_test_lstm = X_test_scaled.reshape((X_test_scaled.shape[0], 1, 1))
|
41 |
|
42 |
# Generate predictions for SARIMA
|
43 |
-
|
|
|
44 |
|
45 |
# Generate predictions for LSTM
|
46 |
-
lstm_predictions_scaled = lstm_model.predict(X_test_lstm)
|
47 |
-
lstm_predictions = scaler_y.inverse_transform(lstm_predictions_scaled)
|
48 |
|
49 |
# Combine predictions into a DataFrame for visualization
|
50 |
future_predictions = pd.DataFrame({
|
51 |
"Datetime": test_data['Datetime'],
|
52 |
"SARIMA_Predicted": sarima_predictions,
|
53 |
-
"LSTM_Predicted": lstm_predictions
|
54 |
})
|
55 |
|
56 |
# Calculate metrics
|
57 |
-
|
58 |
-
|
59 |
|
60 |
-
|
61 |
-
|
62 |
|
63 |
-
# Function to
|
64 |
-
def
|
65 |
-
"""Generate plot based on the selected model."""
|
66 |
plt.figure(figsize=(15, 6))
|
67 |
-
plt.plot(test_data['Datetime'], test_data['Sessions'], label='Actual Traffic', color='black', linestyle='dotted', linewidth=2)
|
68 |
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
plt.xlabel("Datetime", fontsize=12)
|
76 |
plt.ylabel("Sessions", fontsize=12)
|
77 |
plt.legend(loc="upper left")
|
78 |
plt.grid(True)
|
79 |
plt.tight_layout()
|
80 |
-
|
|
|
|
|
81 |
plt.savefig(plot_path)
|
82 |
plt.close()
|
83 |
return plot_path
|
84 |
|
85 |
-
# Function to display metrics
|
86 |
def display_metrics():
|
87 |
-
"""Generate metrics for both models."""
|
88 |
metrics = {
|
89 |
"Model": ["SARIMA", "LSTM"],
|
90 |
-
"Mean Absolute Error (MAE)": [
|
91 |
-
"Root Mean Squared Error (RMSE)": [
|
92 |
}
|
93 |
return pd.DataFrame(metrics)
|
94 |
|
95 |
-
# Gradio
|
96 |
-
def
|
97 |
-
|
98 |
-
plot_path = generate_plot(model)
|
99 |
metrics_df = display_metrics()
|
100 |
return plot_path, metrics_df.to_string()
|
101 |
|
102 |
-
#
|
103 |
with gr.Blocks() as dashboard:
|
104 |
gr.Markdown("## Web Traffic Prediction Dashboard")
|
105 |
-
gr.Markdown("
|
106 |
|
107 |
-
#
|
108 |
-
model_selection = gr.Dropdown(["SARIMA", "LSTM"], label="Select Model", value="SARIMA")
|
109 |
-
|
110 |
-
# Outputs: Plot and Metrics
|
111 |
plot_output = gr.Image(label="Prediction Plot")
|
112 |
-
metrics_output = gr.Textbox(label="Metrics", lines=
|
113 |
-
|
114 |
-
#
|
115 |
-
gr.Button("Update Dashboard").click(
|
116 |
-
fn=dashboard_interface,
|
117 |
-
inputs=[model_selection],
|
118 |
-
outputs=[plot_output, metrics_output]
|
119 |
-
)
|
120 |
|
121 |
# Launch the dashboard
|
122 |
dashboard.launch()
|
|
|
28 |
scaler_X = MinMaxScaler(feature_range=(0, 1))
|
29 |
scaler_y = MinMaxScaler(feature_range=(0, 1))
|
30 |
|
31 |
+
# Scale training data
|
32 |
X_train_scaled = scaler_X.fit_transform(train_data['Sessions'].values.reshape(-1, 1))
|
33 |
y_train_scaled = scaler_y.fit_transform(train_data['Sessions'].values.reshape(-1, 1))
|
34 |
|
35 |
+
# Scale test data
|
36 |
X_test_scaled = scaler_X.transform(test_data['Sessions'].values.reshape(-1, 1))
|
37 |
y_test_scaled = scaler_y.transform(test_data['Sessions'].values.reshape(-1, 1))
|
38 |
|
|
|
40 |
X_test_lstm = X_test_scaled.reshape((X_test_scaled.shape[0], 1, 1))
|
41 |
|
42 |
# Generate predictions for SARIMA
|
43 |
+
future_periods = len(test_data)
|
44 |
+
sarima_predictions = sarima_model.predict(n_periods=future_periods)
|
45 |
|
46 |
# Generate predictions for LSTM
|
47 |
+
lstm_predictions_scaled = lstm_model.predict(X_test_lstm[:future_periods])
|
48 |
+
lstm_predictions = scaler_y.inverse_transform(lstm_predictions_scaled)
|
49 |
|
50 |
# Combine predictions into a DataFrame for visualization
|
51 |
future_predictions = pd.DataFrame({
|
52 |
"Datetime": test_data['Datetime'],
|
53 |
"SARIMA_Predicted": sarima_predictions,
|
54 |
+
"LSTM_Predicted": lstm_predictions.flatten()
|
55 |
})
|
56 |
|
57 |
# Calculate metrics
|
58 |
+
mae_sarima_future = mean_absolute_error(test_data['Sessions'], sarima_predictions)
|
59 |
+
rmse_sarima_future = mean_squared_error(test_data['Sessions'], sarima_predictions, squared=False)
|
60 |
|
61 |
+
mae_lstm_future = mean_absolute_error(test_data['Sessions'], lstm_predictions)
|
62 |
+
rmse_lstm_future = mean_squared_error(test_data['Sessions'], lstm_predictions, squared=False)
|
63 |
|
64 |
+
# Function to plot actual vs. predicted traffic
|
65 |
+
def plot_predictions():
|
|
|
66 |
plt.figure(figsize=(15, 6))
|
|
|
67 |
|
68 |
+
# Plot actual traffic
|
69 |
+
plt.plot(webtraffic_data['Datetime'].iloc[-future_periods:],
|
70 |
+
test_data['Sessions'].values[-future_periods:],
|
71 |
+
label='Actual Traffic', color='black', linestyle='dotted', linewidth=2)
|
72 |
|
73 |
+
# Plot SARIMA predictions
|
74 |
+
plt.plot(future_predictions['Datetime'],
|
75 |
+
future_predictions['SARIMA_Predicted'],
|
76 |
+
label='SARIMA Predicted', color='blue', linewidth=2)
|
77 |
+
|
78 |
+
# Plot LSTM predictions
|
79 |
+
plt.plot(future_predictions['Datetime'],
|
80 |
+
future_predictions['LSTM_Predicted'],
|
81 |
+
label='LSTM Predicted', color='green', linewidth=2)
|
82 |
+
|
83 |
+
plt.title("Future Traffic Predictions: SARIMA vs LSTM", fontsize=16)
|
84 |
plt.xlabel("Datetime", fontsize=12)
|
85 |
plt.ylabel("Sessions", fontsize=12)
|
86 |
plt.legend(loc="upper left")
|
87 |
plt.grid(True)
|
88 |
plt.tight_layout()
|
89 |
+
|
90 |
+
# Save the plot to a file
|
91 |
+
plot_path = "/content/predictions_plot.png"
|
92 |
plt.savefig(plot_path)
|
93 |
plt.close()
|
94 |
return plot_path
|
95 |
|
96 |
+
# Function to display prediction metrics
|
97 |
def display_metrics():
|
|
|
98 |
metrics = {
|
99 |
"Model": ["SARIMA", "LSTM"],
|
100 |
+
"Mean Absolute Error (MAE)": [mae_sarima_future, mae_lstm_future],
|
101 |
+
"Root Mean Squared Error (RMSE)": [rmse_sarima_future, rmse_lstm_future]
|
102 |
}
|
103 |
return pd.DataFrame(metrics)
|
104 |
|
105 |
+
# Gradio function to display the dashboard
|
106 |
+
def gradio_dashboard():
|
107 |
+
plot_path = plot_predictions()
|
|
|
108 |
metrics_df = display_metrics()
|
109 |
return plot_path, metrics_df.to_string()
|
110 |
|
111 |
+
# Gradio interface
|
112 |
with gr.Blocks() as dashboard:
|
113 |
gr.Markdown("## Web Traffic Prediction Dashboard")
|
114 |
+
gr.Markdown("This dashboard compares predictions from SARIMA and LSTM models.")
|
115 |
|
116 |
+
# Show the plot
|
|
|
|
|
|
|
117 |
plot_output = gr.Image(label="Prediction Plot")
|
118 |
+
metrics_output = gr.Textbox(label="Prediction Metrics", lines=15)
|
119 |
+
|
120 |
+
# Define the Gradio button and actions
|
121 |
+
gr.Button("Update Dashboard").click(gradio_dashboard, outputs=[plot_output, metrics_output])
|
|
|
|
|
|
|
|
|
122 |
|
123 |
# Launch the dashboard
|
124 |
dashboard.launch()
|