sunbal7 commited on
Commit
d7bc36b
·
verified ·
1 Parent(s): 588f02b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -91
app.py CHANGED
@@ -1,104 +1,91 @@
1
  import streamlit as st
2
  import pandas as pd
3
  import numpy as np
4
- import tensorflow as tf
5
- from sklearn.preprocessing import MinMaxScaler
6
- import plotly.express as px
7
- import os
8
- from groq import Groq
9
 
10
- # Initialize session state
11
- if 'model' not in st.session_state:
12
- st.session_state.model = None
13
- if 'threshold' not in st.session_state:
14
- st.session_state.threshold = None
 
 
 
15
 
16
- # Autoencoder model definition
17
- def build_autoencoder(input_dim):
18
- model = tf.keras.Sequential([
19
- tf.keras.layers.Dense(64, activation='relu', input_shape=(input_dim,)),
20
- tf.keras.layers.Dense(32, activation='relu'),
21
- tf.keras.layers.Dense(64, activation='relu'),
22
- tf.keras.layers.Dense(input_dim)
23
- ])
24
- model.compile(optimizer='adam', loss='mse')
25
- return model
26
 
27
- # Sidebar controls
28
- st.sidebar.title("Configuration")
29
- fine_tune = st.sidebar.button("Fine-tune Model")
30
- groq_api_key = st.sidebar.text_input("Groq API Key (optional)", type="password")
31
 
32
- # Main interface
33
- st.title("🛰️ AI Network Anomaly Detector")
34
- st.write("Upload your network data (CSV) to detect anomalies")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- # File uploader
37
- uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
 
 
 
38
 
39
- # Load or generate sample data
40
- if uploaded_file is not None:
41
- data = pd.read_csv(uploaded_file)
42
- else:
43
- st.info("Using sample data. Upload a file to use your own dataset.")
44
- data = pd.read_csv("sample_wifi_data.csv") # You should provide this sample file
45
 
46
- # Preprocessing
47
- features = ['device_count', 'connection_attempts', 'packet_loss', 'latency']
48
- scaler = MinMaxScaler()
49
- data_scaled = scaler.fit_transform(data[features])
 
50
 
51
- # Model training/fine-tuning
52
- if fine_tune or st.session_state.model is None:
53
- with st.spinner("Training model..."):
54
- autoencoder = build_autoencoder(data_scaled.shape[1])
55
- autoencoder.fit(data_scaled, data_scaled,
56
- epochs=100,
57
- batch_size=32,
58
- verbose=0,
59
- validation_split=0.1)
60
- st.session_state.model = autoencoder
61
-
62
- # Calculate threshold
63
- reconstructions = autoencoder.predict(data_scaled)
64
- mse = np.mean(np.power(data_scaled - reconstructions, 2), axis=1)
65
- st.session_state.threshold = np.percentile(mse, 95)
66
 
67
- # Anomaly detection
68
- if st.session_state.model and st.button("Detect Anomalies"):
69
- reconstructions = st.session_state.model.predict(data_scaled)
70
- mse = np.mean(np.power(data_scaled - reconstructions, 2), axis=1)
71
- anomalies = mse > st.session_state.threshold
72
-
73
- # Visualization
74
- fig = px.line(data, x=data.index, y='connection_attempts',
75
- title='Network Traffic with Anomalies')
76
- fig.add_scatter(x=data[anomalies].index, y=data[anomalies]['connection_attempts'],
77
- mode='markers', name='Anomalies')
78
- st.plotly_chart(fig)
79
-
80
- # Generate alert with Groq/Llama3
81
- if groq_api_key:
82
- try:
83
- client = Groq(api_key=groq_api_key)
84
- response = client.chat.completions.create(
85
- model="llama3-70b-8192",
86
- messages=[{
87
- "role": "user",
88
- "content": f"Generate a network security alert for {sum(anomalies)} anomalies detected. Max connection attempts: {data['connection_attempts'].max()}"
89
- }]
90
- )
91
- st.warning(response.choices[0].message.content)
92
- except Exception as e:
93
- st.error(f"Groq API Error: {str(e)}")
94
- else:
95
- st.warning(f"Detected {sum(anomalies)} anomalies! Consider adding Groq API key for detailed analysis.")
96
 
97
- # Download button for results
98
- if st.session_state.threshold:
99
- st.download_button(
100
- label="Download Anomaly Report",
101
- data=data[anomalies].to_csv().encode('utf-8'),
102
- file_name='anomalies_report.csv',
103
- mime='text/csv'
104
- )
 
