sunbal7's picture
Update app.py
bb2c8a3 verified
raw
history blame
5.64 kB
import streamlit as st
import pandas as pd
import numpy as np
import tensorflow as tf
from sklearn.preprocessing import MinMaxScaler
import plotly.express as px
from groq import Groq
import io
# Initialize session state
if 'model' not in st.session_state:
st.session_state.model = None
if 'threshold' not in st.session_state:
st.session_state.threshold = None
if 'data' not in st.session_state:
st.session_state.data = None
# Generate synthetic data
def generate_sample_data():
time_steps = 500
base = np.arange(time_steps)
data = pd.DataFrame({
'timestamp': pd.date_range(start='2024-01-01', periods=time_steps, freq='H'),
'device_count': np.random.poisson(50, time_steps),
'connection_attempts': np.random.poisson(30, time_steps),
'packet_loss': np.random.uniform(0.1, 2.0, time_steps),
'latency': np.random.uniform(10, 100, time_steps)
})
# Add anomalies
anomaly_indices = np.random.choice(time_steps, 10, replace=False)
data.loc[anomaly_indices, 'connection_attempts'] *= 10
data.loc[anomaly_indices, 'latency'] *= 5
return data
# Autoencoder model
def build_autoencoder(input_dim):
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(input_dim,)),
tf.keras.layers.Dense(32, activation='relu'),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(input_dim)
])
model.compile(optimizer='adam', loss='mse')
return model
# Sidebar controls
st.sidebar.title("Configuration")
fine_tune = st.sidebar.button("Fine-tune Model")
groq_api_key = st.sidebar.text_input("Groq API Key (optional)", type="password")
# Main interface
st.title("🛰️ AI Network Anomaly Detector")
st.write("Upload network data (CSV) or use synthetic data")
# Data handling
uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
if uploaded_file:
try:
st.session_state.data = pd.read_csv(uploaded_file)
except Exception as e:
st.error(f"Error reading file: {str(e)}")
else:
st.session_state.data = generate_sample_data()
st.info("Using synthetic data. Upload a CSV file to analyze your own data.")
# Preprocessing
if st.session_state.data is not None:
features = ['device_count', 'connection_attempts', 'packet_loss', 'latency']
scaler = MinMaxScaler()
data_scaled = scaler.fit_transform(st.session_state.data[features])
# Model training
if st.session_state.model is None or fine_tune:
with st.spinner("Training model..."):
try:
autoencoder = build_autoencoder(data_scaled.shape[1])
autoencoder.fit(data_scaled, data_scaled,
epochs=100,
batch_size=32,
verbose=0,
validation_split=0.2)
st.session_state.model = autoencoder
# Calculate threshold
reconstructions = autoencoder.predict(data_scaled)
mse = np.mean(np.power(data_scaled - reconstructions, 2), axis=1)
st.session_state.threshold = np.percentile(mse, 95)
st.success("Model trained successfully!")
except Exception as e:
st.error(f"Training error: {str(e)}")
# Anomaly detection
if st.session_state.model and st.button("Detect Anomalies"):
try:
reconstructions = st.session_state.model.predict(data_scaled)
mse = np.mean(np.power(data_scaled - reconstructions, 2), axis=1)
anomalies = mse > st.session_state.threshold
# Visualization
fig = px.line(st.session_state.data, x='timestamp', y='connection_attempts',
title='Network Traffic with Anomalies')
fig.add_scatter(x=st.session_state.data[anomalies]['timestamp'],
y=st.session_state.data[anomalies]['connection_attempts'],
mode='markers', name='Anomalies',
marker=dict(color='red', size=8))
st.plotly_chart(fig)
# Groq API integration
if groq_api_key:
try:
client = Groq(api_key=groq_api_key)
response = client.chat.completions.create(
model="llama3-70b-8192",
messages=[{
"role": "user",
"content": f"""Network anomaly report:
- Total data points: {len(st.session_state.data)}
- Anomalies detected: {sum(anomalies)}
- Max connection attempts: {st.session_state.data['connection_attempts'].max()}
- Average latency: {st.session_state.data['latency'].mean():.2f}ms
Provide a technical analysis and recommendations in bullet points."""
}]
)
st.subheader("AI Analysis")
st.write(response.choices[0].message.content)
except Exception as e:
st.error(f"Groq API Error: {str(e)}")
else:
st.warning("Groq API key not provided - using basic detection")
# Download report
report = st.session_state.data[anomalies]
csv = report.to_csv(index=False).encode('utf-8')
st.download_button(
label="Download Anomaly Report",
data=csv,
file_name='anomaly_report.csv',
mime='text/csv'
)
except Exception as e:
st.error(f"Detection error: {str(e)}")
# Display raw data
if st.checkbox("Show raw data"):
st.write(st.session_state.data)