manjunathainti commited on
Commit
f8d0f44
·
verified ·
1 Parent(s): f69d4dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -94
app.py CHANGED
@@ -1,124 +1,109 @@
1
  import gradio as gr
2
  import matplotlib.pyplot as plt
3
  import pandas as pd
4
- import numpy as np
5
- import tensorflow as tf
6
  import joblib
7
- from sklearn.metrics import mean_absolute_error, mean_squared_error
8
- from sklearn.preprocessing import MinMaxScaler
9
 
10
  # Load the dataset
11
- webtraffic_data = pd.read_csv("webtraffic.csv")
12
-
13
- # Convert 'Hour Index' to datetime
14
- start_date = pd.Timestamp("2024-01-01 00:00:00")
15
- webtraffic_data['Datetime'] = start_date + pd.to_timedelta(webtraffic_data['Hour Index'], unit='h')
16
- webtraffic_data.drop(columns=['Hour Index'], inplace=True)
17
-
18
- # Split the data into train/test
19
- train_size = int(len(webtraffic_data) * 0.8)
20
- train_data = webtraffic_data.iloc[:train_size]
21
- test_data = webtraffic_data.iloc[train_size:]
22
-
23
- # Load pre-trained models
24
- sarima_model = joblib.load("sarima_model.pkl") # SARIMA model
25
- lstm_model = tf.keras.models.load_model("lstm_model.keras") # LSTM model
26
-
27
- # Initialize scalers and scale the data for LSTM
28
- scaler_X = MinMaxScaler(feature_range=(0, 1))
29
- scaler_y = MinMaxScaler(feature_range=(0, 1))
30
-
31
- # Scale training data
32
- X_train_scaled = scaler_X.fit_transform(train_data['Sessions'].values.reshape(-1, 1))
33
- y_train_scaled = scaler_y.fit_transform(train_data['Sessions'].values.reshape(-1, 1))
34
-
35
- # Scale test data
36
- X_test_scaled = scaler_X.transform(test_data['Sessions'].values.reshape(-1, 1))
37
- y_test_scaled = scaler_y.transform(test_data['Sessions'].values.reshape(-1, 1))
38
-
39
- # Reshape test data for LSTM (samples, time_steps, features)
40
- X_test_lstm = X_test_scaled.reshape((X_test_scaled.shape[0], 1, 1))
41
-
42
- # Generate predictions for SARIMA
43
- future_periods = len(test_data)
44
- sarima_predictions = sarima_model.predict(n_periods=future_periods)
45
-
46
- # Generate predictions for LSTM
47
- lstm_predictions_scaled = lstm_model.predict(X_test_lstm[:future_periods])
48
- lstm_predictions = scaler_y.inverse_transform(lstm_predictions_scaled)
49
-
50
- # Combine predictions into a DataFrame for visualization
51
- future_predictions = pd.DataFrame({
52
- "Datetime": test_data['Datetime'],
53
- "SARIMA_Predicted": sarima_predictions,
54
- "LSTM_Predicted": lstm_predictions.flatten()
55
- })
56
-
57
- # Calculate metrics
58
- mae_sarima_future = mean_absolute_error(test_data['Sessions'], sarima_predictions)
59
- rmse_sarima_future = mean_squared_error(test_data['Sessions'], sarima_predictions, squared=False)
60
-
61
- mae_lstm_future = mean_absolute_error(test_data['Sessions'], lstm_predictions)
62
- rmse_lstm_future = mean_squared_error(test_data['Sessions'], lstm_predictions, squared=False)
63
-
64
- # Function to plot actual vs. predicted traffic
65
- def plot_predictions():
66
  plt.figure(figsize=(15, 6))
67
-
68
- # Plot actual traffic
69
- plt.plot(webtraffic_data['Datetime'].iloc[-future_periods:],
70
- test_data['Sessions'].values[-future_periods:],
71
- label='Actual Traffic', color='black', linestyle='dotted', linewidth=2)
72
-
73
- # Plot SARIMA predictions
74
- plt.plot(future_predictions['Datetime'],
75
- future_predictions['SARIMA_Predicted'],
76
- label='SARIMA Predicted', color='blue', linewidth=2)
77
-
78
- # Plot LSTM predictions
79
- plt.plot(future_predictions['Datetime'],
80
- future_predictions['LSTM_Predicted'],
81
- label='LSTM Predicted', color='green', linewidth=2)
82
-
83
- plt.title("Future Traffic Predictions: SARIMA vs LSTM", fontsize=16)
84
  plt.xlabel("Datetime", fontsize=12)
85
  plt.ylabel("Sessions", fontsize=12)
86
  plt.legend(loc="upper left")
87
  plt.grid(True)
88
  plt.tight_layout()
89
 
90
- # Save the plot to a file
91
- plot_path = "/content/predictions_plot.png"
92
  plt.savefig(plot_path)
93
  plt.close()
94
  return plot_path
95
 
96
- # Function to display prediction metrics
 
97
  def display_metrics():
98
  metrics = {
99
- "Model": ["SARIMA", "LSTM"],
100
- "Mean Absolute Error (MAE)": [mae_sarima_future, mae_lstm_future],
101
- "Root Mean Squared Error (RMSE)": [rmse_sarima_future, rmse_lstm_future]
102
  }
