sunbal7 commited on
Commit
bb2c8a3
·
verified ·
1 Parent(s): b2b5a52

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -62
app.py CHANGED
@@ -4,16 +4,36 @@ import numpy as np
4
  import tensorflow as tf
5
  from sklearn.preprocessing import MinMaxScaler
6
  import plotly.express as px
7
- import os
8
  from groq import Groq
 
9
 
10
  # Initialize session state
11
  if 'model' not in st.session_state:
12
  st.session_state.model = None
13
  if 'threshold' not in st.session_state:
14
  st.session_state.threshold = None
 
 
15
 
16
- # Autoencoder model definition
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def build_autoencoder(input_dim):
18
  model = tf.keras.Sequential([
19
  tf.keras.layers.Dense(64, activation='relu', input_shape=(input_dim,)),
@@ -31,74 +51,96 @@ groq_api_key = st.sidebar.text_input("Groq API Key (optional)", type="password")
31
 
32
  # Main interface
33
  st.title("🛰️ AI Network Anomaly Detector")
34
- st.write("Upload your network data (CSV) to detect anomalies")
35
 
36
- # File uploader
37
  uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
38
-
39
- # Load or generate sample data
40
- if uploaded_file is not None:
41
- data = pd.read_csv(uploaded_file)
 
42
  else:
43
- st.info("Using sample data. Upload a file to use your own dataset.")
44
- data = pd.read_csv("sample_wifi_data.csv") # You should provide this sample file
45
 
46
  # Preprocessing
47
- features = ['device_count', 'connection_attempts', 'packet_loss', 'latency']
48
- scaler = MinMaxScaler()
49
- data_scaled = scaler.fit_transform(data[features])
 
50
 
51
- # Model training/fine-tuning
52
- if fine_tune or st.session_state.model is None:
53
  with st.spinner("Training model..."):
54
- autoencoder = build_autoencoder(data_scaled.shape[1])
55
- autoencoder.fit(data_scaled, data_scaled,
56
- epochs=100,
57
- batch_size=32,
58
- verbose=0,
59
- validation_split=0.1)
60
- st.session_state.model = autoencoder
61
-
62
- # Calculate threshold
63
- reconstructions = autoencoder.predict(data_scaled)
64
- mse = np.mean(np.power(data_scaled - reconstructions, 2), axis=1)
65
- st.session_state.threshold = np.percentile(mse, 95)
 
 
 
 
66
 
67
  # Anomaly detection
68
  if st.session_state.model and st.button("Detect Anomalies"):
69
- reconstructions = st.session_state.model.predict(data_scaled)
70
- mse = np.mean(np.power(data_scaled - reconstructions, 2), axis=1)
71
- anomalies = mse > st.session_state.threshold
72
-
73
- # Visualization
74
- fig = px.line(data, x=data.index, y='connection_attempts',
75
- title='Network Traffic with Anomalies')
76
- fig.add_scatter(x=data[anomalies].index, y=data[anomalies]['connection_attempts'],
77
- mode='markers', name='Anomalies')
78
- st.plotly_chart(fig)
79
-
80
- # Generate alert with Groq/Llama3
81
- if groq_api_key:
82
- try:
83
- client = Groq(api_key=groq_api_key)
84
- response = client.chat.completions.create(
85
- model="llama3-70b-8192",
86
- messages=[{
87
- "role": "user",
88
- "content": f"Generate a network security alert for {sum(anomalies)} anomalies detected. Max connection attempts: {data['connection_attempts'].max()}"
89
- }]
90
- )
91
- st.warning(response.choices[0].message.content)
92
- except Exception as e:
93
- st.error(f"Groq API Error: {str(e)}")
94
- else:
95
- st.warning(f"Detected {sum(anomalies)} anomalies! Consider adding Groq API key for detailed analysis.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
- # Download button for results
98
- if st.session_state.threshold:
99
- st.download_button(
100
- label="Download Anomaly Report",
101
- data=data[anomalies].to_csv().encode('utf-8'),
102
- file_name='anomalies_report.csv',
103
- mime='text/csv'
104
- )
 
4
  import tensorflow as tf
5
  from sklearn.preprocessing import MinMaxScaler
6
  import plotly.express as px
 
7
  from groq import Groq
8
+ import io
9
 
10
  # Initialize session state
11
  if 'model' not in st.session_state:
12
  st.session_state.model = None
13
  if 'threshold' not in st.session_state:
14
  st.session_state.threshold = None
15
+ if 'data' not in st.session_state:
16
+ st.session_state.data = None
17
 
18
+ # Generate synthetic data
19
+ def generate_sample_data():
20
+ time_steps = 500
21
+ base = np.arange(time_steps)
22
+ data = pd.DataFrame({
23
+ 'timestamp': pd.date_range(start='2024-01-01', periods=time_steps, freq='H'),
24
+ 'device_count': np.random.poisson(50, time_steps),
25
+ 'connection_attempts': np.random.poisson(30, time_steps),
26
+ 'packet_loss': np.random.uniform(0.1, 2.0, time_steps),
27
+ 'latency': np.random.uniform(10, 100, time_steps)
28
+ })
29
+
30
+ # Add anomalies
31
+ anomaly_indices = np.random.choice(time_steps, 10, replace=False)
32
+ data.loc[anomaly_indices, 'connection_attempts'] *= 10
33
+ data.loc[anomaly_indices, 'latency'] *= 5
34
+ return data
35
+
36
+ # Autoencoder model
37
  def build_autoencoder(input_dim):
38
  model = tf.keras.Sequential([
39
  tf.keras.layers.Dense(64, activation='relu', input_shape=(input_dim,)),
 
51
 
52
  # Main interface
53
  st.title("🛰️ AI Network Anomaly Detector")
54
+ st.write("Upload network data (CSV) or use synthetic data")
55
 
56
+ # Data handling
57
  uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
58
+ if uploaded_file:
59
+ try:
60
+ st.session_state.data = pd.read_csv(uploaded_file)
61
+ except Exception as e:
62
+ st.error(f"Error reading file: {str(e)}")
63
  else:
64
+ st.session_state.data = generate_sample_data()
65
+ st.info("Using synthetic data. Upload a CSV file to analyze your own data.")
66
 
67
  # Preprocessing
68
+ if st.session_state.data is not None:
69
+ features = ['device_count', 'connection_attempts', 'packet_loss', 'latency']
70
+ scaler = MinMaxScaler()
71
+ data_scaled = scaler.fit_transform(st.session_state.data[features])
72
 
73
+ # Model training
74
+ if st.session_state.model is None or fine_tune:
75
  with st.spinner("Training model..."):
76
+ try:
77
+ autoencoder = build_autoencoder(data_scaled.shape[1])
78
+ autoencoder.fit(data_scaled, data_scaled,
79
+ epochs=100,
80
+ batch_size=32,
81
+ verbose=0,
82
+ validation_split=0.2)
83
+ st.session_state.model = autoencoder
84
+
85
+ # Calculate threshold
86
+ reconstructions = autoencoder.predict(data_scaled)
87
+ mse = np.mean(np.power(data_scaled - reconstructions, 2), axis=1)
88
+ st.session_state.threshold = np.percentile(mse, 95)
89
+ st.success("Model trained successfully!")
90
+ except Exception as e:
91
+ st.error(f"Training error: {str(e)}")
92
 
93
  # Anomaly detection
94
  if st.session_state.model and st.button("Detect Anomalies"):
95
+ try:
96
+ reconstructions = st.session_state.model.predict(data_scaled)
97
+ mse = np.mean(np.power(data_scaled - reconstructions, 2), axis=1)
98
+ anomalies = mse > st.session_state.threshold
99
+
100
+ # Visualization
101
+ fig = px.line(st.session_state.data, x='timestamp', y='connection_attempts',
102
+ title='Network Traffic with Anomalies')
103
+ fig.add_scatter(x=st.session_state.data[anomalies]['timestamp'],
104
+ y=st.session_state.data[anomalies]['connection_attempts'],
105
+ mode='markers', name='Anomalies',
106
+ marker=dict(color='red', size=8))
107
+ st.plotly_chart(fig)
108
+
109
+ # Groq API integration
110
+ if groq_api_key:
111
+ try:
112
+ client = Groq(api_key=groq_api_key)
113
+ response = client.chat.completions.create(
114
+ model="llama3-70b-8192",
115
+ messages=[{
116
+ "role": "user",
117
+ "content": f"""Network anomaly report:
118
+ - Total data points: {len(st.session_state.data)}
119
+ - Anomalies detected: {sum(anomalies)}
120
+ - Max connection attempts: {st.session_state.data['connection_attempts'].max()}
121
+ - Average latency: {st.session_state.data['latency'].mean():.2f}ms
122
+ Provide a technical analysis and recommendations in bullet points."""
123
+ }]
124
+ )
125
+ st.subheader("AI Analysis")
126
+ st.write(response.choices[0].message.content)
127
+ except Exception as e:
128
+ st.error(f"Groq API Error: {str(e)}")
129
+ else:
130
+ st.warning("Groq API key not provided - using basic detection")
131
+
132
+ # Download report
133
+ report = st.session_state.data[anomalies]
134
+ csv = report.to_csv(index=False).encode('utf-8')
135
+ st.download_button(
136
+ label="Download Anomaly Report",
137
+ data=csv,
138
+ file_name='anomaly_report.csv',
139
+ mime='text/csv'
140
+ )
141
+ except Exception as e:
142
+ st.error(f"Detection error: {str(e)}")
143
 
144
+ # Display raw data
145
+ if st.checkbox("Show raw data"):
146
+ st.write(st.session_state.data)