Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import tensorflow as tf | |
from sklearn.preprocessing import MinMaxScaler | |
import plotly.express as px | |
import os | |
from groq import Groq | |
# Initialize session state | |
if 'model' not in st.session_state: | |
st.session_state.model = None | |
if 'threshold' not in st.session_state: | |
st.session_state.threshold = None | |
# Autoencoder model definition | |
def build_autoencoder(input_dim): | |
model = tf.keras.Sequential([ | |
tf.keras.layers.Dense(64, activation='relu', input_shape=(input_dim,)), | |
tf.keras.layers.Dense(32, activation='relu'), | |
tf.keras.layers.Dense(64, activation='relu'), | |
tf.keras.layers.Dense(input_dim) | |
]) | |
model.compile(optimizer='adam', loss='mse') | |
return model | |
# Sidebar controls | |
st.sidebar.title("Configuration") | |
fine_tune = st.sidebar.button("Fine-tune Model") | |
groq_api_key = st.sidebar.text_input("Groq API Key (optional)", type="password") | |
# Main interface | |
st.title("🛰️ AI Network Anomaly Detector") | |
st.write("Upload your network data (CSV) to detect anomalies") | |
# File uploader | |
uploaded_file = st.file_uploader("Choose a CSV file", type="csv") | |
# Load or generate sample data | |
if uploaded_file is not None: | |
data = pd.read_csv(uploaded_file) | |
else: | |
st.info("Using sample data. Upload a file to use your own dataset.") | |
data = pd.read_csv("sample_wifi_data.csv") # You should provide this sample file | |
# Preprocessing | |
features = ['device_count', 'connection_attempts', 'packet_loss', 'latency'] | |
scaler = MinMaxScaler() | |
data_scaled = scaler.fit_transform(data[features]) | |
# Model training/fine-tuning | |
if fine_tune or st.session_state.model is None: | |
with st.spinner("Training model..."): | |
autoencoder = build_autoencoder(data_scaled.shape[1]) | |
autoencoder.fit(data_scaled, data_scaled, | |
epochs=100, | |
batch_size=32, | |
verbose=0, | |
validation_split=0.1) | |
st.session_state.model = autoencoder | |
# Calculate threshold | |
reconstructions = autoencoder.predict(data_scaled) | |
mse = np.mean(np.power(data_scaled - reconstructions, 2), axis=1) | |
st.session_state.threshold = np.percentile(mse, 95) | |
# Anomaly detection | |
if st.session_state.model and st.button("Detect Anomalies"): | |
reconstructions = st.session_state.model.predict(data_scaled) | |
mse = np.mean(np.power(data_scaled - reconstructions, 2), axis=1) | |
anomalies = mse > st.session_state.threshold | |
# Visualization | |
fig = px.line(data, x=data.index, y='connection_attempts', | |
title='Network Traffic with Anomalies') | |
fig.add_scatter(x=data[anomalies].index, y=data[anomalies]['connection_attempts'], | |
mode='markers', name='Anomalies') | |
st.plotly_chart(fig) | |
# Generate alert with Groq/Llama3 | |
if groq_api_key: | |
try: | |
client = Groq(api_key=groq_api_key) | |
response = client.chat.completions.create( | |
model="llama3-70b-8192", | |
messages=[{ | |
"role": "user", | |
"content": f"Generate a network security alert for {sum(anomalies)} anomalies detected. Max connection attempts: {data['connection_attempts'].max()}" | |
}] | |
) | |
st.warning(response.choices[0].message.content) | |
except Exception as e: | |
st.error(f"Groq API Error: {str(e)}") | |
else: | |
st.warning(f"Detected {sum(anomalies)} anomalies! Consider adding Groq API key for detailed analysis.") | |
# Download button for results | |
if st.session_state.threshold: | |
st.download_button( | |
label="Download Anomaly Report", | |
data=data[anomalies].to_csv().encode('utf-8'), | |
file_name='anomalies_report.csv', | |
mime='text/csv' | |
) |