sunbal7 commited on
Commit
588f02b
·
verified ·
1 Parent(s): bb2c8a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -104
app.py CHANGED
@@ -4,36 +4,16 @@ import numpy as np
4
  import tensorflow as tf
5
  from sklearn.preprocessing import MinMaxScaler
6
  import plotly.express as px
 
7
  from groq import Groq
8
- import io
9
 
10
  # Initialize session state
11
  if 'model' not in st.session_state:
12
  st.session_state.model = None
13
  if 'threshold' not in st.session_state:
14
  st.session_state.threshold = None
15
- if 'data' not in st.session_state:
16
- st.session_state.data = None
17
 
18
- # Generate synthetic data
19
- def generate_sample_data():
20
- time_steps = 500
21
- base = np.arange(time_steps)
22
- data = pd.DataFrame({
23
- 'timestamp': pd.date_range(start='2024-01-01', periods=time_steps, freq='H'),
24
- 'device_count': np.random.poisson(50, time_steps),
25
- 'connection_attempts': np.random.poisson(30, time_steps),
26
- 'packet_loss': np.random.uniform(0.1, 2.0, time_steps),
27
- 'latency': np.random.uniform(10, 100, time_steps)
28
- })
29
-
30
- # Add anomalies
31
- anomaly_indices = np.random.choice(time_steps, 10, replace=False)
32
- data.loc[anomaly_indices, 'connection_attempts'] *= 10
33
- data.loc[anomaly_indices, 'latency'] *= 5
34
- return data
35
-
36
- # Autoencoder model
37
  def build_autoencoder(input_dim):
