Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import tensorflow as tf | |
from sklearn.preprocessing import MinMaxScaler | |
import plotly.express as px | |
from groq import Groq | |
import io | |
# Initialize session state | |
if 'model' not in st.session_state: | |
st.session_state.model = None | |
if 'threshold' not in st.session_state: | |
st.session_state.threshold = None | |
if 'data' not in st.session_state: | |
st.session_state.data = None | |
# Generate synthetic data | |
def generate_sample_data(): | |
time_steps = 500 | |
base = np.arange(time_steps) | |
data = pd.DataFrame({ | |
'timestamp': pd.date_range(start='2024-01-01', periods=time_steps, freq='H'), | |
'device_count': np.random.poisson(50, time_steps), | |
'connection_attempts': np.random.poisson(30, time_steps), | |
'packet_loss': np.random.uniform(0.1, 2.0, time_steps), | |
'latency': np.random.uniform(10, 100, time_steps) | |
}) | |
# Add anomalies | |
anomaly_indices = np.random.choice(time_steps, 10, replace=False) | |
data.loc[anomaly_indices, 'connection_attempts'] *= 10 | |
data.loc[anomaly_indices, 'latency'] *= 5 | |
return data | |
# Autoencoder model | |
def build_autoencoder(input_dim): | |
model = tf.keras.Sequential([ | |
tf.keras.layers.Dense(64, activation='relu', input_shape=(input_dim,)), | |
tf.keras.layers.Dense(32, activation='relu'), | |
tf.keras.layers.Dense(64, activation='relu'), | |
tf.keras.layers.Dense(input_dim) | |
]) | |
model.compile(optimizer='adam', loss='mse') | |
return model | |
# Sidebar controls | |
st.sidebar.title("Configuration") | |
fine_tune = st.sidebar.button("Fine-tune Model") | |
groq_api_key = st.sidebar.text_input("Groq API Key (optional)", type="password") | |
# Main interface | |
st.title("🛰️ AI Network Anomaly Detector") | |
st.write("Upload network data (CSV) or use synthetic data") | |
# Data handling | |
uploaded_file = st.file_uploader("Choose a CSV file", type="csv") | |
if uploaded_file: | |
try: | |
st.session_state.data = pd.read_csv(uploaded_file) | |
except Exception as e: | |
st.error(f"Error reading file: {str(e)}") | |
else: | |
st.session_state.data = generate_sample_data() | |
st.info("Using synthetic data. Upload a CSV file to analyze your own data.") | |
# Preprocessing | |
if st.session_state.data is not None: | |
features = ['device_count', 'connection_attempts', 'packet_loss', 'latency'] | |
scaler = MinMaxScaler() | |
data_scaled = scaler.fit_transform(st.session_state.data[features]) | |
# Model training | |
if st.session_state.model is None or fine_tune: | |
with st.spinner("Training model..."): | |
try: | |
autoencoder = build_autoencoder(data_scaled.shape[1]) | |
autoencoder.fit(data_scaled, data_scaled, | |
epochs=100, | |
batch_size=32, | |
verbose=0, | |
validation_split=0.2) | |
st.session_state.model = autoencoder | |
# Calculate threshold | |
reconstructions = autoencoder.predict(data_scaled) | |
mse = np.mean(np.power(data_scaled - reconstructions, 2), axis=1) | |
st.session_state.threshold = np.percentile(mse, 95) | |
st.success("Model trained successfully!") | |
except Exception as e: | |
st.error(f"Training error: {str(e)}") | |
# Anomaly detection | |
if st.session_state.model and st.button("Detect Anomalies"): | |
try: | |
reconstructions = st.session_state.model.predict(data_scaled) | |
mse = np.mean(np.power(data_scaled - reconstructions, 2), axis=1) | |
anomalies = mse > st.session_state.threshold | |
# Visualization | |
fig = px.line(st.session_state.data, x='timestamp', y='connection_attempts', | |
title='Network Traffic with Anomalies') | |
fig.add_scatter(x=st.session_state.data[anomalies]['timestamp'], | |
y=st.session_state.data[anomalies]['connection_attempts'], | |
mode='markers', name='Anomalies', | |
marker=dict(color='red', size=8)) | |
st.plotly_chart(fig) | |
# Groq API integration | |
if groq_api_key: | |
try: | |
client = Groq(api_key=groq_api_key) | |
response = client.chat.completions.create( | |
model="llama3-70b-8192", | |
messages=[{ | |
"role": "user", | |
"content": f"""Network anomaly report: | |
- Total data points: {len(st.session_state.data)} | |
- Anomalies detected: {sum(anomalies)} | |
- Max connection attempts: {st.session_state.data['connection_attempts'].max()} | |
- Average latency: {st.session_state.data['latency'].mean():.2f}ms | |
Provide a technical analysis and recommendations in bullet points.""" | |
}] | |
) | |
st.subheader("AI Analysis") | |
st.write(response.choices[0].message.content) | |
except Exception as e: | |
st.error(f"Groq API Error: {str(e)}") | |
else: | |
st.warning("Groq API key not provided - using basic detection") | |
# Download report | |
report = st.session_state.data[anomalies] | |
csv = report.to_csv(index=False).encode('utf-8') | |
st.download_button( | |
label="Download Anomaly Report", | |
data=csv, | |
file_name='anomaly_report.csv', | |
mime='text/csv' | |
) | |
except Exception as e: | |
st.error(f"Detection error: {str(e)}") | |
# Display raw data | |
if st.checkbox("Show raw data"): | |
st.write(st.session_state.data) |