1
  import streamlit as st
2
  import pandas as pd
3
  import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from pyod.models.iforest import IForest
6
+ from pyod.models.lof import LOF
 
 
7
 
8
+ def main():
9
+ st.title("AI-Based Network Anomaly Detection (Predictive Maintenance)")
10
+ st.markdown(
11
+ """
12
+ This application uses AI to detect unusual behavior in a network before it leads to failure.
13
+ By leveraging open source models and PyOD, it predicts potential issues, enabling proactive maintenance.
14
+ """
15
+ )
16
 
17
+ # Sidebar settings for model and parameters
18
+ st.sidebar.header("Settings")
19
+ model_choice = st.sidebar.selectbox("Select Anomaly Detection Model", ("Isolation Forest", "Local Outlier Factor"))
20
+ contamination = st.sidebar.slider("Contamination (Expected anomaly ratio)", 0.0, 0.5, 0.1)
 
 
 
 
 
 
21
 
22
+ uploaded_file = st.file_uploader("Upload CSV file with network data", type=["csv"])
 
 
 
23
 
24
+ if uploaded_file is not None:
25
+ data = pd.read_csv(uploaded_file)
26
+ st.write("### Data Preview")
27
+ st.dataframe(data.head())
28
+ else:
29
+ st.info("No file uploaded. Generating synthetic network data for demonstration.")
30
+ # Generate synthetic data with features like traffic, latency, and packet_loss
31
+ np.random.seed(42)
32
+ n_samples = 300
33
+ traffic = np.random.normal(100, 10, n_samples)
34
+ latency = np.random.normal(50, 5, n_samples)
35
+ packet_loss = np.random.normal(0.5, 0.1, n_samples)
36
+ # Introduce anomalies by modifying a subset of data points
37
+ anomaly_indices = np.random.choice(n_samples, size=20, replace=False)
38
+ traffic[anomaly_indices] *= 1.5
39
+ latency[anomaly_indices] *= 2
40
+ packet_loss[anomaly_indices] *= 5
41
+
42
+ data = pd.DataFrame({
43
+ "traffic": traffic,
44
+ "latency": latency,
45
+ "packet_loss": packet_loss
46
+ })
47
+ st.write("### Synthetic Data")
48
+ st.dataframe(data.head())
49
 
50
+ # Use only numeric features for anomaly detection
51
+ features = data.select_dtypes(include=[np.number]).columns.tolist()
52
+ if not features:
53
+ st.error("No numeric columns found in the data for anomaly detection.")
54
+ return
55
 
56
+ X = data[features].values
 
 
 
 
 
57
 
58
+ # Initialize the selected model from PyOD
59
+ if model_choice == "Isolation Forest":
60
+ model = IForest(contamination=contamination)
61
+ elif model_choice == "Local Outlier Factor":
62
+ model = LOF(contamination=contamination)
63
 
64
+ # Fit the model and predict anomalies (0: normal, 1: anomaly)
65
+ model.fit(X)
66
+ predictions = model.labels_
67
+ data["anomaly"] = predictions
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ st.subheader("Anomaly Detection Results")
70
+ st.write(data.head())
71
+ n_anomalies = np.sum(predictions)
72
+ st.write(f"Detected **{n_anomalies}** anomalies out of **{len(data)}** data points.")
73
+
74
+ # Visualization (if at least 2 numeric features are available)
75
+ if len(features) >= 2:
76
+ st.subheader("Visualization")
77
+ fig, ax = plt.subplots()
78
+ # Plot using the first two numeric features
79
+ x_feature = features[0]
80
+ y_feature = features[1]
81
+ normal_data = data[data["anomaly"] == 0]
82
+ anomaly_data = data[data["anomaly"] == 1]
83
+ ax.scatter(normal_data[x_feature], normal_data[y_feature], label="Normal", color="blue", alpha=0.5)
84
+ ax.scatter(anomaly_data[x_feature], anomaly_data[y_feature], label="Anomaly", color="red", marker="x")
85
+ ax.set_xlabel(x_feature)
86
+ ax.set_ylabel(y_feature)
87
+ ax.legend()
88
+ st.pyplot(fig)
 
 
 
 
 
 
 
 
 
89
 
90
+ if __name__ == "__main__":
91
+ main()