Spaces:
Sleeping
Sleeping
File size: 5,644 Bytes
b2b5a52 bb2c8a3 b2b5a52 bb2c8a3 b2b5a52 bb2c8a3 b2b5a52 bb2c8a3 b2b5a52 bb2c8a3 b2b5a52 bb2c8a3 b2b5a52 bb2c8a3 b2b5a52 bb2c8a3 b2b5a52 bb2c8a3 b2b5a52 bb2c8a3 b2b5a52 bb2c8a3 b2b5a52 bb2c8a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import streamlit as st
import pandas as pd
import numpy as np
import tensorflow as tf
from sklearn.preprocessing import MinMaxScaler
import plotly.express as px
from groq import Groq
import io
# Initialize session state
if 'model' not in st.session_state:
st.session_state.model = None
if 'threshold' not in st.session_state:
st.session_state.threshold = None
if 'data' not in st.session_state:
st.session_state.data = None
# Generate synthetic data
def generate_sample_data():
time_steps = 500
base = np.arange(time_steps)
data = pd.DataFrame({
'timestamp': pd.date_range(start='2024-01-01', periods=time_steps, freq='H'),
'device_count': np.random.poisson(50, time_steps),
'connection_attempts': np.random.poisson(30, time_steps),
'packet_loss': np.random.uniform(0.1, 2.0, time_steps),
'latency': np.random.uniform(10, 100, time_steps)
})
# Add anomalies
anomaly_indices = np.random.choice(time_steps, 10, replace=False)
data.loc[anomaly_indices, 'connection_attempts'] *= 10
data.loc[anomaly_indices, 'latency'] *= 5
return data
# Autoencoder model
def build_autoencoder(input_dim):
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(input_dim,)),
tf.keras.layers.Dense(32, activation='relu'),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(input_dim)
])
model.compile(optimizer='adam', loss='mse')
return model
# Sidebar controls
st.sidebar.title("Configuration")
fine_tune = st.sidebar.button("Fine-tune Model")
groq_api_key = st.sidebar.text_input("Groq API Key (optional)", type="password")
# Main interface
st.title("🛰️ AI Network Anomaly Detector")
st.write("Upload network data (CSV) or use synthetic data")
# Data handling
uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
if uploaded_file:
try:
st.session_state.data = pd.read_csv(uploaded_file)
except Exception as e:
st.error(f"Error reading file: {str(e)}")
else:
st.session_state.data = generate_sample_data()
st.info("Using synthetic data. Upload a CSV file to analyze your own data.")
# Preprocessing
if st.session_state.data is not None:
features = ['device_count', 'connection_attempts', 'packet_loss', 'latency']
scaler = MinMaxScaler()
data_scaled = scaler.fit_transform(st.session_state.data[features])
# Model training
if st.session_state.model is None or fine_tune:
with st.spinner("Training model..."):
try:
autoencoder = build_autoencoder(data_scaled.shape[1])
autoencoder.fit(data_scaled, data_scaled,
epochs=100,
batch_size=32,
verbose=0,
validation_split=0.2)
st.session_state.model = autoencoder
# Calculate threshold
reconstructions = autoencoder.predict(data_scaled)
mse = np.mean(np.power(data_scaled - reconstructions, 2), axis=1)
st.session_state.threshold = np.percentile(mse, 95)
st.success("Model trained successfully!")
except Exception as e:
st.error(f"Training error: {str(e)}")
# Anomaly detection
if st.session_state.model and st.button("Detect Anomalies"):
try:
reconstructions = st.session_state.model.predict(data_scaled)
mse = np.mean(np.power(data_scaled - reconstructions, 2), axis=1)
anomalies = mse > st.session_state.threshold
# Visualization
fig = px.line(st.session_state.data, x='timestamp', y='connection_attempts',
title='Network Traffic with Anomalies')
fig.add_scatter(x=st.session_state.data[anomalies]['timestamp'],
y=st.session_state.data[anomalies]['connection_attempts'],
mode='markers', name='Anomalies',
marker=dict(color='red', size=8))
st.plotly_chart(fig)
# Groq API integration
if groq_api_key:
try:
client = Groq(api_key=groq_api_key)
response = client.chat.completions.create(
model="llama3-70b-8192",
messages=[{
"role": "user",
"content": f"""Network anomaly report:
- Total data points: {len(st.session_state.data)}
- Anomalies detected: {sum(anomalies)}
- Max connection attempts: {st.session_state.data['connection_attempts'].max()}
- Average latency: {st.session_state.data['latency'].mean():.2f}ms
Provide a technical analysis and recommendations in bullet points."""
}]
)
st.subheader("AI Analysis")
st.write(response.choices[0].message.content)
except Exception as e:
st.error(f"Groq API Error: {str(e)}")
else:
st.warning("Groq API key not provided - using basic detection")
# Download report
report = st.session_state.data[anomalies]
csv = report.to_csv(index=False).encode('utf-8')
st.download_button(
label="Download Anomaly Report",
data=csv,
file_name='anomaly_report.csv',
mime='text/csv'
)
except Exception as e:
st.error(f"Detection error: {str(e)}")
# Display raw data
if st.checkbox("Show raw data"):
st.write(st.session_state.data) |