sunbal7's picture
Create app.py
b2b5a52 verified
raw
history blame
3.85 kB
import streamlit as st
import pandas as pd
import numpy as np
import tensorflow as tf
from sklearn.preprocessing import MinMaxScaler
import plotly.express as px
import os
from groq import Groq
# Initialize session state
if 'model' not in st.session_state:
st.session_state.model = None
if 'threshold' not in st.session_state:
st.session_state.threshold = None
# Autoencoder model definition
def build_autoencoder(input_dim):
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(input_dim,)),
tf.keras.layers.Dense(32, activation='relu'),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(input_dim)
])
model.compile(optimizer='adam', loss='mse')
return model
# Sidebar controls
st.sidebar.title("Configuration")
fine_tune = st.sidebar.button("Fine-tune Model")
groq_api_key = st.sidebar.text_input("Groq API Key (optional)", type="password")
# Main interface
st.title("🛰️ AI Network Anomaly Detector")
st.write("Upload your network data (CSV) to detect anomalies")
# File uploader
uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
# Load or generate sample data
if uploaded_file is not None:
data = pd.read_csv(uploaded_file)
else:
st.info("Using sample data. Upload a file to use your own dataset.")
data = pd.read_csv("sample_wifi_data.csv") # You should provide this sample file
# Preprocessing
features = ['device_count', 'connection_attempts', 'packet_loss', 'latency']
scaler = MinMaxScaler()
data_scaled = scaler.fit_transform(data[features])
# Model training/fine-tuning
if fine_tune or st.session_state.model is None:
with st.spinner("Training model..."):
autoencoder = build_autoencoder(data_scaled.shape[1])
autoencoder.fit(data_scaled, data_scaled,
epochs=100,
batch_size=32,
verbose=0,
validation_split=0.1)
st.session_state.model = autoencoder
# Calculate threshold
reconstructions = autoencoder.predict(data_scaled)
mse = np.mean(np.power(data_scaled - reconstructions, 2), axis=1)
st.session_state.threshold = np.percentile(mse, 95)
# Anomaly detection
if st.session_state.model and st.button("Detect Anomalies"):
reconstructions = st.session_state.model.predict(data_scaled)
mse = np.mean(np.power(data_scaled - reconstructions, 2), axis=1)
anomalies = mse > st.session_state.threshold
# Visualization
fig = px.line(data, x=data.index, y='connection_attempts',
title='Network Traffic with Anomalies')
fig.add_scatter(x=data[anomalies].index, y=data[anomalies]['connection_attempts'],
mode='markers', name='Anomalies')
st.plotly_chart(fig)
# Generate alert with Groq/Llama3
if groq_api_key:
try:
client = Groq(api_key=groq_api_key)
response = client.chat.completions.create(
model="llama3-70b-8192",
messages=[{
"role": "user",
"content": f"Generate a network security alert for {sum(anomalies)} anomalies detected. Max connection attempts: {data['connection_attempts'].max()}"
}]
)
st.warning(response.choices[0].message.content)
except Exception as e:
st.error(f"Groq API Error: {str(e)}")
else:
st.warning(f"Detected {sum(anomalies)} anomalies! Consider adding Groq API key for detailed analysis.")
# Download button for results
if st.session_state.threshold:
st.download_button(
label="Download Anomaly Report",
data=data[anomalies].to_csv().encode('utf-8'),
file_name='anomalies_report.csv',
mime='text/csv'
)