File size: 5,644 Bytes
b2b5a52
 
 
 
 
 
 
bb2c8a3
b2b5a52
 
 
 
 
 
bb2c8a3
 
b2b5a52
bb2c8a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2b5a52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb2c8a3
b2b5a52
bb2c8a3
b2b5a52
bb2c8a3
 
 
 
 
b2b5a52
bb2c8a3
 
b2b5a52
 
bb2c8a3
 
 
 
b2b5a52
bb2c8a3
 
b2b5a52
bb2c8a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2b5a52
 
 
bb2c8a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2b5a52
bb2c8a3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import streamlit as st
import pandas as pd
import numpy as np
import tensorflow as tf
from sklearn.preprocessing import MinMaxScaler
import plotly.express as px
from groq import Groq
import io

# Initialize session state
if 'model' not in st.session_state:
    st.session_state.model = None
if 'threshold' not in st.session_state:
    st.session_state.threshold = None
if 'data' not in st.session_state:
    st.session_state.data = None

# Generate synthetic data
def generate_sample_data():
    time_steps = 500
    base = np.arange(time_steps)
    data = pd.DataFrame({
        'timestamp': pd.date_range(start='2024-01-01', periods=time_steps, freq='H'),
        'device_count': np.random.poisson(50, time_steps),
        'connection_attempts': np.random.poisson(30, time_steps),
        'packet_loss': np.random.uniform(0.1, 2.0, time_steps),
        'latency': np.random.uniform(10, 100, time_steps)
    })
    
    # Add anomalies
    anomaly_indices = np.random.choice(time_steps, 10, replace=False)
    data.loc[anomaly_indices, 'connection_attempts'] *= 10
    data.loc[anomaly_indices, 'latency'] *= 5
    return data

# Autoencoder model
def build_autoencoder(input_dim):
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(64, activation='relu', input_shape=(input_dim,)),
        tf.keras.layers.Dense(32, activation='relu'),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(input_dim)
    ])
    model.compile(optimizer='adam', loss='mse')
    return model

# Sidebar controls
st.sidebar.title("Configuration")
fine_tune = st.sidebar.button("Fine-tune Model")
groq_api_key = st.sidebar.text_input("Groq API Key (optional)", type="password")

# Main interface
st.title("🛰️ AI Network Anomaly Detector")
st.write("Upload network data (CSV) or use synthetic data")

# Data handling
uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
if uploaded_file:
    try:
        st.session_state.data = pd.read_csv(uploaded_file)
    except Exception as e:
        st.error(f"Error reading file: {str(e)}")
else:
    st.session_state.data = generate_sample_data()
    st.info("Using synthetic data. Upload a CSV file to analyze your own data.")

# Preprocessing
if st.session_state.data is not None:
    features = ['device_count', 'connection_attempts', 'packet_loss', 'latency']
    scaler = MinMaxScaler()
    data_scaled = scaler.fit_transform(st.session_state.data[features])

# Model training
if st.session_state.model is None or fine_tune:
    with st.spinner("Training model..."):
        try:
            autoencoder = build_autoencoder(data_scaled.shape[1])
            autoencoder.fit(data_scaled, data_scaled,
                          epochs=100,
                          batch_size=32,
                          verbose=0,
                          validation_split=0.2)
            st.session_state.model = autoencoder
            
            # Calculate threshold
            reconstructions = autoencoder.predict(data_scaled)
            mse = np.mean(np.power(data_scaled - reconstructions, 2), axis=1)
            st.session_state.threshold = np.percentile(mse, 95)
            st.success("Model trained successfully!")
        except Exception as e:
            st.error(f"Training error: {str(e)}")

# Anomaly detection
if st.session_state.model and st.button("Detect Anomalies"):
    try:
        reconstructions = st.session_state.model.predict(data_scaled)
        mse = np.mean(np.power(data_scaled - reconstructions, 2), axis=1)
        anomalies = mse > st.session_state.threshold
        
        # Visualization
        fig = px.line(st.session_state.data, x='timestamp', y='connection_attempts',
                     title='Network Traffic with Anomalies')
        fig.add_scatter(x=st.session_state.data[anomalies]['timestamp'],
                       y=st.session_state.data[anomalies]['connection_attempts'],
                       mode='markers', name='Anomalies',
                       marker=dict(color='red', size=8))
        st.plotly_chart(fig)

        # Groq API integration
        if groq_api_key:
            try:
                client = Groq(api_key=groq_api_key)
                response = client.chat.completions.create(
                    model="llama3-70b-8192",
                    messages=[{
                        "role": "user",
                        "content": f"""Network anomaly report:
                        - Total data points: {len(st.session_state.data)}
                        - Anomalies detected: {sum(anomalies)}
                        - Max connection attempts: {st.session_state.data['connection_attempts'].max()}
                        - Average latency: {st.session_state.data['latency'].mean():.2f}ms
                        Provide a technical analysis and recommendations in bullet points."""
                    }]
                )
                st.subheader("AI Analysis")
                st.write(response.choices[0].message.content)
            except Exception as e:
                st.error(f"Groq API Error: {str(e)}")
        else:
            st.warning("Groq API key not provided - using basic detection")

        # Download report
        report = st.session_state.data[anomalies]
        csv = report.to_csv(index=False).encode('utf-8')
        st.download_button(
            label="Download Anomaly Report",
            data=csv,
            file_name='anomaly_report.csv',
            mime='text/csv'
        )
    except Exception as e:
        st.error(f"Detection error: {str(e)}")

# Display raw data
if st.checkbox("Show raw data"):
    st.write(st.session_state.data)