103
  return pd.DataFrame(metrics)
104
 
105
- # Gradio function to display the dashboard
106
- def gradio_dashboard():
107
- plot_path = plot_predictions()
 
108
  metrics_df = display_metrics()
109
  return plot_path, metrics_df.to_string()
110
 
111
- # Gradio interface
 
112
  with gr.Blocks() as dashboard:
113
- gr.Markdown("## Web Traffic Prediction Dashboard")
114
- gr.Markdown("This dashboard compares predictions from SARIMA and LSTM models.")
 
 
115
 
116
- # Show the plot
117
  plot_output = gr.Image(label="Prediction Plot")
118
- metrics_output = gr.Textbox(label="Prediction Metrics", lines=15)
119
 
120
- # Define the Gradio button and actions
121
- gr.Button("Update Dashboard").click(gradio_dashboard, outputs=[plot_output, metrics_output])
 
 
 
122
 
123
- # Launch the dashboard
124
- dashboard.launch()
 
 
1
  import gradio as gr
2
  import matplotlib.pyplot as plt
3
  import pandas as pd
 
 
4
  import joblib
 
 
5
 
6
  # Load the dataset
7
+ data_file = "webtraffic.csv"
8
+ webtraffic_data = pd.read_csv(data_file)
9
+
10
+ # Verify if 'Datetime' exists, or create it
11
+ if "Datetime" not in webtraffic_data.columns:
12
+ print("Datetime column missing. Attempting to create from 'Hour Index'.")
13
+ start_date = pd.Timestamp("2024-01-01 00:00:00")
14
+ webtraffic_data["Datetime"] = start_date + pd.to_timedelta(
15
+ webtraffic_data["Hour Index"], unit="h"
16
+ )
17
+ else:
18
+ webtraffic_data["Datetime"] = pd.to_datetime(webtraffic_data["Datetime"])
19
+
20
+ # Ensure 'Datetime' column is sorted
21
+ webtraffic_data.sort_values("Datetime", inplace=True)
22
+
23
+ # Load the SARIMA model
24
+ sarima_model = joblib.load("sarima_model.pkl")
25
+
26
+ # Define future periods for evaluation
27
+ future_periods = 48
28
+
29
+ # Dummy values for metrics (if needed)
30
+ mae_sarima_future = 100
31
+ rmse_sarima_future = 150
32
+
33
+
34
+ # Function to generate plot based on SARIMA model
35
+ def generate_plot():
36
+ future_dates = pd.date_range(
37
+ start=webtraffic_data["Datetime"].iloc[-1], periods=future_periods + 1, freq="H"
38
+ )[1:]
39
+
40
+ sarima_predictions = sarima_model.predict(n_periods=future_periods)
41
+ future_predictions = pd.DataFrame(
42
+ {"Datetime": future_dates, "SARIMA_Predicted": sarima_predictions}
43
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  plt.figure(figsize=(15, 6))
45
+ plt.plot(
46
+ webtraffic_data["Datetime"],
47
+ webtraffic_data["Sessions"],
48
+ label="Actual Traffic",
49
+ color="black",
50
+ linestyle="dotted",
51
+ linewidth=2,
52
+ )
53
+ plt.plot(
54
+ future_predictions["Datetime"],
55
+ future_predictions["SARIMA_Predicted"],
56
+ label="SARIMA Predicted",
57
+ color="blue",
58
+ linewidth=2,
59
+ )
60
+
61
+ plt.title("SARIMA Predictions vs Actual Traffic", fontsize=16)
62
  plt.xlabel("Datetime", fontsize=12)
63
  plt.ylabel("Sessions", fontsize=12)
64
  plt.legend(loc="upper left")
65
  plt.grid(True)
66
  plt.tight_layout()
67
 
68
+ plot_path = "sarima_prediction_plot.png"
 
69
  plt.savefig(plot_path)
70
  plt.close()
71
  return plot_path
72
 
73
+
74
+ # Function to display SARIMA metrics
75
  def display_metrics():
76
  metrics = {
77
+ "Model": ["SARIMA"],
78
+ "Mean Absolute Error (MAE)": [mae_sarima_future],
79
+ "Root Mean Squared Error (RMSE)": [rmse_sarima_future],
80
  }
81
  return pd.DataFrame(metrics)
82
 
83
+
84
+ # Gradio interface function
85
+ def dashboard_interface():
86
+ plot_path = generate_plot()
87
  metrics_df = display_metrics()
88
  return plot_path, metrics_df.to_string()
89
 
90
+
91
+ # Build the Gradio interface
92
  with gr.Blocks() as dashboard:
93
+ gr.Markdown("## Interactive SARIMA Web Traffic Prediction Dashboard")
94
+ gr.Markdown(
95
+ "This dashboard shows SARIMA model predictions vs actual traffic along with performance metrics."
96
+ )
97
 
 
98
  plot_output = gr.Image(label="Prediction Plot")
99
+ metrics_output = gr.Textbox(label="Metrics", lines=15)
100
 
101
+ gr.Button("Generate Predictions").click(
102
+ fn=dashboard_interface,
103
+ inputs=[],
104
+ outputs=[plot_output, metrics_output],
105
+ )
106
 
107
+ # Launch the Gradio dashboard
108
+ if __name__ == "__main__":
109
+ dashboard.launch()