38
  model = tf.keras.Sequential([
39
  tf.keras.layers.Dense(64, activation='relu', input_shape=(input_dim,)),
@@ -51,96 +31,74 @@ groq_api_key = st.sidebar.text_input("Groq API Key (optional)", type="password")
51
 
52
  # Main interface
53
  st.title("🛰️ AI Network Anomaly Detector")
54
- st.write("Upload network data (CSV) or use synthetic data")
55
 
56
- # Data handling
57
  uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
58
- if uploaded_file:
59
- try:
60
- st.session_state.data = pd.read_csv(uploaded_file)
61
- except Exception as e:
62
- st.error(f"Error reading file: {str(e)}")
63
  else:
64
- st.session_state.data = generate_sample_data()
65
- st.info("Using synthetic data. Upload a CSV file to analyze your own data.")
66
 
67
  # Preprocessing
68
- if st.session_state.data is not None:
69
- features = ['device_count', 'connection_attempts', 'packet_loss', 'latency']
70
- scaler = MinMaxScaler()
71
- data_scaled = scaler.fit_transform(st.session_state.data[features])
72
 
73
- # Model training
74
- if st.session_state.model is None or fine_tune:
75
  with st.spinner("Training model..."):
76
- try:
77
- autoencoder = build_autoencoder(data_scaled.shape[1])
78
- autoencoder.fit(data_scaled, data_scaled,
79
- epochs=100,
80
- batch_size=32,
81
- verbose=0,
82
- validation_split=0.2)
83
- st.session_state.model = autoencoder
84
-
85
- # Calculate threshold
86
- reconstructions = autoencoder.predict(data_scaled)
87
- mse = np.mean(np.power(data_scaled - reconstructions, 2), axis=1)
88
- st.session_state.threshold = np.percentile(mse, 95)
89
- st.success("Model trained successfully!")
90
- except Exception as e:
91
- st.error(f"Training error: {str(e)}")
92
 
93
  # Anomaly detection
94
  if st.session_state.model and st.button("Detect Anomalies"):
95
- try:
96
- reconstructions = st.session_state.model.predict(data_scaled)
97
- mse = np.mean(np.power(data_scaled - reconstructions, 2), axis=1)
98
- anomalies = mse > st.session_state.threshold
99
-
100
- # Visualization
101
- fig = px.line(st.session_state.data, x='timestamp', y='connection_attempts',
102
- title='Network Traffic with Anomalies')
103
- fig.add_scatter(x=st.session_state.data[anomalies]['timestamp'],
104
- y=st.session_state.data[anomalies]['connection_attempts'],
105
- mode='markers', name='Anomalies',
106
- marker=dict(color='red', size=8))
107
- st.plotly_chart(fig)
108
-
109
- # Groq API integration
110
- if groq_api_key:
111
- try:
112
- client = Groq(api_key=groq_api_key)
113
- response = client.chat.completions.create(
114
- model="llama3-70b-8192",
115
- messages=[{
116
- "role": "user",
117
- "content": f"""Network anomaly report:
118
- - Total data points: {len(st.session_state.data)}
119
- - Anomalies detected: {sum(anomalies)}
120
- - Max connection attempts: {st.session_state.data['connection_attempts'].max()}
121
- - Average latency: {st.session_state.data['latency'].mean():.2f}ms
122
- Provide a technical analysis and recommendations in bullet points."""
123
- }]
124
- )
125
- st.subheader("AI Analysis")
126
- st.write(response.choices[0].message.content)
127
- except Exception as e:
128
- st.error(f"Groq API Error: {str(e)}")
129
- else:
130
- st.warning("Groq API key not provided - using basic detection")
131
-
132
- # Download report
133
- report = st.session_state.data[anomalies]
134
- csv = report.to_csv(index=False).encode('utf-8')
135
- st.download_button(
136
- label="Download Anomaly Report",
137
- data=csv,
138
- file_name='anomaly_report.csv',
139
- mime='text/csv'
140
- )
141
- except Exception as e:
142
- st.error(f"Detection error: {str(e)}")
143
 
144
- # Display raw data
145
- if st.checkbox("Show raw data"):
146
- st.write(st.session_state.data)
 
 
 
 
 
 
4
  import tensorflow as tf
5
  from sklearn.preprocessing import MinMaxScaler
6
  import plotly.express as px
7
+ import os
8
  from groq import Groq
 
9
 
10
  # Initialize session state
11
  if 'model' not in st.session_state:
12
  st.session_state.model = None
13
  if 'threshold' not in st.session_state:
14
  st.session_state.threshold = None
 
 
15
 
16
+ # Autoencoder model definition
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def build_autoencoder(input_dim):
18
  model = tf.keras.Sequential([
19
  tf.keras.layers.Dense(64, activation='relu', input_shape=(input_dim,)),
 
31
 
32
  # Main interface
33
  st.title("🛰️ AI Network Anomaly Detector")
34
+ st.write("Upload your network data (CSV) to detect anomalies")
35
 
36
+ # File uploader
37
  uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
38
+
39
+ # Load or generate sample data
40
+ if uploaded_file is not None:
41
+ data = pd.read_csv(uploaded_file)
 
42
  else:
43
+ st.info("Using sample data. Upload a file to use your own dataset.")
44
+ data = pd.read_csv("sample_wifi_data.csv") # You should provide this sample file
45
 
46
  # Preprocessing
47
+ features = ['device_count', 'connection_attempts', 'packet_loss', 'latency']
48
+ scaler = MinMaxScaler()
49
+ data_scaled = scaler.fit_transform(data[features])
 
50
 
51
+ # Model training/fine-tuning
52
+ if fine_tune or st.session_state.model is None:
53
  with st.spinner("Training model..."):
54
+ autoencoder = build_autoencoder(data_scaled.shape[1])
55
+ autoencoder.fit(data_scaled, data_scaled,
56
+ epochs=100,
57
+ batch_size=32,
58
+ verbose=0,
59
+ validation_split=0.1)
60
+ st.session_state.model = autoencoder
61
+
62
+ # Calculate threshold
63
+ reconstructions = autoencoder.predict(data_scaled)
64
+ mse = np.mean(np.power(data_scaled - reconstructions, 2), axis=1)
65
+ st.session_state.threshold = np.percentile(mse, 95)
 
 
 
 
66
 
67
  # Anomaly detection
68
  if st.session_state.model and st.button("Detect Anomalies"):
69
+ reconstructions = st.session_state.model.predict(data_scaled)
70
+ mse = np.mean(np.power(data_scaled - reconstructions, 2), axis=1)
71
+ anomalies = mse > st.session_state.threshold
72
+
73
+ # Visualization
74
+ fig = px.line(data, x=data.index, y='connection_attempts',
75
+ title='Network Traffic with Anomalies')
76
+ fig.add_scatter(x=data[anomalies].index, y=data[anomalies]['connection_attempts'],
77
+ mode='markers', name='Anomalies')
78
+ st.plotly_chart(fig)
79
+
80
+ # Generate alert with Groq/Llama3
81
+ if groq_api_key:
82
+ try:
83
+ client = Groq(api_key=groq_api_key)
84
+ response = client.chat.completions.create(
85
+ model="llama3-70b-8192",
86
+ messages=[{
87
+ "role": "user",
88
+ "content": f"Generate a network security alert for {sum(anomalies)} anomalies detected. Max connection attempts: {data['connection_attempts'].max()}"
89
+ }]
90
+ )
91
+ st.warning(response.choices[0].message.content)
92
+ except Exception as e:
93
+ st.error(f"Groq API Error: {str(e)}")
94
+ else:
95
+ st.warning(f"Detected {sum(anomalies)} anomalies! Consider adding Groq API key for detailed analysis.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
+ # Download button for results
98
+ if st.session_state.threshold:
99
+ st.download_button(
100
+ label="Download Anomaly Report",
101
+ data=data[anomalies].to_csv().encode('utf-8'),
102
+ file_name='anomalies_report.csv',
103
+ mime='text/csv'
104
+ )