manjunathainti commited on
Commit
f69d4dd
·
1 Parent(s): 873bd97

deployment issues

Browse files
Files changed (1) hide show
  1. app.py +44 -42
app.py CHANGED
@@ -28,11 +28,11 @@ lstm_model = tf.keras.models.load_model("lstm_model.keras") # LSTM model
28
  scaler_X = MinMaxScaler(feature_range=(0, 1))
29
  scaler_y = MinMaxScaler(feature_range=(0, 1))
30
 
31
- # Fit scalers on the training data
32
  X_train_scaled = scaler_X.fit_transform(train_data['Sessions'].values.reshape(-1, 1))
33
  y_train_scaled = scaler_y.fit_transform(train_data['Sessions'].values.reshape(-1, 1))
34
 
35
- # Scale the test data
36
  X_test_scaled = scaler_X.transform(test_data['Sessions'].values.reshape(-1, 1))
37
  y_test_scaled = scaler_y.transform(test_data['Sessions'].values.reshape(-1, 1))
38
 
@@ -40,83 +40,85 @@ y_test_scaled = scaler_y.transform(test_data['Sessions'].values.reshape(-1, 1))
40
  X_test_lstm = X_test_scaled.reshape((X_test_scaled.shape[0], 1, 1))
41
 
42
  # Generate predictions for SARIMA
43
- sarima_predictions = sarima_model.predict(start=len(train_data), end=len(webtraffic_data) - 1)
 
44
 
45
  # Generate predictions for LSTM
46
- lstm_predictions_scaled = lstm_model.predict(X_test_lstm)
47
- lstm_predictions = scaler_y.inverse_transform(lstm_predictions_scaled).flatten()
48
 
49
  # Combine predictions into a DataFrame for visualization
50
  future_predictions = pd.DataFrame({
51
  "Datetime": test_data['Datetime'],
52
  "SARIMA_Predicted": sarima_predictions,
53
- "LSTM_Predicted": lstm_predictions
54
  })
55
 
56
  # Calculate metrics
57
- mae_sarima = mean_absolute_error(test_data['Sessions'], sarima_predictions)
58
- rmse_sarima = mean_squared_error(test_data['Sessions'], sarima_predictions, squared=False)
59
 
60
- mae_lstm = mean_absolute_error(test_data['Sessions'], lstm_predictions)
61
- rmse_lstm = mean_squared_error(test_data['Sessions'], lstm_predictions, squared=False)
62
 
63
- # Function to generate plots
64
- def generate_plot(model):
65
- """Generate plot based on the selected model."""
66
  plt.figure(figsize=(15, 6))
67
- plt.plot(test_data['Datetime'], test_data['Sessions'], label='Actual Traffic', color='black', linestyle='dotted', linewidth=2)
68
 
69
- if model == "SARIMA":
70
- plt.plot(future_predictions['Datetime'], future_predictions['SARIMA_Predicted'], label='SARIMA Predicted', color='blue', linewidth=2)
71
- elif model == "LSTM":
72
- plt.plot(future_predictions['Datetime'], future_predictions['LSTM_Predicted'], label='LSTM Predicted', color='green', linewidth=2)
73
 
74
- plt.title(f"{model} Predictions vs Actual Traffic", fontsize=16)
 
 
 
 
 
 
 
 
 
 
75
  plt.xlabel("Datetime", fontsize=12)
76
  plt.ylabel("Sessions", fontsize=12)
77
  plt.legend(loc="upper left")
78
  plt.grid(True)
79
  plt.tight_layout()
80
- plot_path = f"{model.lower()}_plot.png"
 
 
81
  plt.savefig(plot_path)
82
  plt.close()
83
  return plot_path
84
 
85
- # Function to display metrics
86
  def display_metrics():
87
- """Generate metrics for both models."""
88
  metrics = {
89
  "Model": ["SARIMA", "LSTM"],
90
- "Mean Absolute Error (MAE)": [mae_sarima, mae_lstm],
91
- "Root Mean Squared Error (RMSE)": [rmse_sarima, rmse_lstm]
92
  }
93
  return pd.DataFrame(metrics)
94
 
95
- # Gradio interface function
96
- def dashboard_interface(model="SARIMA"):
97
- """Generate plot and metrics for the selected model."""
98
- plot_path = generate_plot(model)
99
  metrics_df = display_metrics()
100
  return plot_path, metrics_df.to_string()
101
 
102
- # Build the Gradio dashboard
103
  with gr.Blocks() as dashboard:
104
  gr.Markdown("## Web Traffic Prediction Dashboard")
105
- gr.Markdown("Select a model to view its predictions and performance metrics.")
106
 
107
- # Dropdown for model selection
108
- model_selection = gr.Dropdown(["SARIMA", "LSTM"], label="Select Model", value="SARIMA")
109
-
110
- # Outputs: Plot and Metrics
111
  plot_output = gr.Image(label="Prediction Plot")
