File size: 3,846 Bytes
b2b5a52
 
 
 
 
 
588f02b
b2b5a52
 
 
 
 
 
 
 
588f02b
b2b5a52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
588f02b
b2b5a52
588f02b
b2b5a52
588f02b
 
 
 
b2b5a52
588f02b
 
b2b5a52
 
588f02b
 
 
b2b5a52
588f02b
 
b2b5a52
588f02b
 
 
 
 
 
 
 
 
 
 
 
b2b5a52
 
 
588f02b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2b5a52
588f02b
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import streamlit as st
import pandas as pd
import numpy as np
import tensorflow as tf
from sklearn.preprocessing import MinMaxScaler
import plotly.express as px
import os
from groq import Groq

# Initialize session state
if 'model' not in st.session_state:
    st.session_state.model = None
if 'threshold' not in st.session_state:
    st.session_state.threshold = None

# Autoencoder model definition
def build_autoencoder(input_dim):
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(64, activation='relu', input_shape=(input_dim,)),
        tf.keras.layers.Dense(32, activation='relu'),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(input_dim)
    ])
    model.compile(optimizer='adam', loss='mse')
    return model

# Sidebar controls
st.sidebar.title("Configuration")
fine_tune = st.sidebar.button("Fine-tune Model")
groq_api_key = st.sidebar.text_input("Groq API Key (optional)", type="password")

# Main interface
st.title("🛰️ AI Network Anomaly Detector")
st.write("Upload your network data (CSV) to detect anomalies")

# File uploader
uploaded_file = st.file_uploader("Choose a CSV file", type="csv")

# Load or generate sample data
if uploaded_file is not None:
    data = pd.read_csv(uploaded_file)
else:
    st.info("Using sample data. Upload a file to use your own dataset.")
    data = pd.read_csv("sample_wifi_data.csv")  # You should provide this sample file

# Preprocessing
features = ['device_count', 'connection_attempts', 'packet_loss', 'latency']
scaler = MinMaxScaler()
data_scaled = scaler.fit_transform(data[features])

# Model training/fine-tuning
if fine_tune or st.session_state.model is None:
    with st.spinner("Training model..."):
        autoencoder = build_autoencoder(data_scaled.shape[1])
        autoencoder.fit(data_scaled, data_scaled, 
                       epochs=100, 
                       batch_size=32,
                       verbose=0,
                       validation_split=0.1)
        st.session_state.model = autoencoder
        
        # Calculate threshold
        reconstructions = autoencoder.predict(data_scaled)
        mse = np.mean(np.power(data_scaled - reconstructions, 2), axis=1)
        st.session_state.threshold = np.percentile(mse, 95)

# Anomaly detection
if st.session_state.model and st.button("Detect Anomalies"):
    reconstructions = st.session_state.model.predict(data_scaled)
    mse = np.mean(np.power(data_scaled - reconstructions, 2), axis=1)
    anomalies = mse > st.session_state.threshold
    
    # Visualization
    fig = px.line(data, x=data.index, y='connection_attempts', 
                 title='Network Traffic with Anomalies')
    fig.add_scatter(x=data[anomalies].index, y=data[anomalies]['connection_attempts'], 
                   mode='markers', name='Anomalies')
    st.plotly_chart(fig)
    
    # Generate alert with Groq/Llama3
    if groq_api_key:
        try:
            client = Groq(api_key=groq_api_key)
            response = client.chat.completions.create(
                model="llama3-70b-8192",
                messages=[{
                    "role": "user",
                    "content": f"Generate a network security alert for {sum(anomalies)} anomalies detected. Max connection attempts: {data['connection_attempts'].max()}"
                }]
            )
            st.warning(response.choices[0].message.content)
        except Exception as e:
            st.error(f"Groq API Error: {str(e)}")
    else:
        st.warning(f"Detected {sum(anomalies)} anomalies! Consider adding Groq API key for detailed analysis.")

# Download button for results
if st.session_state.threshold:
    st.download_button(
        label="Download Anomaly Report",
        data=data[anomalies].to_csv().encode('utf-8'),
        file_name='anomalies_report.csv',
        mime='text/csv'
    )