sunbal7 commited on
Commit
b2b5a52
·
verified ·
1 Parent(s): 7381427

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -0
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from sklearn.preprocessing import MinMaxScaler
6
+ import plotly.express as px
7
+ import os
8
+ from groq import Groq
9
+
10
+ # Initialize session state
11
+ if 'model' not in st.session_state:
12
+ st.session_state.model = None
13
+ if 'threshold' not in st.session_state:
14
+ st.session_state.threshold = None
15
+
16
+ # Autoencoder model definition
17
+ def build_autoencoder(input_dim):
18
+ model = tf.keras.Sequential([
19
+ tf.keras.layers.Dense(64, activation='relu', input_shape=(input_dim,)),
20
+ tf.keras.layers.Dense(32, activation='relu'),
21
+ tf.keras.layers.Dense(64, activation='relu'),
22
+ tf.keras.layers.Dense(input_dim)
23
+ ])
24
+ model.compile(optimizer='adam', loss='mse')
25
+ return model
26
+
27
+ # Sidebar controls
28
+ st.sidebar.title("Configuration")
29
+ fine_tune = st.sidebar.button("Fine-tune Model")
30
+ groq_api_key = st.sidebar.text_input("Groq API Key (optional)", type="password")
31
+
32
+ # Main interface
33
+ st.title("🛰️ AI Network Anomaly Detector")
34
+ st.write("Upload your network data (CSV) to detect anomalies")
35
+
36
+ # File uploader
37
+ uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
38
+
39
+ # Load or generate sample data
40
+ if uploaded_file is not None:
41
+ data = pd.read_csv(uploaded_file)
42
+ else:
43
+ st.info("Using sample data. Upload a file to use your own dataset.")
44
+ data = pd.read_csv("sample_wifi_data.csv") # You should provide this sample file
45
+
46
+ # Preprocessing
47
+ features = ['device_count', 'connection_attempts', 'packet_loss', 'latency']
48
+ scaler = MinMaxScaler()
49
+ data_scaled = scaler.fit_transform(data[features])
50
+
51
+ # Model training/fine-tuning
52
+ if fine_tune or st.session_state.model is None:
53
+ with st.spinner("Training model..."):
54
+ autoencoder = build_autoencoder(data_scaled.shape[1])
55
+ autoencoder.fit(data_scaled, data_scaled,
56
+ epochs=100,
57
+ batch_size=32,
58
+ verbose=0,
59
+ validation_split=0.1)
60
+ st.session_state.model = autoencoder
61
+
62
+ # Calculate threshold
63
+ reconstructions = autoencoder.predict(data_scaled)
64
+ mse = np.mean(np.power(data_scaled - reconstructions, 2), axis=1)
65
+ st.session_state.threshold = np.percentile(mse, 95)
66
+
67
+ # Anomaly detection
68
+ if st.session_state.model and st.button("Detect Anomalies"):
69
+ reconstructions = st.session_state.model.predict(data_scaled)
70
+ mse = np.mean(np.power(data_scaled - reconstructions, 2), axis=1)
71
+ anomalies = mse > st.session_state.threshold
72
+
73
+ # Visualization
74
+ fig = px.line(data, x=data.index, y='connection_attempts',
75
+ title='Network Traffic with Anomalies')
76
+ fig.add_scatter(x=data[anomalies].index, y=data[anomalies]['connection_attempts'],
77
+ mode='markers', name='Anomalies')
78
+ st.plotly_chart(fig)
79
+
80
+ # Generate alert with Groq/Llama3
81
+ if groq_api_key:
82
+ try:
83
+ client = Groq(api_key=groq_api_key)
84
+ response = client.chat.completions.create(
85
+ model="llama3-70b-8192",
86
+ messages=[{
87
+ "role": "user",
88
+ "content": f"Generate a network security alert for {sum(anomalies)} anomalies detected. Max connection attempts: {data['connection_attempts'].max()}"
89
+ }]
90
+ )
91
+ st.warning(response.choices[0].message.content)
92
+ except Exception as e:
93
+ st.error(f"Groq API Error: {str(e)}")
94
+ else:
95
+ st.warning(f"Detected {sum(anomalies)} anomalies! Consider adding Groq API key for detailed analysis.")
96
+
97
+ # Download button for results
98
+ if st.session_state.threshold:
99
+ st.download_button(
100
+ label="Download Anomaly Report",
101
+ data=data[anomalies].to_csv().encode('utf-8'),
102
+ file_name='anomalies_report.csv',
103
+ mime='text/csv'
104
+ )