112
- metrics_output = gr.Textbox(label="Metrics", lines=10)
113
-
114
- # Button to update dashboard
115
- gr.Button("Update Dashboard").click(
116
- fn=dashboard_interface,
117
- inputs=[model_selection],
118
- outputs=[plot_output, metrics_output]
119
- )
120
 
121
  # Launch the dashboard
122
  dashboard.launch()
 
28
  scaler_X = MinMaxScaler(feature_range=(0, 1))
29
  scaler_y = MinMaxScaler(feature_range=(0, 1))
30
 
31
+ # Scale training data
32
  X_train_scaled = scaler_X.fit_transform(train_data['Sessions'].values.reshape(-1, 1))
33
  y_train_scaled = scaler_y.fit_transform(train_data['Sessions'].values.reshape(-1, 1))
34
 
35
+ # Scale test data
36
  X_test_scaled = scaler_X.transform(test_data['Sessions'].values.reshape(-1, 1))
37
  y_test_scaled = scaler_y.transform(test_data['Sessions'].values.reshape(-1, 1))
38
 
 
40
  X_test_lstm = X_test_scaled.reshape((X_test_scaled.shape[0], 1, 1))
41
 
42
  # Generate predictions for SARIMA
43
+ future_periods = len(test_data)
44
+ sarima_predictions = sarima_model.predict(n_periods=future_periods)
45
 
46
  # Generate predictions for LSTM
47
+ lstm_predictions_scaled = lstm_model.predict(X_test_lstm[:future_periods])
48
+ lstm_predictions = scaler_y.inverse_transform(lstm_predictions_scaled)
49
 
50
  # Combine predictions into a DataFrame for visualization
51
  future_predictions = pd.DataFrame({
52
  "Datetime": test_data['Datetime'],
53
  "SARIMA_Predicted": sarima_predictions,
54
+ "LSTM_Predicted": lstm_predictions.flatten()
55
  })
56
 
57
  # Calculate metrics
58
+ mae_sarima_future = mean_absolute_error(test_data['Sessions'], sarima_predictions)
59
+ rmse_sarima_future = mean_squared_error(test_data['Sessions'], sarima_predictions, squared=False)
60
 
61
+ mae_lstm_future = mean_absolute_error(test_data['Sessions'], lstm_predictions)
62
+ rmse_lstm_future = mean_squared_error(test_data['Sessions'], lstm_predictions, squared=False)
63
 
64
+ # Function to plot actual vs. predicted traffic
65
+ def plot_predictions():
 
66
  plt.figure(figsize=(15, 6))
 
67
 
68
+ # Plot actual traffic
69
+ plt.plot(webtraffic_data['Datetime'].iloc[-future_periods:],
70
+ test_data['Sessions'].values[-future_periods:],
71
+ label='Actual Traffic', color='black', linestyle='dotted', linewidth=2)
72
 
73
+ # Plot SARIMA predictions
74
+ plt.plot(future_predictions['Datetime'],
75
+ future_predictions['SARIMA_Predicted'],
76
+ label='SARIMA Predicted', color='blue', linewidth=2)
77
+
78
+ # Plot LSTM predictions
79
+ plt.plot(future_predictions['Datetime'],
80
+ future_predictions['LSTM_Predicted'],
81
+ label='LSTM Predicted', color='green', linewidth=2)
82
+
83
+ plt.title("Future Traffic Predictions: SARIMA vs LSTM", fontsize=16)
84
  plt.xlabel("Datetime", fontsize=12)
85
  plt.ylabel("Sessions", fontsize=12)
86
  plt.legend(loc="upper left")
87
  plt.grid(True)
88
  plt.tight_layout()
89
+
90
+ # Save the plot to a file
91
+ plot_path = "/content/predictions_plot.png"
92
  plt.savefig(plot_path)
93
  plt.close()
94
  return plot_path
95
 
96
+ # Function to display prediction metrics
97
  def display_metrics():
 
98
  metrics = {
99
  "Model": ["SARIMA", "LSTM"],
100
+ "Mean Absolute Error (MAE)": [mae_sarima_future, mae_lstm_future],
101
+ "Root Mean Squared Error (RMSE)": [rmse_sarima_future, rmse_lstm_future]
102
  }
103
  return pd.DataFrame(metrics)
104
 
105
+ # Gradio function to display the dashboard
106
+ def gradio_dashboard():
107
+ plot_path = plot_predictions()
 
108
  metrics_df = display_metrics()
109
  return plot_path, metrics_df.to_string()
110
 
111
+ # Gradio interface
112
  with gr.Blocks() as dashboard:
113
  gr.Markdown("## Web Traffic Prediction Dashboard")
114
+ gr.Markdown("This dashboard compares predictions from SARIMA and LSTM models.")
115
 
116
+ # Show the plot
 
 
 
117
  plot_output = gr.Image(label="Prediction Plot")
118
+ metrics_output = gr.Textbox(label="Prediction Metrics", lines=15)
119
+
120
+ # Define the Gradio button and actions
121
+ gr.Button("Update Dashboard").click(gradio_dashboard, outputs=[plot_output, metrics_output])
 
 
 
 
122
 
123
  # Launch the dashboard
124
  dashboard.launch()